From 3ca1b9df66baae32f0c841be5c9ae4fcf47c88ae Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 16 May 2025 15:14:46 +0800 Subject: [PATCH] [CK_TILE] fMHA batch_prefill block index & logits soft-capping optimizations (#2198) * Write soft-sign in inline asm * Change tile idx computation * Add macro to turn off soft-sign asm opt * Use simple for loop to avoid register spill * Only do block id transform for masking cases [ROCm/composable_kernel commit: 791802b381c99e47966cbf4a987b91ab3d56bcfc] --- include/ck_tile/ops/fmha/block/variants.hpp | 38 ++++++++++++++++--- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 21 ++++++++-- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 13 ++++++- 3 files changed, 63 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/variants.hpp b/include/ck_tile/ops/fmha/block/variants.hpp index 90fc5656fc..d8b0cdbb86 100644 --- a/include/ck_tile/ops/fmha/block/variants.hpp +++ b/include/ck_tile/ops/fmha/block/variants.hpp @@ -15,7 +15,36 @@ #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH #endif +#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM +#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0 +#endif + namespace ck_tile { +namespace internal { +__device__ inline float +exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp) +{ +#if(defined(__gfx90a__) || defined(__gfx94__)) && \ + (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ + CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) + /// NOTICE: Make sure softmax_scale is stored in SGPR + float result, numerator, denominator; + asm volatile( + "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n" + "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n" + "v_rcp_f32_e32 %[denominator], %[denominator]\n" + "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n" + "v_mul_f32_e32 %[result], %[numerator], %[denominator]" + : [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result) + : [softmax_scale] "s"(softmax_scale), + [logits] "v"(logits), + [logits_soft_cap_rcp] "v"(logits_soft_cap_rcp)); + return result; +#else + return softmax_scale * logits * rcp(1.f + abs(logits * logits_soft_cap_rcp)); +#endif +} +} // namespace internal template struct StandardAttentionParams @@ -169,8 +198,8 @@ struct LogitsSoftCap return params.logits_soft_cap * tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return params.sm_scale * type_convert(logits) * - rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); #endif } else @@ -239,9 +268,8 @@ struct ComposedAttention return params.logits_soft_cap * tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return params.sm_scale * type_convert(logits) * - rcp(1.f + - abs(type_convert(logits) * params.logits_soft_cap_rcp)); + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); #endif } else diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index ba327ee511..7472c82114 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -651,8 +651,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel }; const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } else { @@ -672,7 +679,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index e07cf1c94e..8691622bb0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,8 +6,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -498,6 +499,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync #else for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) { +#if(defined(__gfx90a__) || defined(__gfx94__)) && \ + (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ + CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) + // Avoid data hazard if v_mfma is followed by inline asm consumer + // instructions. In this case, compiler won't add s_nop for us + if(i == s_acc.thread_buf_.size() / 2) + { + __builtin_amdgcn_sched_barrier(0); + } +#endif apply_logits_transform(s_acc.thread_buf_[i]); } #endif