From 888b6cad866698df47d6e0d9b252486ccd44e49b Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 30 Apr 2026 14:06:06 +0000 Subject: [PATCH] Use inline-assembly based v_pk_mul_f32 to scale tile pcomp_tile in non-softmax pipeline on gfx950 --- ...tention_no_softmax_fwd_trload_pipeline.hpp | 4 +- .../18_hstu_attention/hstu_attention_util.hpp | 79 +++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index f65cd614a7..14e7897c88 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -7,6 +7,7 @@ #include #include "hstu_attention_fwd_pipeline_policy.hpp" +#include "hstu_attention_util.hpp" namespace ck_tile { @@ -393,8 +394,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad tile_elementwise_inout(f_silu, pcomp_tile); - tile_elementwise_inout([&](auto& x) { x = x * type_convert(scale_p); }, - pcomp_tile); + detail::scale_tile_in_pack(pcomp_tile, scale_p); seqlen_k_curr += kN0; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp index b098f48df0..6ad710b6a9 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_util.hpp @@ -53,6 +53,85 @@ struct has_use_trload_flag< template static inline constexpr bool is_using_trload_v = has_use_trload_flag::value; +// scale is uniform (scalar register), c is per-lane (vector register) +// GFX9 VOP2: V_MUL_F32 VDST, SRC0, SRC1 - SRC0 can be SGPR, SRC1 must be VGPR + +// scale is uniform (scalar register), c is per-lane (vector register) +// GFX9 VOP2: V_MUL_F32 VDST, SRC0, SRC1 - SRC0 can be SGPR, SRC1 must be VGPR +CK_TILE_DEVICE static void v_mul_f32_two(float& c0, float& c1, float scale) +{ + asm volatile("v_mul_f32 %[v_c0], %[s_scale], %[v_c0] \n\ + v_mul_f32 %[v_c1], %[s_scale], %[v_c1]" + : [v_c0] "+v"(c0), [v_c1] "+v"(c1) + : [s_scale] "s"(scale) + :); +} + +CK_TILE_DEVICE static void v_mul_f32(float& c, float scale) +{ + asm volatile("v_mul_f32 %[v_c], %[s_scale], %[v_c]" : [v_c] "+v"(c) : [s_scale] "s"(scale) :); +} + +CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) +{ + fp32x2_t result; + asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + +template +CK_TILE_DEVICE static void scale_tile_in_scalar(InOutDstrTensor& in_out_dstr_tensor, float scale) +{ + using DataType = typename InOutDstrTensor::DataType; + + if constexpr(std::is_same_v, float>) + { + auto tmp_scale = type_convert(scale); + + constexpr index_t thread_buffer_size = InOutDstrTensor::get_thread_buffer_size(); + + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + v_mul_f32_two(in_out_dstr_tensor.thread_buf_[idx], + in_out_dstr_tensor.thread_buf_[idx + 1], + tmp_scale); + }); + } + else + { + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor); + }; +}; + +template +CK_TILE_DEVICE static void scale_tile_in_pack(InOutDstrTensor& in_out_dstr_tensor, float scale) +{ + using DataType = typename InOutDstrTensor::DataType; + + if constexpr(std::is_same_v, float>) + { + fp32x2_t pk_scale; + + pk_scale.x = scale; + pk_scale.y = scale; + + constexpr index_t thread_buffer_size = InOutDstrTensor::get_thread_buffer_size(); + + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + fp32x2_t input = {in_out_dstr_tensor.thread_buf_[idx], + in_out_dstr_tensor.thread_buf_[idx + 1]}; + auto output = pk_mul_f32(input, pk_scale); + in_out_dstr_tensor.thread_buf_[idx] = output.x; + in_out_dstr_tensor.thread_buf_[idx + 1] = output.y; + }); + } + else + { + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor); + }; +}; + } // namespace detail } // namespace ck_tile