From d3d53433aacc7e6eeb5645b45f080169aa715710 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Wed, 18 Dec 2024 00:49:27 +0000 Subject: [PATCH] update qsksvs pipeline [ROCm/composable_kernel commit: bfc997a7e69de42ac471f56c001725c9c438ac20] --- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 51 +++++++++++++++++-- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index a52ba83dd8..82e353be6a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" namespace ck_tile { @@ -99,8 +100,7 @@ struct BlockFmhaPipelineQSKSVS static constexpr const char* name = "qs"; - // using DropoutType = std::conditional_t; - using DropoutType = int32_t; // unused + using DropoutType = std::conditional_t; CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { @@ -267,7 +267,8 @@ struct BlockFmhaPipelineQSKSVS bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); + Policy::template MakeBiasDramTileDistribution()); + // Policy::template MakeBiasDramTileDistribution()); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -620,10 +621,46 @@ struct BlockFmhaPipelineQSKSVS return o_acc; } + // template + // CK_TILE_HOST_DEVICE auto + // operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + // const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + // const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + // const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + // LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + // FmhaMask mask, + // PositionEncoding position_encoding, + // float scale_s, + // void* smem_ptr) const + // { + // return operator()(q_dram_block_window_tmp, + // identity{}, + // k_dram_block_window_tmp, + // identity{}, + // v_dram_block_window_tmp, + // identity{}, + // bias_dram_block_window_tmp, + // identity{}, + // lse_dram_block_window_tmp, + // identity{}, + // identity{}, + // identity{}, + // identity{}, + // mask, + // position_encoding, + // scale_s, + // smem_ptr); + // } template CK_TILE_HOST_DEVICE auto @@ -631,11 +668,13 @@ struct BlockFmhaPipelineQSKSVS const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -645,6 +684,7 @@ struct BlockFmhaPipelineQSKSVS identity{}, bias_dram_block_window_tmp, identity{}, + randval_dram_block_window_tmp, lse_dram_block_window_tmp, identity{}, identity{}, @@ -653,7 +693,8 @@ struct BlockFmhaPipelineQSKSVS mask, position_encoding, scale_s, - smem_ptr); + smem_ptr, + dropout); } };