From 63cc962000ac81cf3121fd3030b59049c4cde4d3 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:01:50 +0000 Subject: [PATCH] clang-format and remove dead code [ROCm/composable_kernel commit: edb78a4729278289a7d1bda94123aadec9821d1e] --- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 108 ++++++------------ ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 4 +- 2 files changed, 37 insertions(+), 75 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 82e353be6a..a59a59f85c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -128,42 +128,39 @@ struct BlockFmhaPipelineQSKSVS typename OAccElementFunction, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - // operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - // const QElementFunction& q_element_func, - // const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - // const KElementFunction& k_element_func, - // const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - // const VElementFunction& v_element_func, - // const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - // const BiasElementFunction& bias_element_func, - // LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile - // const LSEElementFunction& lse_element_func, - // const SAccElementFunction& s_acc_element_func, - // const PComputeElementFunction& p_compute_element_func, - // const OAccElementFunction& o_acc_element_func, - // FmhaMask mask, - // PositionEncoding position_encoding, - // float scale_s, - // void* smem_ptr) const - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction& bias_element_func, - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile - const LSEElementFunction& lse_element_func, - const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, + operator()(const QDramBlockWindowTmp & q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction & + q_element_func, + const KDramBlockWindowTmp & + k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction & + k_element_func, + const VDramBlockWindowTmp & + v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction & + v_element_func, + const BiasDramBlockWindowTmp & + bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction & + bias_element_func, + RandValDramBlockWindowTmp & + randval_dram_block_window_tmp, + LSEDramBlockWindowTmp & + lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction & + lse_element_func, + const SAccElementFunction & + s_acc_element_func, + const PComputeElementFunction & + p_compute_element_func, + const OAccElementFunction & + o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void* smem_ptr, - DropoutType& dropout) const + DropoutType & + dropout) const { static_assert( std::is_same_v> && @@ -263,12 +260,12 @@ struct BlockFmhaPipelineQSKSVS {seqlen_k_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_window = make_tile_window( - bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); - // Policy::template MakeBiasDramTileDistribution()); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + // Policy::template MakeBiasDramTileDistribution()); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -621,41 +618,6 @@ struct BlockFmhaPipelineQSKSVS return o_acc; } - // template - // CK_TILE_HOST_DEVICE auto - // operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - // const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - // const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - // const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - // LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - // FmhaMask mask, - // PositionEncoding position_encoding, - // float scale_s, - // void* smem_ptr) const - // { - // return operator()(q_dram_block_window_tmp, - // identity{}, - // k_dram_block_window_tmp, - // identity{}, - // v_dram_block_window_tmp, - // identity{}, - // bias_dram_block_window_tmp, - // identity{}, - // lse_dram_block_window_tmp, - // identity{}, - // identity{}, - // identity{}, - // identity{}, - // mask, - // position_encoding, - // scale_s, - // smem_ptr); - // } template CK_TILE_HOST_DEVICE static constexpr auto - MakeKLdsStoreBlockDescriptor(number = number<0>{}) + MakeKLdsStoreBlockDescriptor(number = number<0>{}) { // K is always k-major, we use async-copy to load into LDS constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; @@ -526,7 +526,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto - MakeKLdsLoadBlockDescriptor(number = number<0>{}) + MakeKLdsLoadBlockDescriptor(number = number<0>{}) { // K is always k-major, we use async-copy to load into LDS constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;