From 2c297ba2011823d89735aaa600afba7599e35e27 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 1 Sep 2025 15:16:28 -0500 Subject: [PATCH] apply scale p --- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 59 ++++++++++++++----- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index b3ec335ee6..a7871502ad 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1478,20 +1478,51 @@ struct FmhaFwdKernel } else { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); + if constexpr(std::is_same_v) + { + float invscale_p = ck_tile::type_convert(ck_tile::numeric::max()); + float scale_p = 1.f / invscale_p; + + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{invscale_p}, // p_compute_element_func + scales{scale_p}, // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } } }();