diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 3e9df1806f..d72647395a 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -109,11 +109,14 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) float x_l = ((x_u8 & 0x0f) >> 0) - 8.f; float x_h = ((x_u8 & 0xf0) >> 4) - 8.f; + /* #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE fp32x2_t res = {x_h, x_l}; #elif fp32x2_t res = {x_l, x_h}; #endif + */ + fp32x2_t res = {x_l, x_h}; return res; } @@ -159,8 +162,8 @@ CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x, float scale) { // TODO(yadai): confirm quanzation algorithm - // auto float_vec2 = pk_int4_t_to_fp32x2_t(x); - auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); + // auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); float_vec2.x = float_vec2.x * scale; float_vec2.y = float_vec2.y * scale; return fp16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)}; @@ -183,7 +186,7 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x, float scale) { - auto float_vec2 = pk_int4_t_to_fp32x2_t_signed_conversion(x); + auto float_vec2 = pk_int4_t_to_fp32x2_t(x); float_vec2.x = float_vec2.x * scale; float_vec2.y = float_vec2.y * scale; return bf16x2_t{type_convert(float_vec2.x), type_convert(float_vec2.y)};