From e74ae6664a09e284ab96885098f1f657ab9bfcac Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 17 Dec 2024 01:24:22 +0000 Subject: [PATCH] qsksvs pipeline changes to mirror qrksvs [ROCm/composable_kernel commit: f7942b993cd70a29e9d392bc5df79b1d3c359ff5] --- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 2 + ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 4 ++ .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 45 ++++++++++++++++++- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 2 +- 4 files changed, 50 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 7ac86e6d12..4b16b1fc81 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -95,6 +95,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline { constexpr std::array occupancy{2, 2, 2, 2, 2, 1}; return occupancy[detail::log2::value - 2]; + } else if constexpr(kHeadDimV <= 512) { + return 1; } } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 04aa85644d..01a7bd36f4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -96,6 +96,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { return 1; } + else if constexpr(kQKHeaddim <= 512) + { + return 1; + } } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index b98247df9c..a52ba83dd8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -12,7 +12,7 @@ namespace ck_tile { /// NOTICE: we no-longer use this pipeline. // This pipeline is qkv all located in LDS template -struct [[deprecated]] BlockFmhaPipelineQSKSVS +struct BlockFmhaPipelineQSKSVS { using Problem = remove_cvref_t; using Policy = remove_cvref_t; @@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) @@ -81,6 +99,9 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS static constexpr const char* name = "qs"; + // using DropoutType = std::conditional_t; + using DropoutType = int32_t; // unused + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -95,6 +116,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -106,6 +128,23 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS typename OAccElementFunction, typename PositionEncoding> CK_TILE_HOST_DEVICE auto + // operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + // const QElementFunction& q_element_func, + // const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + // const KElementFunction& k_element_func, + // const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + // const VElementFunction& v_element_func, + // const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + // const BiasElementFunction& bias_element_func, + // LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + // const LSEElementFunction& lse_element_func, + // const SAccElementFunction& s_acc_element_func, + // const PComputeElementFunction& p_compute_element_func, + // const OAccElementFunction& o_acc_element_func, + // FmhaMask mask, + // PositionEncoding position_encoding, + // float scale_s, + // void* smem_ptr) const operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -114,6 +153,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, const SAccElementFunction& s_acc_element_func, @@ -122,7 +162,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + DropoutType& dropout) const { static_assert( std::is_same_v> && diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 1c9df46449..4d3c7c09d2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -127,7 +127,7 @@ struct BlockFmhaPipelineQXCustomPolicy /// NOTICE: we no-longer use this policy. template <> -struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy +struct BlockFmhaPipelineQXCustomPolicy { static constexpr bool QLoadOnce = false;