mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Use inline-assembly based v_pk_mul_f32 to scale tile pcomp_tile in non-softmax pipeline on gfx950
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <ck_tile/ops/fmha/block/block_dropout.hpp>
|
||||
|
||||
#include "hstu_attention_fwd_pipeline_policy.hpp"
|
||||
#include "hstu_attention_util.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -393,8 +394,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
|
||||
pcomp_tile);
|
||||
detail::scale_tile_in_pack(pcomp_tile, scale_p);
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
|
||||
@@ -53,6 +53,85 @@ struct has_use_trload_flag<
|
||||
template <typename T>
|
||||
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
|
||||
|
||||
// scale is uniform (scalar register), c is per-lane (vector register)
|
||||
// GFX9 VOP2: V_MUL_F32 VDST, SRC0, SRC1 - SRC0 can be SGPR, SRC1 must be VGPR
|
||||
|
||||
// scale is uniform (scalar register), c is per-lane (vector register)
|
||||
// GFX9 VOP2: V_MUL_F32 VDST, SRC0, SRC1 - SRC0 can be SGPR, SRC1 must be VGPR
|
||||
CK_TILE_DEVICE static void v_mul_f32_two(float& c0, float& c1, float scale)
|
||||
{
|
||||
asm volatile("v_mul_f32 %[v_c0], %[s_scale], %[v_c0] \n\
|
||||
v_mul_f32 %[v_c1], %[s_scale], %[v_c1]"
|
||||
: [v_c0] "+v"(c0), [v_c1] "+v"(c1)
|
||||
: [s_scale] "s"(scale)
|
||||
:);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static void v_mul_f32(float& c, float scale)
|
||||
{
|
||||
asm volatile("v_mul_f32 %[v_c], %[s_scale], %[v_c]" : [v_c] "+v"(c) : [s_scale] "s"(scale) :);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
|
||||
{
|
||||
fp32x2_t result;
|
||||
asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]"
|
||||
: [result] "=v"(result)
|
||||
: [lhs] "v"(lhs), [rhs] "v"(rhs));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename InOutDstrTensor>
|
||||
CK_TILE_DEVICE static void scale_tile_in_scalar(InOutDstrTensor& in_out_dstr_tensor, float scale)
|
||||
{
|
||||
using DataType = typename InOutDstrTensor::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<std::remove_cv_t<DataType>, float>)
|
||||
{
|
||||
auto tmp_scale = type_convert<DataType>(scale);
|
||||
|
||||
constexpr index_t thread_buffer_size = InOutDstrTensor::get_thread_buffer_size();
|
||||
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
v_mul_f32_two(in_out_dstr_tensor.thread_buf_[idx],
|
||||
in_out_dstr_tensor.thread_buf_[idx + 1],
|
||||
tmp_scale);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InOutDstrTensor>
|
||||
CK_TILE_DEVICE static void scale_tile_in_pack(InOutDstrTensor& in_out_dstr_tensor, float scale)
|
||||
{
|
||||
using DataType = typename InOutDstrTensor::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<std::remove_cv_t<DataType>, float>)
|
||||
{
|
||||
fp32x2_t pk_scale;
|
||||
|
||||
pk_scale.x = scale;
|
||||
pk_scale.y = scale;
|
||||
|
||||
constexpr index_t thread_buffer_size = InOutDstrTensor::get_thread_buffer_size();
|
||||
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
fp32x2_t input = {in_out_dstr_tensor.thread_buf_[idx],
|
||||
in_out_dstr_tensor.thread_buf_[idx + 1]};
|
||||
auto output = pk_mul_f32(input, pk_scale);
|
||||
in_out_dstr_tensor.thread_buf_[idx] = output.x;
|
||||
in_out_dstr_tensor.thread_buf_[idx + 1] = output.y;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, in_out_dstr_tensor);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user