From 260ace4b78c921c865c1cc5ccaedef66a5d60bfa Mon Sep 17 00:00:00 2001 From: danyao12 Date: Mon, 22 Jul 2024 11:35:34 +0800 Subject: [PATCH] code cleanup --- .../ck_tile/ops/fmha/block/block_dropout.hpp | 60 ++++++++----------- 1 file changed, 25 insertions(+), 35 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 73d89d4907..363cb0fa4c 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -58,43 +58,33 @@ struct BlockDropout MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, index_t seqlen_qk_start) { - if constexpr(IsDropout) - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - constexpr index_t kMPerStep = MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; - const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); - auto randval_dram_window = [&]() { - if constexpr(IsFwd) - { - return make_tile_window( - randval_dram_block_window_tmp.get_bottom_tensor_view(), - ck_tile::make_tuple(number{}, number{}), - {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N - } - else - { - return make_tile_window( - randval_dram_block_window_tmp.get_bottom_tensor_view(), - ck_tile::make_tuple(number{}, number{}), - {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N - } - }(); + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); - return randval_dram_window; - } - else - { - (void)randval_dram_block_window_tmp; - (void)seqlen_qk_start; - - return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); - } + return randval_dram_window; } template