From 629573e3e3fa38dd630d0a6f547503ede30c3daf Mon Sep 17 00:00:00 2001 From: ApoorvaKalyani Date: Fri, 30 Jan 2026 18:34:54 +0100 Subject: [PATCH] Test fix for gemm_b_scale_xdl_v3. (#3674) [ROCm/composable_kernel commit: 70d71b1514cc650ef7808d8757097f2d8617d313] --- .../element/unary_element_wise_operation.hpp | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index e324479420..13d421c80c 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -52,9 +52,9 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) const int EX = 0x64006400; // Extract the two int4 at low bit and create two fp16 number. - int lo = amd_assembly_and_or_b32(q, LO, EX); + int lo = (q & LO) | EX; // Extract the two int4 at hight bit and create two fp16 number. - int hi = amd_assembly_and_or_b32(q, HI, EX); + int hi = (q & HI) | EX; const int SUB = 0xE408E408; // half2 {-1032, -1032} const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} @@ -62,19 +62,15 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) vector_type res; + res.template AsType()(Number<0>{}) = bit_cast(lo) + bit_cast(SUB); + + res.template AsType()(Number<1>{}) = + bit_cast(hi) * bit_cast(MUL) + bit_cast(ADD); + res.template AsType()(Number<0>{}) = - amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); - - res.template AsType()(Number<1>{}) = amd_assembly_pk_fma_f16( - bit_cast(hi), bit_cast(MUL), bit_cast(ADD)); - - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(res.template AsType()(Number<0>{})) - : "v"(res.template AsType()(Number<0>{})), "v"(scale)); - - asm volatile("v_pk_mul_f16 %0, %1, %2" - : "=v"(res.template AsType()(Number<1>{})) - : "v"(res.template AsType()(Number<1>{})), "v"(scale)); + res.template AsType()(Number<0>{}) * scale; + res.template AsType()(Number<1>{}) = + res.template AsType()(Number<1>{}) * scale; return res.template AsType()[Number<0>{}]; }