mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] fMHA batch_prefill block index & logits soft-capping optimizations (#2198)
* Write soft-sign in inline asm * Change tile idx computation * Add macro to turn off soft-sign asm opt * Use simple for loop to avoid register spill * Only do block id transform for masking cases
This commit is contained in:
@@ -15,7 +15,36 @@
|
||||
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
|
||||
#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
namespace internal {
|
||||
__device__ inline float
|
||||
exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
|
||||
{
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
|
||||
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
|
||||
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
|
||||
/// NOTICE: Make sure softmax_scale is stored in SGPR
|
||||
float result, numerator, denominator;
|
||||
asm volatile(
|
||||
"v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
|
||||
"v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
|
||||
"v_rcp_f32_e32 %[denominator], %[denominator]\n"
|
||||
"v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
|
||||
"v_mul_f32_e32 %[result], %[numerator], %[denominator]"
|
||||
: [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result)
|
||||
: [softmax_scale] "s"(softmax_scale),
|
||||
[logits] "v"(logits),
|
||||
[logits_soft_cap_rcp] "v"(logits_soft_cap_rcp));
|
||||
return result;
|
||||
#else
|
||||
return softmax_scale * logits * rcp<float>(1.f + abs(logits * logits_soft_cap_rcp));
|
||||
#endif
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
template <typename ImplMask>
|
||||
struct StandardAttentionParams
|
||||
@@ -169,8 +198,8 @@ struct LogitsSoftCap
|
||||
return params.logits_soft_cap *
|
||||
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
|
||||
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
|
||||
return params.sm_scale * type_convert<float>(logits) *
|
||||
rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
|
||||
return internal::exp2_soft_sign_impl(
|
||||
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -239,9 +268,8 @@ struct ComposedAttention
|
||||
return params.logits_soft_cap *
|
||||
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
|
||||
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
|
||||
return params.sm_scale * type_convert<float>(logits) *
|
||||
rcp<float>(1.f +
|
||||
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
|
||||
return internal::exp2_soft_sign_impl(
|
||||
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user