diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index f0a681cb17..3e9df1806f 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -127,11 +127,14 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_in x_l = x_l > 7 ? x_l - 16 : x_l; x_h = x_h > 7 ? x_h - 16 : x_h; + /* #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE fp32x2_t res = {x_h, x_l}; #elif fp32x2_t res = {x_l, x_h}; #endif + */ + fp32x2_t res = {x_l, x_h}; return res; } @@ -180,7 +183,7 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale) { - auto float_vec2 = pk_int4_t_to_fp32x2_t(x); + auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); float_vec2.x = float_vec2.x * scale; float_vec2.y = float_vec2.y * scale; return bf16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)};