From 81bd4741529fc97de04431c47fe784d7b29a2d2e Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 12 Oct 2024 13:58:45 +0000 Subject: [PATCH] Remove the using of MakeKRegBlockDescriptor and MakeVRegBlockDescriptor --- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 7 +- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 7 +- ...block_fmha_bwd_pipeline_default_policy.hpp | 64 ------------------- 3 files changed, 4 insertions(+), 74 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index e2719a2137..a87dac3757 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -189,7 +189,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR Policy::template MakeKRegSliceBlockDescriptor()); auto k_reg_tensor = make_static_distributed_tensor( - Policy::template MakeKRegBlockDescriptor()); + Policy::template MakeKRegSliceBlockDescriptor()); //------------------------------------------------------------------ // V, HBM ->LDS ->Reg @@ -214,9 +214,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR v_lds_write_window.get_window_origin(), Policy::template MakeVRegSliceBlockDescriptor()); - auto v_reg_tensor = make_static_distributed_tensor( - Policy::template MakeVRegBlockDescriptor()); - //------------------------------------------------------------------ // KT, Reg ->LDS ->Reg auto shuffled_k_block_tile = make_static_distributed_tensor( @@ -259,7 +256,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR block_sync_lds(); - v_reg_tensor = load_tile(v_lds_read_window); + auto v_reg_tensor = load_tile(v_lds_read_window); block_sync_lds(); //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 68fc5a12bc..8286b7feb7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -189,7 +189,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP Policy::template MakeKRegSliceBlockDescriptor()); auto k_reg_tensor = make_static_distributed_tensor( - Policy::template MakeKRegBlockDescriptor()); + Policy::template MakeKRegSliceBlockDescriptor()); //------------------------------------------------------------------ // V, HBM ->LDS ->Reg @@ -214,9 +214,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP v_lds_write_window.get_window_origin(), Policy::template MakeVRegSliceBlockDescriptor()); - auto v_reg_tensor = make_static_distributed_tensor( - Policy::template MakeVRegBlockDescriptor()); - //------------------------------------------------------------------ // KT, Reg ->LDS ->Reg auto shuffled_k_block_tile = make_static_distributed_tensor( @@ -259,7 +256,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP block_sync_lds(); - v_reg_tensor = load_tile(v_lds_read_window); + auto v_reg_tensor = load_tile(v_lds_read_window); //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS auto q_dram_window = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index ca39f457e8..b62dc2def4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -849,38 +849,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy return k_block_dstr; } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto k_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); - - return k_block_dstr; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() { @@ -924,38 +892,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy return v_block_dstr; } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto v_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); - - return v_block_dstr; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor() {