let pack/unpack return pk_fp4_t

This commit is contained in:
Gino Lu
2025-09-16 02:29:01 -05:00
committed by mtgu0705
parent a333206929
commit ec9bcef591
2 changed files with 31 additions and 24 deletions

View File

@@ -62,8 +62,15 @@ struct pk_float4_e2m1_t
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
template <index_t I>
CK_TILE_HOST_DEVICE constexpr type unpack(number<I>) const;
CK_TILE_HOST_DEVICE constexpr static type pack(const type x0, const type x1)
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number<I>) const { return _unpack(number<I>{});}
CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const pk_float4_e2m1_t& x0, const pk_float4_e2m1_t& x1)
{
return _pack(x0.get(), x1.get());
}
template <index_t I>
CK_TILE_HOST_DEVICE constexpr type _unpack(number<I>) const;
CK_TILE_HOST_DEVICE constexpr static type _pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
}
@@ -137,7 +144,7 @@ struct numeric<pk_fp4_t>
};
template <index_t I>
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::_unpack(number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 1)
@@ -203,7 +210,7 @@ CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16_t>(data, scale);
#else
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
#endif
}
@@ -212,8 +219,8 @@ CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<bf16x2_t>(data, scale);
#else
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
#endif
}
@@ -232,7 +239,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float sca
return impl::_to_f4(x, scale);
#else
auto res = convert_to_type<pk_fp4_t>(x, scale);
return pk_fp4_t::pack(res, res);
return pk_fp4_t::_pack(res, res);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
@@ -241,7 +248,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float sca
return impl::_to_f4(x, scale);
#else
auto res = float_to_mxfp4(type_convert<float>(x), scale);
return pk_fp4_t::pack(res, res);
return pk_fp4_t::_pack(res, res);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
@@ -250,7 +257,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float sca
return impl::_to_f4(x, scale);
#else
auto res = float_to_mxfp4(type_convert<float>(x), scale);
return pk_fp4_t::pack(res, res);
return pk_fp4_t::_pack(res, res);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
@@ -258,7 +265,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
@@ -266,7 +273,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
@@ -274,7 +281,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
#endif
}
@@ -309,7 +316,7 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32_t>(data, scale);
#else
return convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale);
return convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
@@ -317,8 +324,8 @@ CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp32x2_t>(data, scale);
#else
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale),
convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale)};
return fp32x2_t{convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale),
convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale)};
#endif
}
@@ -327,7 +334,7 @@ CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16_t>(data, scale);
#else
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
#endif
}
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
@@ -335,28 +342,28 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
#if CK_TILE_FP4_CVT_DEVICE
return impl::_from_f4<fp16x2_t>(data, scale);
#else
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
#endif
}
#else
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
{
return e2m1_to_fp32_table[unpack(number<0>{})] * scale;
return e2m1_to_fp32_table[_unpack(number<0>{})] * scale;
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
{
return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale};
return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale};
}
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
{
return type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale;
return type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale;
}
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
{
return fp16x2_t{
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale),
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)};
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale),
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<1>{})]) * scale)};
}
#endif

View File

@@ -109,7 +109,7 @@ struct SrcPkfp4Dst
// ex: fp32_t -> fp4 -> bf16_t
dst[i] = toDST(toPF4(src[i]));
// ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t
dst[i + 1] = toDST(toPF4(toPF4(input2).unpack(number<1>{})));
dst[i + 1] = toDST(toPF4(input2).unpack(number<1>{}));
}
else
{