Test fix for gemm_b_scale_xdl_v3. (#3674)

This commit is contained in:
ApoorvaKalyani
2026-01-30 18:34:54 +01:00
committed by GitHub
parent 63df1c0af2
commit 70d71b1514

View File

@@ -52,9 +52,9 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
const int EX = 0x64006400;
// Extract the two int4 at low bit and create two fp16 number.
int lo = amd_assembly_and_or_b32(q, LO, EX);
int lo = (q & LO) | EX;
// Extract the two int4 at hight bit and create two fp16 number.
int hi = amd_assembly_and_or_b32(q, HI, EX);
int hi = (q & HI) | EX;
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
@@ -62,19 +62,15 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) = bit_cast<half2_t>(lo) + bit_cast<half2_t>(SUB);
res.template AsType<half2_t>()(Number<1>{}) =
bit_cast<half2_t>(hi) * bit_cast<half2_t>(MUL) + bit_cast<half2_t>(ADD);
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
asm volatile("v_pk_mul_f16 %0, %1, %2"
: "=v"(res.template AsType<half2_t>()(Number<0>{}))
: "v"(res.template AsType<half2_t>()(Number<0>{})), "v"(scale));
asm volatile("v_pk_mul_f16 %0, %1, %2"
: "=v"(res.template AsType<half2_t>()(Number<1>{}))
: "v"(res.template AsType<half2_t>()(Number<1>{})), "v"(scale));
res.template AsType<half2_t>()(Number<0>{}) * scale;
res.template AsType<half2_t>()(Number<1>{}) =
res.template AsType<half2_t>()(Number<1>{}) * scale;
return res.template AsType<half4_t>()[Number<0>{}];
}