Use inline-assembly based v_pk_mul_f32 to scale tile pcomp_tile in non-softmax pipeline on gfx950

This commit is contained in:
Qianfeng Zhang
2026-04-30 14:06:06 +00:00
parent 4c583f0574
commit 888b6cad86
2 changed files with 81 additions and 2 deletions

View File

@@ -7,6 +7,7 @@
#include <ck_tile/ops/fmha/block/block_dropout.hpp>
#include "hstu_attention_fwd_pipeline_policy.hpp"
#include "hstu_attention_util.hpp"
namespace ck_tile {
@@ -393,8 +394,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
tile_elementwise_inout(f_silu, pcomp_tile);
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
pcomp_tile);
detail::scale_tile_in_pack(pcomp_tile, scale_p);
seqlen_k_curr += kN0;

View File

@@ -53,6 +53,85 @@ struct has_use_trload_flag<
template <typename T>
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
// scale is uniform (scalar register), c is per-lane (vector register)
// GFX9 VOP2: V_MUL_F32 VDST, SRC0, SRC1 - SRC0 can be SGPR, SRC1 must be VGPR
// scale is uniform (scalar register), c is per-lane (vector register)
// GFX9 VOP2: V_MUL_F32 VDST, SRC0, SRC1 - SRC0 can be SGPR, SRC1 must be VGPR
CK_TILE_DEVICE static void v_mul_f32_two(float& c0, float& c1, float scale)
{
asm volatile("v_mul_f32 %[v_c0], %[s_scale], %[v_c0] \n\
v_mul_f32 %[v_c1], %[s_scale], %[v_c1]"
: [v_c0] "+v"(c0), [v_c1] "+v"(c1)
: [s_scale] "s"(scale)
:);
}
CK_TILE_DEVICE static void v_mul_f32(float& c, float scale)
{
asm volatile("v_mul_f32 %[v_c], %[s_scale], %[v_c]" : [v_c] "+v"(c) : [s_scale] "s"(scale) :);
}
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
{
fp32x2_t result;
asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]"
: [result] "=v"(result)
: [lhs] "v"(lhs), [rhs] "v"(rhs));
return result;
}
template <typename InOutDstrTensor>
CK_TILE_DEVICE static void scale_tile_in_scalar(InOutDstrTensor& in_out_dstr_tensor, float scale)
{
using DataType = typename InOutDstrTensor::DataType;
if constexpr(std::is_same_v<std::remove_cv_t<DataType>, float>)
{
auto tmp_scale = type_convert<DataType>(scale);
constexpr index_t thread_buffer_size = InOutDstrTensor::get_thread_buffer_size();
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
v_mul_f32_two(in_out_dstr_tensor.thread_buf_[idx],
in_out_dstr_tensor.thread_buf_[idx + 1],
tmp_scale);
});
}
else
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor);
};
};
template <typename InOutDstrTensor>
CK_TILE_DEVICE static void scale_tile_in_pack(InOutDstrTensor& in_out_dstr_tensor, float scale)
{
using DataType = typename InOutDstrTensor::DataType;
if constexpr(std::is_same_v<std::remove_cv_t<DataType>, float>)
{
fp32x2_t pk_scale;
pk_scale.x = scale;
pk_scale.y = scale;
constexpr index_t thread_buffer_size = InOutDstrTensor::get_thread_buffer_size();
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
fp32x2_t input = {in_out_dstr_tensor.thread_buf_[idx],
in_out_dstr_tensor.thread_buf_[idx + 1]};
auto output = pk_mul_f32(input, pk_scale);
in_out_dstr_tensor.thread_buf_[idx] = output.x;
in_out_dstr_tensor.thread_buf_[idx + 1] = output.y;
});
}
else
{
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor);
};
};
} // namespace detail
} // namespace ck_tile