[CK_TILE] fMHA batch_prefill block index & logits soft-capping optimizations (#2198)

* Write soft-sign in inline asm

* Change tile idx computation

* Add macro to turn off soft-sign asm opt

* Use simple for loop to avoid register spill

* Only do block id transform for masking cases

[ROCm/composable_kernel commit: 791802b381]
This commit is contained in:
Po Yen Chen
2025-05-16 15:14:46 +08:00
committed by GitHub
parent b54eceb59f
commit 3ca1b9df66
3 changed files with 63 additions and 9 deletions

View File

@@ -15,7 +15,36 @@
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
#endif
#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
#endif
namespace ck_tile {
namespace internal {
__device__ inline float
exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
{
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
/// NOTICE: Make sure softmax_scale is stored in SGPR
float result, numerator, denominator;
asm volatile(
"v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
"v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
"v_rcp_f32_e32 %[denominator], %[denominator]\n"
"v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
"v_mul_f32_e32 %[result], %[numerator], %[denominator]"
: [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result)
: [softmax_scale] "s"(softmax_scale),
[logits] "v"(logits),
[logits_soft_cap_rcp] "v"(logits_soft_cap_rcp));
return result;
#else
return softmax_scale * logits * rcp<float>(1.f + abs(logits * logits_soft_cap_rcp));
#endif
}
} // namespace internal
template <typename ImplMask>
struct StandardAttentionParams
@@ -169,8 +198,8 @@ struct LogitsSoftCap
return params.logits_soft_cap *
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return params.sm_scale * type_convert<float>(logits) *
rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
return internal::exp2_soft_sign_impl(
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
#endif
}
else
@@ -239,9 +268,8 @@ struct ComposedAttention
return params.logits_soft_cap *
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return params.sm_scale * type_convert<float>(logits) *
rcp<float>(1.f +
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
return internal::exp2_soft_sign_impl(
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
#endif
}
else

View File

@@ -651,8 +651,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
else
{
@@ -672,7 +679,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
}

View File

@@ -6,8 +6,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -498,6 +499,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
#else
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
// Avoid data hazard if v_mfma is followed by inline asm consumer
// instructions. In this case, compiler won't add s_nop for us
if(i == s_acc.thread_buf_.size() / 2)
{
__builtin_amdgcn_sched_barrier(0);
}
#endif
apply_logits_transform(s_acc.thread_buf_[i]);
}
#endif