From 272bd532f6daef370afac4ad0124eb7ec1a25880 Mon Sep 17 00:00:00 2001 From: AMD-dteng Date: Sun, 27 Apr 2025 13:50:44 +0800 Subject: [PATCH] move make tile window out of hotloop --- .../core/arch/amd_buffer_addressing.hpp | 8 +-- .../arch/amd_buffer_addressing_builtins.hpp | 8 +-- .../block_flatmm_asmem_bsmem_creg_v1.hpp | 51 ++++++++++--------- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 51 ++++++++++++++----- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 10 ++-- 5 files changed, 81 insertions(+), 47 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 5d6d6ce348..124af4586b 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -38,10 +38,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); + // r.x = __builtin_amdgcn_readfirstlane(r.x); + // r.y = __builtin_amdgcn_readfirstlane(r.y); + // r.z = __builtin_amdgcn_readfirstlane(r.z); + // r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 0b9956cd01..1bd1edd7b3 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -29,10 +29,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); + // r.x = __builtin_amdgcn_readfirstlane(r.x); + // r.y = __builtin_amdgcn_readfirstlane(r.y); + // r.z = __builtin_amdgcn_readfirstlane(r.z); + // r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 0833a9e5e6..54a060b45c 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -68,48 +68,51 @@ struct BlockFlatmmASmemBSmemCRegV1 // C += A * B template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockWindow& a_block_window, + ABlockWindow& a_warp_windows, BFlatBlockTensor& b_warp_tensor) const { - constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; + // constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; - static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!"); + // static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!"); + + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t KPerBlock = BlockGemmShape::kK; constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); + // constexpr index_t NWarp = config.template at<2>(); constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + // constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + // constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - const index_t iMWarp = get_warp_id() / NWarp; + // const index_t iMWarp = get_warp_id() / NWarp; // construct A-warp-window - auto a_warp_window_tmp = make_tile_window( - a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + // auto a_warp_window_tmp = make_tile_window( + // a_block_window.get_bottom_tensor_view(), + // make_tuple(number{}, number{}), + // a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + // make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + // statically_indexed_array< + // statically_indexed_array, + // MIterPerWarp> + // a_warp_windows; + // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + // a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); + // move_tile_window(a_warp_windows(mIter)(kIter), + // {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + // }); + // }); using CWarpDstr = typename WG::CWarpDstr; using CWarpTensor = typename WG::CWarpTensor; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index c0f0be1b65..3b81b86204 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -92,16 +92,17 @@ struct FlatmmPipelineAGmemBGmemCRegV1 constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; // constexpr index_t A_LDS_Read_Inst_Remain = A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num; - static_for<0, A_LDS_Read_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); + static_for<0, A_LDS_Read_Inst_Num-A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read @@ -110,7 +111,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA }); } @@ -134,15 +135,21 @@ struct FlatmmPipelineAGmemBGmemCRegV1 using WG = remove_cvref_t())>; - // constexpr index_t MWarp = config.template at<1>(); + constexpr index_t MWarp = config.template at<1>(); constexpr index_t NWarp = config.template at<2>(); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + // A tile in LDS ADataType* p_a_lds = static_cast(p_smem); @@ -166,6 +173,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1 auto a_lds_gemm_window = make_tile_window( a_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto a_warp_window_tmp = make_tile_window( + a_lds_gemm_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + // Block GEMM auto block_flatmm = BlockFlatmm(); @@ -245,7 +271,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 a_block_tile = load_tile(a_copy_dram_window); // GEMM i - block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor); + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); block_sync_lds(); @@ -278,7 +304,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 a_block_tile = load_tile(a_copy_dram_window); // GEMM i - block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor_2); + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); block_sync_lds(); @@ -315,7 +341,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 a_block_tile = load_tile(a_copy_dram_window); // GEMM i - block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor); + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); block_sync_lds(); @@ -340,10 +366,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // move to next flat K // move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + HotLoopScheduler(); block_sync_lds(); // GEMM num_loop - 1 - block_flatmm(c_block_tile, a_lds_gemm_window, b_warp_tensor_2); + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); } return c_block_tile; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index a4b244e06e..be6bff9d8b 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -19,11 +19,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { using namespace ck_tile; - + +#if 0 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPack = GetSmemPackA(); -#if 1 + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), @@ -39,6 +40,9 @@ struct UniversalFlatmmPipelineAgBgCrPolicy #endif /*xor*/ #if 0 + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = GetSmemPackA(); using ADataType = remove_cvref_t; constexpr auto DataTypeSize = sizeof(ADataType); @@ -79,7 +83,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); #endif -#if 0 /*reduce transform layers,compare with old ck*/ +#if 1 /*reduce transform layers,compare with old ck*/ constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPack = GetSmemPackA();