[CK_TILE] Refine pk_fp4's fill, pack, and unpack (#2845)

* fix bug

* let pack/unpack return pk_fp4_t

* fix clang-format
This commit is contained in:
Gino Lu
2025-09-17 10:54:06 +08:00
committed by GitHub
parent db79fad16f
commit c2997f2b7f
3 changed files with 70 additions and 36 deletions

View File

@@ -23,7 +23,8 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
struct pk_float4_e2m1_t;
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f);
// TODO: Add stochastic method
struct pk_float4_e2m1_t
@@ -31,7 +32,7 @@ struct pk_float4_e2m1_t
// TODO: Can we merge raw_type and type?
using raw_type = uint8_t;
using type = raw_type;
raw_type data;
type data;
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {}
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
@@ -39,12 +40,12 @@ struct pk_float4_e2m1_t
{
}
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
: data{float_to_e2m1(init, scale)}
: data{float_to_pk_fp4(init, scale)}
{
}
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
CK_TILE_HOST_DEVICE constexpr type& get() { return data; }
CK_TILE_HOST_DEVICE constexpr type get() const { return data; }
CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
@@ -61,8 +62,19 @@ struct pk_float4_e2m1_t
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
template <index_t I>
CK_TILE_HOST_DEVICE constexpr raw_type unpack(number<I>) const;
CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1)
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number<I>) const
{
return _unpack(number<I>{});
}
CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const pk_float4_e2m1_t& x0,
const pk_float4_e2m1_t& x1)
{
return _pack(x0.get(), x1.get());
}
template <index_t I>
CK_TILE_HOST_DEVICE constexpr type _unpack(number<I>) const;
CK_TILE_HOST_DEVICE constexpr static type _pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
}
@@ -92,7 +104,7 @@ struct pk_float4_e2m1_t
};
using pk_fp4_t = pk_float4_e2m1_t;
using pk_fp4_raw_t = typename pk_fp4_t::raw_type;
using pk_fp4_raw_t = typename pk_fp4_t::type;
template <>
struct numeric_traits<pk_fp4_t>
@@ -124,7 +136,7 @@ struct numeric<pk_fp4_t>
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t denorm_min() { return binary_min_subnorm; }
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
// N/A
@@ -136,7 +148,7 @@ struct numeric<pk_fp4_t>
};
template <index_t I>
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::_unpack(number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 1)
@@ -202,7 +214,7 @@ CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16_t>(data, scale);
#else
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
#endif
}
@@ -211,13 +223,13 @@ CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16x2_t>(data, scale);
#else
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
#endif
}
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
// TODO: make it generic so that we can convert from directrly.
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
@@ -227,14 +239,20 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
{
return float_to_e2m1(x, scale);
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
auto res = convert_to_type<pk_fp4_t>(x, scale);
return pk_fp4_t::_pack(res, res);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return float_to_e2m1(type_convert<float>(x), scale);
auto res = float_to_mxfp4(type_convert<float>(x), scale);
return pk_fp4_t::_pack(res, res);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
@@ -242,7 +260,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float sca
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return float_to_e2m1(type_convert<float>(x), scale);
auto res = float_to_mxfp4(type_convert<float>(x), scale);
return pk_fp4_t::_pack(res, res);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
@@ -250,7 +269,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
@@ -258,7 +277,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
@@ -266,7 +285,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
#endif
}
@@ -301,7 +320,7 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32_t>(data, scale);
#else
return convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale);
return convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
@@ -309,8 +328,8 @@ CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32x2_t>(data, scale);
#else
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale),
convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale)};
return fp32x2_t{convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale),
convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale)};
#endif
}
@@ -319,7 +338,7 @@ CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16_t>(data, scale);
#else
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
#endif
}
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
@@ -327,28 +346,29 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16x2_t>(data, scale);
#else
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
#endif
}
#else
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
{
return e2m1_to_fp32_table[unpack(number<0>{})] * scale;
return e2m1_to_fp32_table[_unpack(number<0>{})] * scale;
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
{
return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale};
return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale};
}
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
{
return type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale;
return type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale;
}
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
{
return fp16x2_t{
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale),
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)};
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale),
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<1>{})]) *
scale)};
}
#endif

View File

@@ -67,7 +67,10 @@ struct FillUniformDistribution
: std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
if constexpr(numeric_traits<T>::PackedSize == 2)
return ck_tile::type_convert<T>(fp32x2_t{dis(gen), dis(gen)});
else
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
@@ -77,8 +80,12 @@ struct FillUniformDistribution
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
std::generate(first, last, [&dis, &gen]() {
if constexpr(numeric_traits<T>::PackedSize == 2)
return ck_tile::type_convert<T>(fp32x2_t{dis(gen), dis(gen)});
else
return ck_tile::type_convert<T>(dis(gen));
});
}
}

View File

@@ -2,6 +2,7 @@
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include <vector>
#include <hip/hip_runtime.h>
#include "ck_tile/core.hpp"
@@ -29,6 +30,12 @@ TEST(PackedFp4, NumericLimits)
EXPECT_EQ(ck_tile::numeric<pk_fp4_t>::epsilon(), pk_fp4_t{0b00010001});
EXPECT_EQ(ck_tile::numeric<pk_fp4_t>::round_error(), pk_fp4_t{0b00010001});
}
TEST(PackedFp4, fill)
{
std::vector<pk_fp4_t> v_fp4(4);
ck_tile::FillUniformDistribution<pk_fp4_t>{1.f, 1.f}(v_fp4);
EXPECT_EQ(v_fp4[0].get(), pk_fp4_t{0b00100010}.get());
}
TEST(PackedFp4, ConvertBasic)
{
EXPECT_EQ(ck_tile::convert_to_type<pk_fp4_t>(0.0f), pk_fp4_t{0b00000000}.get());
@@ -102,7 +109,7 @@ struct SrcPkfp4Dst
// ex: fp32_t -> fp4 -> bf16_t
dst[i] = toDST(toPF4(src[i]));
// ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t
dst[i + 1] = toDST(toPF4(toPF4(input2).unpack(number<1>{})));
dst[i + 1] = toDST(toPF4(input2).unpack(number<1>{}));
}
else
{