From 918d5b21bc7b7200234f81940e940443db004292 Mon Sep 17 00:00:00 2001 From: YC Lin Date: Wed, 16 Apr 2025 03:06:54 +0000 Subject: [PATCH 01/21] [GEMM] Fix num_loop issues --- .../02_gemm/block_gemm_asmem_bsmem_creg.hpp | 94 ++++++------- .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 59 ++++---- ...peline_agmem_bgmem_creg_default_policy.hpp | 8 -- .../ck_tile/99_toy_example/02_gemm/gemm.hpp | 132 ------------------ 4 files changed, 68 insertions(+), 225 deletions(-) diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp index ddea73644a..7f6ec4307f 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp @@ -20,7 +20,9 @@ struct BlockGemmASmemBSmemCReg using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using WarpGemm = remove_cvref_t().template at<0>())>; + using WarpGemm = remove_cvref_t().template get<0>())>; + static constexpr index_t MWarp = Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = Policy::template GetWarpGemmMWarpNWarp().template get<2>(); using AWarpDstr = typename WarpGemm::AWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr; @@ -45,11 +47,7 @@ struct BlockGemmASmemBSmemCReg // A block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarp * WarpGemm::kM); - constexpr index_t KPerBlock = BlockGemmShape::kK; constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; @@ -69,11 +67,7 @@ struct BlockGemmASmemBSmemCReg // B block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarp * WarpGemm::kN); - constexpr index_t KPerBlock = BlockGemmShape::kK; constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; @@ -128,16 +122,13 @@ struct BlockGemmASmemBSmemCReg static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + // constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; + // using WarpGemm = remove_cvref_t())>; - constexpr index_t MWarp = config.template get<1>(); - constexpr index_t NWarp = config.template get<2>(); - - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; @@ -149,9 +140,9 @@ struct BlockGemmASmemBSmemCReg // construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WG::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); statically_indexed_array, MIterPerWarp> a_warp_windows; @@ -167,9 +158,9 @@ struct BlockGemmASmemBSmemCReg // construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); statically_indexed_array, NIterPerWarp> b_warp_windows; @@ -185,27 +176,25 @@ struct BlockGemmASmemBSmemCReg // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { -#if defined(ENABLE_INSTRUCTION_SCH) -#pragma message ("local data share prefetch") // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; +#if defined(ENABLE_INSTRUCTION_SCH) +#pragma message ("local data share prefetch") a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); #else - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { -#if defined(ENABLE_INSTRUCTION_SCH) // read B warp tensor from B block tensor BWarpTensor b_warp_tensor; +#if defined(ENABLE_INSTRUCTION_SCH) b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); #else - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter)); #endif // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -215,7 +204,7 @@ struct BlockGemmASmemBSmemCReg merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -242,16 +231,13 @@ struct BlockGemmASmemBSmemCReg static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + // constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t{}))>; + // using WarpGemm = remove_cvref_t{}))>; - constexpr index_t MWarp = config.template get(number<1>{}); - constexpr index_t NWarp = config.template get(number<2>{}); - - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; @@ -263,9 +249,9 @@ struct BlockGemmASmemBSmemCReg // construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WG::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); statically_indexed_array, MIterPerWarp> a_warp_windows; @@ -281,9 +267,9 @@ struct BlockGemmASmemBSmemCReg // construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); statically_indexed_array, NIterPerWarp> b_warp_windows; @@ -296,7 +282,7 @@ struct BlockGemmASmemBSmemCReg }); }); - static_assert(std::is_same_v, "wrong!"); + static_assert(std::is_same_v, "wrong!"); // Construct C-Block-Tensor constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< @@ -308,7 +294,7 @@ struct BlockGemmASmemBSmemCReg sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); @@ -317,26 +303,24 @@ struct BlockGemmASmemBSmemCReg // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { -#if defined(ENABLE_INSTRUCTION_SCH) // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; +#if defined(ENABLE_INSTRUCTION_SCH) a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); #else - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { -#if defined(ENABLE_INSTRUCTION_SCH) // read B warp tensor from B block tensor BWarpTensor b_warp_tensor; +#if defined(ENABLE_INSTRUCTION_SCH) b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); #else - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter)); #endif // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -345,7 +329,7 @@ struct BlockGemmASmemBSmemCReg if constexpr(KIterPerWarp == 0) { // c = a * b - c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor); + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); } else { @@ -354,7 +338,7 @@ struct BlockGemmASmemBSmemCReg merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); } // write C warp tensor into C block tensor diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index c488f52185..d96f27b66b 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -47,14 +47,6 @@ struct BlockGemmPipelineAGmemBGmemCReg static constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - using I0 = number<0>; - using I1 = number<1>; - using I2 = number<2>; - static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; @@ -74,8 +66,8 @@ struct BlockGemmPipelineAGmemBGmemCReg constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; constexpr index_t WaveSize = 64; - constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); - constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + constexpr index_t WaveNumM = BlockGemm::MWarp; + constexpr index_t WaveNumN = BlockGemm::NWarp; constexpr index_t AB_LDS_RW_Width = GetSmemPack(); @@ -321,35 +313,39 @@ struct BlockGemmPipelineAGmemBGmemCReg constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); - // Prefetch - // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - load_tile(b_block_tile, b_copy_dram_window); - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - // Initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - // LDS write 0 - store_tile(a_copy_lds_window, a_block_tile); - store_tile(b_copy_lds_window, b_block_tile); - + // Prefetch // Global read 0 load_tile(a_block_tile, a_copy_dram_window); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); load_tile(b_block_tile, b_copy_dram_window); - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - block_sync_lds(); + if (num_loop > 1) + { + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - // Prefetch from LDS to warp register in block gemm - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // LDS write 0 + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 0 + load_tile(a_block_tile, a_copy_dram_window); + load_tile(b_block_tile, b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + // Prefetch from LDS to warp register in block gemm + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + } __builtin_amdgcn_sched_barrier(0); // Main body - if constexpr(HasHotLoop) + if (num_loop > 2) { index_t i = 0; do @@ -362,8 +358,8 @@ struct BlockGemmPipelineAGmemBGmemCReg // Global read 0 load_tile(a_block_tile, a_copy_dram_window); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); load_tile(b_block_tile, b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); @@ -382,8 +378,11 @@ struct BlockGemmPipelineAGmemBGmemCReg } // Tail - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); + if (num_loop > 1) + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + } store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); block_sync_lds(); diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp index 71d97c4ce4..8149fb4132 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -314,12 +314,8 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - static_assert(std::is_same_v); - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - return GetGlobalVectorLoadSize(); } @@ -327,12 +323,8 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { using BDataType = remove_cvref_t; - using BLayout = remove_cvref_t; - static_assert(std::is_same_v); - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - return GetGlobalVectorLoadSize(); } diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp index 9b196c96b8..8cc3013d5f 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp @@ -29,28 +29,6 @@ struct GridGemmProblem using CElementFunction = CElementFunction_; }; -#if defined(ENABLE_INSTRUCTION_SCH) -template -struct TileGemmShape -{ - using BlockTile = remove_cvref_t; - using BlockWarps = remove_cvref_t; - using WarpTile = remove_cvref_t; - - static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); - - static constexpr index_t kM = BlockTile::at(number<0>{}); - static constexpr index_t kN = BlockTile::at(number<1>{}); - static constexpr index_t kK = BlockTile::at(number<2>{}); - - static constexpr bool PermuteA = PermuteA_; - static constexpr bool PermuteB = PermuteB_; -}; -#else template struct TileGemmShape { @@ -58,71 +36,7 @@ struct TileGemmShape static constexpr index_t kN = kNPerTile; static constexpr index_t kK = kKPerTile; }; -#endif -#if defined(ENABLE_INSTRUCTION_SCH) -template -struct TileGemmTraits -{ - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool kPadK = kPadK_; - - static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - - using ALayout = ALayout_; - using BLayout = BLayout_; - using CLayout = CLayout_; - - static constexpr bool TransposeC = TransposeC_; -}; - -template -struct BlockGemmPipelineProblem -{ - using Traits = remove_cvref_t; - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - - using BlockGemmShape = remove_cvref_t; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - - static constexpr bool kPadM = Traits::kPadM; - static constexpr bool kPadN = Traits::kPadN; - static constexpr bool kPadK = Traits::kPadK; - - static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - - static constexpr auto Scheduler = Scheduler_; - static constexpr auto HasHotLoop = HasHotLoop_; - static constexpr auto TailNum = TailNum_; - - static constexpr bool TransposeC = Traits::TransposeC; -}; -#else template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() { -#if defined(ENABLE_INSTRUCTION_SCH) - // Block GEMM pipeline w/ instruction scheduling - using GemmShape = TileGemmShape, - sequence, - sequence, - PermuteA, - PermuteB>; - - using GemmTraits = TileGemmTraits; - - using BlockGemmPipelineProblem_ = - BlockGemmPipelineProblem; -#else using BlockGemmPipelineProblem_ = BlockGemmPipelineProblem>; - -#endif return BlockGemmPipelineAGmemBGmemCReg{}; } }; From 77a96c7a823e05e4748985180abfc8f2c6e414bc Mon Sep 17 00:00:00 2001 From: MHYang Date: Fri, 18 Apr 2025 10:15:17 +0000 Subject: [PATCH 02/21] Fix register spilling and K0 tile size issues --- ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 3 +-- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 6 ++++++ .../flash_attention_fwd_impl.hpp | 21 ++++++++----------- ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 3 +-- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 6 ++++++ .../flash_attention_fwd_impl.hpp | 21 ++++++++----------- 6 files changed, 32 insertions(+), 28 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 92b98a2f11..4a8c2beeb7 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -58,8 +58,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); ignore = a_element_func; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index cdce1b1f31..2cafb715a2 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -91,6 +91,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + return MakeARegBlockDescriptor(); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 4229db5250..234ff8821a 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -140,7 +140,7 @@ struct FlashAttentionFwdImpl // Q/K/V DRAM and DRAM window const auto q_dram = make_naive_tensor_view( - q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); + q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); const auto k_dram = make_naive_tensor_view( k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{}); @@ -149,7 +149,10 @@ struct FlashAttentionFwdImpl v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{}); auto q_dram_window = make_tile_window( - q_dram, make_tuple(number{}, number{}), {iM0, 0}); + q_dram, + make_tuple(number{}, number{}), + {iM0, 0}, + BlockGemm0Policy::template MakeADramTileDistribution()); auto k_dram_window = make_tile_window( k_dram, make_tuple(number{}, number{}), {0, 0}); @@ -160,9 +163,8 @@ struct FlashAttentionFwdImpl {iN1, 0}, MakeVDramTileDistribution()); - // Q in Register - auto q_reg_tensor = make_static_distributed_tensor( - BlockGemm0Policy::template MakeARegBlockDescriptor()); + // Q in register + auto q_reg_tensor = load_tile(q_dram_window); // V LDS and LDS window // V LDS occupies the same LDS allocation Q/K LDS @@ -212,15 +214,10 @@ struct FlashAttentionFwdImpl // loop over Column of S (J loop) index_t iN0 = 0; - // Cold Q_Reg_Cache - s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr); do { - // Hot Q_Reg_Cache - if(iN0 > 0) - { - s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); - } + s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); + // S{j} const auto s = tile_elementwise_in(type_convert, s_acc); diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 92b98a2f11..4a8c2beeb7 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -58,8 +58,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); ignore = a_element_func; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index cdce1b1f31..2cafb715a2 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -91,6 +91,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + return MakeARegBlockDescriptor(); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp index 4229db5250..234ff8821a 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -140,7 +140,7 @@ struct FlashAttentionFwdImpl // Q/K/V DRAM and DRAM window const auto q_dram = make_naive_tensor_view( - q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); + q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); const auto k_dram = make_naive_tensor_view( k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{}); @@ -149,7 +149,10 @@ struct FlashAttentionFwdImpl v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{}); auto q_dram_window = make_tile_window( - q_dram, make_tuple(number{}, number{}), {iM0, 0}); + q_dram, + make_tuple(number{}, number{}), + {iM0, 0}, + BlockGemm0Policy::template MakeADramTileDistribution()); auto k_dram_window = make_tile_window( k_dram, make_tuple(number{}, number{}), {0, 0}); @@ -160,9 +163,8 @@ struct FlashAttentionFwdImpl {iN1, 0}, MakeVDramTileDistribution()); - // Q in Register - auto q_reg_tensor = make_static_distributed_tensor( - BlockGemm0Policy::template MakeARegBlockDescriptor()); + // Q in register + auto q_reg_tensor = load_tile(q_dram_window); // V LDS and LDS window // V LDS occupies the same LDS allocation Q/K LDS @@ -212,15 +214,10 @@ struct FlashAttentionFwdImpl // loop over Column of S (J loop) index_t iN0 = 0; - // Cold Q_Reg_Cache - s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr); do { - // Hot Q_Reg_Cache - if(iN0 > 0) - { - s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); - } + s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); + // S{j} const auto s = tile_elementwise_in(type_convert, s_acc); From 8a6cc0e94b50bb5070ddeeb659c0f8b094ce29fd Mon Sep 17 00:00:00 2001 From: YC Lin Date: Mon, 21 Apr 2025 16:44:23 +0000 Subject: [PATCH 03/21] [GEMM] Fix bWarpTile issue and remove redundant pipeline in BlockGemmPipeline --- .../02_gemm/block_gemm_asmem_bsmem_creg.hpp | 34 ++-- ...k_gemm_asmem_bsmem_creg_default_policy.hpp | 40 ++++- .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 156 +++++------------- 3 files changed, 88 insertions(+), 142 deletions(-) diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp index 7f6ec4307f..4419d31e28 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp @@ -43,7 +43,7 @@ struct BlockGemmASmemBSmemCReg static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; -#if defined(ENABLE_INSTRUCTION_SCH) +#if defined(ENABLE_PREFETCH) // A block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { @@ -92,16 +92,16 @@ struct BlockGemmASmemBSmemCReg using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - ALdsTile a_warp_tile_; - ALdsTile b_warp_tile_; + ALdsTile aWarpTile; + BLdsTile bWarpTile; // Prefetch from LDS to warp register template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window) { - load_tile(a_warp_tile_, a_block_window); - load_tile(b_warp_tile_, b_block_window); + aWarpTile = load_tile(a_block_window); + bWarpTile = load_tile(b_block_window); } #endif @@ -178,23 +178,23 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; -#if defined(ENABLE_INSTRUCTION_SCH) +#if defined(ENABLE_PREFETCH) #pragma message ("local data share prefetch") - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); #else - load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter)); + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B block tensor BWarpTensor b_warp_tensor; -#if defined(ENABLE_INSTRUCTION_SCH) - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); #else - load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter)); + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -305,22 +305,22 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; -#if defined(ENABLE_INSTRUCTION_SCH) - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); #else - load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter)); + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B block tensor BWarpTensor b_warp_tensor; -#if defined(ENABLE_INSTRUCTION_SCH) - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); #else - load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter)); + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp index a6bc4e7563..ec71f9c76a 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -17,19 +17,31 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { +#if defined(ADJUST_BLOCK_TILE_SHAPE) + constexpr index_t kMWarp = 2; + constexpr index_t kNWarp = 2; +#else + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; +#endif + #if defined(NAIVE_IMPLEMENTATION) #pragma message ("mfma m32 n32 k8") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, + kMWarp, + kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, + kMWarp, + kNWarp); } #elif defined(USING_MFMA_32x32x_8x2) #pragma message ("mfma m32 n32 k16") @@ -37,13 +49,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, + kMWarp, + kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, + kMWarp, + kNWarp); } #elif defined(USING_MFMA_16x16x16) #pragma message ("mfma m16 n16 k16") @@ -51,13 +67,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, + kMWarp, + kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, + kMWarp, + kNWarp); } #elif defined(USING_MFMA_16x16x_16x2) #pragma message ("mfma m16 n16 k32") @@ -65,13 +85,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, + kMWarp, + kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, 2, 2); + return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, + kMWarp, + kNWarp); } #endif else diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index d96f27b66b..26d5618330 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -42,15 +42,8 @@ struct BlockGemmPipelineAGmemBGmemCReg } #if defined(ENABLE_INSTRUCTION_SCH) - static constexpr index_t APackedSize = + static constexpr index_t kPackedSize = ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; - - static constexpr index_t BlockSize = Problem::kBlockSize; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -72,31 +65,31 @@ struct BlockGemmPipelineAGmemBGmemCReg constexpr index_t AB_LDS_RW_Width = GetSmemPack(); constexpr index_t A_Buffer_Load_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA()); constexpr index_t B_Buffer_Load_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB()); constexpr index_t A_LDS_Write_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t B_LDS_Write_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t A_LDS_Read_Inst_Num = - WaveNumN * MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); - constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / - (BlockSize / WaveSize) / + constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); // A/B split schedule // compiler is likely to use ds_read2 when instruction width smaller than 16bytes constexpr auto num_ds_read_inst_a = - AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num : + AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? A_LDS_Read_Inst_Num : A_LDS_Read_Inst_Num / 2; constexpr auto num_ds_read_inst_b = - AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num : + AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? B_LDS_Read_Inst_Num : B_LDS_Read_Inst_Num / 2; constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; @@ -109,9 +102,9 @@ struct BlockGemmPipelineAGmemBGmemCReg constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; constexpr auto ds_read_a_issue_cycle = - AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4; constexpr auto ds_read_b_issue_cycle = - AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; + AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; constexpr auto ds_read_a_mfma_rate = (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); constexpr auto ds_read_b_mfma_rate = @@ -266,15 +259,15 @@ struct BlockGemmPipelineAGmemBGmemCReg {0, 0}, b_copy_dram_window.get_tile_distribution()); -#if defined(ENABLE_INSTRUCTION_SCH) +#if defined(ENABLE_PREFETCH) // A LDS tile for block GEMM auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}, + a_lds_block, make_tuple(number{}, number{}), {0, 0}, make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode())); // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}, + b_lds_block, make_tuple(number{}, number{}), {0, 0}, make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode())); #else // A LDS tile for block GEMM @@ -303,23 +296,23 @@ struct BlockGemmPipelineAGmemBGmemCReg ABlockTile a_block_tile; BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock); // ------------------------------------------------------------------------------------- // Gemm pipeline start -#if defined(ENABLE_INSTRUCTION_SCH) - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); +#if defined(ENABLE_PREFETCH) // Initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // Prefetch // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); if (num_loop > 1) { @@ -331,8 +324,8 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(b_copy_lds_window, b_block_tile); // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); @@ -357,8 +350,8 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(b_copy_lds_window, b_block_tile); // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); @@ -369,12 +362,14 @@ struct BlockGemmPipelineAGmemBGmemCReg // Prefetch from LDS to warp register in block gemm block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); +#if defined(ENABLE_INSTRUCTION_SCH) HotLoopScheduler(); +#endif __builtin_amdgcn_sched_barrier(0); - i += 1; - } while(i < (num_loop - 2)); + iCounter += 1; + } while(iCounter < (num_loop - 2)); } // Tail @@ -388,84 +383,12 @@ struct BlockGemmPipelineAGmemBGmemCReg block_sync_lds(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - -#elif defined(ENABLE_PREFETCH) - // Prefetch - // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); - - { - // Move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // Initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - store_tile(a_copy_lds_window, a_block_tile); - // Global read 1 - load_tile(a_block_tile, a_copy_dram_window); - - // LDS write 0 - store_tile(b_copy_lds_window, b_block_tile); - // Global read 1 - load_tile(b_block_tile, b_copy_dram_window); - } - - index_t iCounter = num_loop - 2; - - do - { - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // Move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - store_tile(a_copy_lds_window, a_block_tile); - // Global read i + 2 - load_tile(a_block_tile, a_copy_dram_window); - - // LDS write i + 1 - store_tile(b_copy_lds_window, b_block_tile); - // Global read i + 2 - load_tile(b_block_tile, b_copy_dram_window); - - iCounter--; - - } while(iCounter > 0); - - // Tail - { - block_sync_lds(); - - // GEMM num_loop - 2 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // LDS write num_loop - 1 - store_tile(a_copy_lds_window, a_block_tile); - - store_tile(b_copy_lds_window, b_block_tile); - - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } #else // non-prefetch - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); @@ -477,11 +400,10 @@ struct BlockGemmPipelineAGmemBGmemCReg while (iCounter > 0) { - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); From 252b72ec30c7939a62947748444caab018d745ba Mon Sep 17 00:00:00 2001 From: MHYang Date: Mon, 21 Apr 2025 16:44:56 +0000 Subject: [PATCH 04/21] Fix bank conflict --- ..._pipeline_agmem_bgmem_creg_policy_impl.hpp | 45 ++++++++++++++---- .../flash_attention_fwd_impl.hpp | 46 ++++++++++++++----- 2 files changed, 70 insertions(+), 21 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp index ad6d6d3996..3ab79e23b1 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp @@ -36,20 +36,47 @@ __host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() { constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, number<1>{}); - constexpr auto b_lds_block_desc = - transform_tensor_descriptor(b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 234ff8821a..94016025c4 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -68,23 +68,45 @@ struct FlashAttentionFwdImpl { constexpr index_t kNPerBlock = kN1PerBlock; constexpr index_t kKPerBlock = kK1PerBlock; - constexpr index_t kPad = 1; - // 2% faster than use kK1 = 8 - constexpr index_t kK1 = 4; + constexpr index_t kKPack = 4; + + constexpr auto dataTypeSize = sizeof(VDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize); constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(kNPerBlock + kPad) * kK1>{}, number{}, number<1>{}), - number{}, + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, number<1>{}); - constexpr auto b_lds_block_desc = transform_tensor_descriptor( + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } @@ -140,7 +162,7 @@ struct FlashAttentionFwdImpl // Q/K/V DRAM and DRAM window const auto q_dram = make_naive_tensor_view( - q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); + q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); const auto k_dram = make_naive_tensor_view( k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{}); From 30e4b12ef49ebc6053fd91f59588069e71849e0d Mon Sep 17 00:00:00 2001 From: MHYang Date: Mon, 21 Apr 2025 16:50:56 +0000 Subject: [PATCH 05/21] Merge fix for bank conflict into codegen FA --- ..._pipeline_agmem_bgmem_creg_policy_impl.hpp | 45 ++++++++++++++---- .../flash_attention_fwd_impl.hpp | 46 ++++++++++++++----- 2 files changed, 70 insertions(+), 21 deletions(-) diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp index ad6d6d3996..3ab79e23b1 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp @@ -36,20 +36,47 @@ __host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() { constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, number<1>{}); - constexpr auto b_lds_block_desc = - transform_tensor_descriptor(b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp index 234ff8821a..94016025c4 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -68,23 +68,45 @@ struct FlashAttentionFwdImpl { constexpr index_t kNPerBlock = kN1PerBlock; constexpr index_t kKPerBlock = kK1PerBlock; - constexpr index_t kPad = 1; - // 2% faster than use kK1 = 8 - constexpr index_t kK1 = 4; + constexpr index_t kKPack = 4; + + constexpr auto dataTypeSize = sizeof(VDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize); constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(kNPerBlock + kPad) * kK1>{}, number{}, number<1>{}), - number{}, + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, number<1>{}); - constexpr auto b_lds_block_desc = transform_tensor_descriptor( + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } @@ -140,7 +162,7 @@ struct FlashAttentionFwdImpl // Q/K/V DRAM and DRAM window const auto q_dram = make_naive_tensor_view( - q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); + q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); const auto k_dram = make_naive_tensor_view( k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{}); From de097a54d6fb10e2e952d6157ffa6e756369551b Mon Sep 17 00:00:00 2001 From: BoboFang Date: Tue, 22 Apr 2025 13:55:21 +0000 Subject: [PATCH 06/21] Add cache-aware in flash attention --- .../flash_attention_fwd.hpp | 74 +++++++++++++++++-- .../flash_attention_fwd.hpp | 22 ++++-- 2 files changed, 84 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp index caeeece8e9..085366c0b3 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp @@ -15,6 +15,58 @@ namespace ck_tile { +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, + index_t N0) +{ + return [=](index_t block_1d_id) { + constexpr index_t M01 = 4; + constexpr index_t GroupNum = 8; + + const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; + const auto update_M0 = ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + + const auto xcd_id = block_1d_id % GroupNum; + + const auto l_block_id = block_1d_id - (xcd_id % 2); + + const auto ridn = GroupNum * M01 * (update_N0 / 2); + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; + + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01)+ lu ) ) / GroupNum; + + auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); + auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); + + const auto total_update_size = update_N0 * update_M0; + + if (block_1d_id >= total_update_size) { + auto x = (block_1d_id + 1) - total_update_size; + auto rlen = N0 - update_N0; + + auto rm = 0; + auto rn = 0; + if (rlen > 0) { + rm = (x - 1) / rlen; + rn = x % rlen; + } + + if (rlen > 0 and rm < M0) { + n = rn + update_N0; + m = rm; + } else { + x = x - rlen * M0; + rm = (x - 1) / update_N0; + rn = x % update_N0; + n = rn; + m = update_M0 + rm; + } + } + return make_multi_index(m, n); + }; +} + // S[M0, N0] = Q[M0, K0] * K[N0, K0] // P[M0, N0] = Softmax(S[M0, N0]) // O[M0, N1] = P[M0, N0] * V[N1, N0] @@ -53,26 +105,36 @@ struct FlashAttentionFwd const index_t BatchStrideV, const index_t BatchStrideO) const { - // divide problem - const index_t num_tile_m0 = M0 / kM0PerBlock; - const index_t num_tile_n1 = N1 / kN1PerBlock; - const index_t id_block = get_block_id(); + const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock); + const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock); + +#if defined(GEMM_OPT) + const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1); + + const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0; + const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0); + + const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) % num_tile_m0 * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % num_tile_n1 * kN1PerBlock); + +#else const auto f = [](index_t dividend, index_t divisor) { index_t quotient = dividend / divisor; index_t modulus = dividend - quotient * divisor; return make_tuple(quotient, modulus); }; - const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); - const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); +#endif + const auto kernel_impl = FlashAttentionFwdImpl{}) % num_tile_m0 * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % num_tile_n1 * kN1PerBlock); + +#else const auto f = [](index_t dividend, index_t divisor) { index_t quotient = dividend / divisor; index_t modulus = dividend - quotient * divisor; return make_tuple(quotient, modulus); }; - const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); - const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); +#endif + const auto kernel_impl = FlashAttentionFwdImpl Date: Tue, 22 Apr 2025 14:43:38 +0000 Subject: [PATCH 07/21] Initialize instruction schedule --- .../block_gemm_areg_bsmem_creg_v1.hpp | 94 +++++++++++++++++++ ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 9 ++ .../flash_attention_fwd_impl.hpp | 4 + 3 files changed, 107 insertions(+) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 740c540d6c..d502655210 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -26,6 +26,100 @@ struct BlockGemmARegBSmemCRegV1 static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + template + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MPerXDL = WG::kM; + constexpr index_t NPerXDL = WG::kN; + constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = config.template get<1>(); + + constexpr index_t B_LDS_RW_Width = SmemPack; + + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * VectorSizeB); + + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / + (kBlockSize * B_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (kBlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // B split schedule + constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_b_issue_cycle = + B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + // C += A * B template __device__ void operator()(CBlockTensor& c_block_tensor, diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 4a8c2beeb7..d54b460b71 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -134,6 +134,8 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_block_tile = load_tile(b_copy_dram_window); } + __builtin_amdgcn_sched_barrier(0); + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -158,6 +160,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } @@ -276,6 +281,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_block_tile = load_tile(b_copy_dram_window); } + __builtin_amdgcn_sched_barrier(0); if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -293,6 +299,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 94016025c4..0d734aca3d 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -310,6 +310,7 @@ struct FlashAttentionFwdImpl if constexpr(k1_loops > 1) { + __builtin_amdgcn_sched_barrier(0); static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); @@ -321,6 +322,9 @@ struct FlashAttentionFwdImpl block_sync_lds(); store_tile(v_lds_window, v); move_tile_window(v_dram_window, {0, kK1PerBlock}); + + gemm1.template HotLoopScheduler<8, 4>(); + __builtin_amdgcn_sched_barrier(0); }); } // tail From 95a8ac00c6668d03f0c42fd3d6951b3703c84fc6 Mon Sep 17 00:00:00 2001 From: BoboFang Date: Tue, 22 Apr 2025 14:49:37 +0000 Subject: [PATCH 08/21] Run clang-format in toy_example --- .../99_toy_example/01_add/CMakeLists.txt | 33 +-- example/ck_tile/99_toy_example/01_add/add.cpp | 8 +- example/ck_tile/99_toy_example/01_add/add.hpp | 32 ++- .../99_toy_example/01_add/reference_add.hpp | 11 +- .../99_toy_example/02_gemm/CMakeLists.txt | 41 ++-- .../02_gemm/block_gemm_asmem_bsmem_creg.hpp | 61 +++-- ...k_gemm_asmem_bsmem_creg_default_policy.hpp | 48 ++-- .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 110 +++++---- ...peline_agmem_bgmem_creg_default_policy.hpp | 80 +++--- .../ck_tile/99_toy_example/02_gemm/config.h | 55 ++--- .../ck_tile/99_toy_example/02_gemm/gemm.cpp | 41 ++-- .../ck_tile/99_toy_example/02_gemm/gemm.hpp | 66 ++--- .../99_toy_example/02_gemm/grid_gemm.hpp | 6 +- .../99_toy_example/02_gemm/reference_gemm.hpp | 10 +- .../03_flash_attention_fwd/CMakeLists.txt | 31 +-- .../block_gemm_areg_bsmem_creg_v1.hpp | 100 ++++---- ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 11 +- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 14 +- .../block_gemm_pipeline_problem.hpp | 1 - ..._pipeline_agmem_bgmem_creg_policy_impl.hpp | 25 +- .../flash_attention_fwd.cpp | 123 +++++----- .../flash_attention_fwd.hpp | 49 ++-- .../flash_attention_fwd_impl.hpp | 69 +++--- .../reference_batched_gemm.hpp | 7 +- .../reference_batched_softmax.hpp | 7 +- .../CMakeLists.txt | 85 +++---- .../block_gemm_areg_bsmem_creg_v1.hpp | 88 ++++--- ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 11 +- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 14 +- .../block_gemm_pipeline_problem.hpp | 1 - ..._pipeline_agmem_bgmem_creg_policy_impl.hpp | 25 +- .../flash_attention_fwd.cpp | 82 +++---- .../flash_attention_fwd.hpp | 91 ++++--- .../flash_attention_fwd_impl.hpp | 69 +++--- .../generate.py | 230 +++++++++--------- .../reference_batched_gemm.hpp | 7 +- .../reference_batched_softmax.hpp | 7 +- 37 files changed, 890 insertions(+), 859 deletions(-) mode change 100755 => 100644 example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt mode change 100755 => 100644 example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt diff --git a/example/ck_tile/99_toy_example/01_add/CMakeLists.txt b/example/ck_tile/99_toy_example/01_add/CMakeLists.txt index f241173fe7..ad51b75885 100644 --- a/example/ck_tile/99_toy_example/01_add/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/01_add/CMakeLists.txt @@ -1,22 +1,23 @@ set(EXAMPLE_REDUCE "add") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" -message("adding example ${EXAMPLE_REDUCE}") +#not using add_example_executable() to add this target, since we don't want this to have +#to be included in "make all/install/check" + message("adding example ${EXAMPLE_REDUCE}") -add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL add.cpp) -target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -set(EXAMPLE_REDUCE_COMPILE_OPTIONS) + add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL add.cpp) target_include_directories(${ + EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) set(EXAMPLE_REDUCE_COMPILE_OPTIONS) -# generate assembly -# list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +#generate assembly +#list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - v-- save - temps - Wno - gnu - line - marker) -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +#NOTE : we turn off undefined - func - \ + template to let source compile without explicit declare function specializations + list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - Wno - undefined - func - template - Wno - + float - equal) -target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) + target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) -# TODO: we have to turn off this global prop, otherwise the progress bar generated -# by cmake will print too many files, execvp: /bin/sh: Argument list too long -# however, this property may affect global -# TODO: consider codegen a makefile by us -set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) +#TODO : we have to turn off this global prop, otherwise the progress bar generated +#by cmake will print too many files, execvp : / bin / sh : Argument list too long +#however, this property may affect global +#TODO : consider codegen a makefile by us + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/99_toy_example/01_add/add.cpp b/example/ck_tile/99_toy_example/01_add/add.cpp index 4e50e7d526..cd4ce141dc 100644 --- a/example/ck_tile/99_toy_example/01_add/add.cpp +++ b/example/ck_tile/99_toy_example/01_add/add.cpp @@ -57,9 +57,8 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "block x-size = " << BlockTile::at(ck_tile::number<0>{}) << std::endl; std::cout << "grid size " << kGridSize << std::endl; - using Shape = ck_tile::AddShape; - using Porblem = - ck_tile::AddProblem; + using Shape = ck_tile::AddShape; + using Porblem = ck_tile::AddProblem; using Kernel = ck_tile::Add; @@ -85,8 +84,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { - ck_tile::reference_add( - x_host_a, x_host_b, y_host_ref); + ck_tile::reference_add(x_host_a, x_host_b, y_host_ref); y_buf.FromDevice(y_host_dev.mData.data()); pass = ck_tile::check_err(y_host_dev, y_host_ref); diff --git a/example/ck_tile/99_toy_example/01_add/add.hpp b/example/ck_tile/99_toy_example/01_add/add.hpp index f77301be06..87eb86e0c4 100644 --- a/example/ck_tile/99_toy_example/01_add/add.hpp +++ b/example/ck_tile/99_toy_example/01_add/add.hpp @@ -36,10 +36,7 @@ struct AddShape warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); }; -template +template struct AddProblem { using XDataType = remove_cvref_t; @@ -76,7 +73,8 @@ struct Add using ComputeDataType = ck_tile::remove_cvref_t; using YDataType = ck_tile::remove_cvref_t; - CK_TILE_DEVICE void operator()(const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t M, index_t N) const + CK_TILE_DEVICE void operator()( + const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t M, index_t N) const { using S = typename Problem::BlockShape; @@ -98,14 +96,14 @@ struct Add const auto iM = get_block_id() * S::Block_M; auto x_window_a = make_tile_window(x_m_n_a, - make_tuple(number{}, number{}), - {iM, 0}, - Policy::template MakeXBlockTileDistribution()); + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); auto x_window_b = make_tile_window(x_m_n_b, - make_tuple(number{}, number{}), - {iM, 0}, - Policy::template MakeXBlockTileDistribution()); + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); auto y_window = make_tile_window(y_m_n, make_tuple(number{}, number{}), @@ -117,17 +115,17 @@ struct Add for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - const auto xa = load_tile(x_window_a); - const auto xb = load_tile(x_window_b); + const auto xa = load_tile(x_window_a); + const auto xb = load_tile(x_window_b); auto y_compute = load_tile(y_window); constexpr auto spans = decltype(xa)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); - const auto x = ck_tile::type_convert(xa[i_j_idx]); - const auto y = ck_tile::type_convert(xb[i_j_idx]); - y_compute(i_j_idx) = x + y; + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + const auto x = ck_tile::type_convert(xa[i_j_idx]); + const auto y = ck_tile::type_convert(xb[i_j_idx]); + y_compute(i_j_idx) = x + y; }); }); diff --git a/example/ck_tile/99_toy_example/01_add/reference_add.hpp b/example/ck_tile/99_toy_example/01_add/reference_add.hpp index 26a72286da..a1e09f1c84 100644 --- a/example/ck_tile/99_toy_example/01_add/reference_add.hpp +++ b/example/ck_tile/99_toy_example/01_add/reference_add.hpp @@ -10,19 +10,22 @@ namespace ck_tile { template -CK_TILE_HOST void -reference_add(const HostTensor& xa_m_n, const HostTensor& xb_m_n, HostTensor& y_m_n) +CK_TILE_HOST void reference_add(const HostTensor& xa_m_n, + const HostTensor& xb_m_n, + HostTensor& y_m_n) { auto f = [&](auto m) { const int N = xa_m_n.mDesc.get_lengths()[1]; for(int n = 0; n < N; ++n) { - y_m_n(m, n) = ck_tile::type_convert(xa_m_n(m, n)) + ck_tile::type_convert(xb_m_n(m, n)); + y_m_n(m, n) = ck_tile::type_convert(xa_m_n(m, n)) + + ck_tile::type_convert(xb_m_n(m, n)); } }; - make_ParallelTensorFunctor(f, y_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f, + y_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); } } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt b/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt index 02863398fe..afb32d49a9 100644 --- a/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt @@ -1,27 +1,28 @@ set(EXAMPLE_REDUCE "basic_gemm") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" -message("adding example ${EXAMPLE_REDUCE}") +#not using add_example_executable() to add this target, since we don't want this to have +#to be included in "make all/install/check" + message("adding example ${EXAMPLE_REDUCE}") -add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL gemm.cpp) -target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -set(EXAMPLE_REDUCE_COMPILE_OPTIONS) + add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL gemm.cpp) target_include_directories(${ + EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) set(EXAMPLE_REDUCE_COMPILE_OPTIONS) -# generate assembly -# list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +#generate assembly +#list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - v-- save - temps - Wno - gnu - line - marker) -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +#NOTE : we turn off undefined - func - \ + template to let source compile without explicit declare function specializations + list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - Wno - undefined - func - template - Wno - + float - equal) -if(DEFINED kernel) - message("Compiling with Kernel: ${kernel}") - target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE KERNEL_${kernel}=1) -endif() + if(DEFINED kernel) message("Compiling with Kernel: ${kernel}") + target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE KERNEL_${kernel} = 1) + endif() -target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) + target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${ + EXAMPLE_REDUCE_COMPILE_OPTIONS}) -# TODO: we have to turn off this global prop, otherwise the progress bar generated -# by cmake will print too many files, execvp: /bin/sh: Argument list too long -# however, this property may affect global -# TODO: consider codegen a makefile by us -set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) +#TODO : we have to turn off this global prop, otherwise the progress bar generated +#by cmake will print too many files, execvp : / bin / sh : Argument list too long +#however, this property may affect global +#TODO : consider codegen a makefile by us + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp index 4419d31e28..0fb5d2de7d 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp @@ -20,9 +20,12 @@ struct BlockGemmASmemBSmemCReg using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using WarpGemm = remove_cvref_t().template get<0>())>; - static constexpr index_t MWarp = Policy::template GetWarpGemmMWarpNWarp().template get<1>(); - static constexpr index_t NWarp = Policy::template GetWarpGemmMWarpNWarp().template get<2>(); + using WarpGemm = remove_cvref_t< + decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; + static constexpr index_t MWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<2>(); using AWarpDstr = typename WarpGemm::AWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr; @@ -98,7 +101,7 @@ struct BlockGemmASmemBSmemCReg // Prefetch from LDS to warp register template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window) { aWarpTile = load_tile(a_block_window); bWarpTile = load_tile(b_block_window); @@ -112,15 +115,17 @@ struct BlockGemmASmemBSmemCReg const BBlockWindowTmp& b_block_window_tmp) const { static_assert(std::is_same_v && - std::is_same_v && - std::is_same_v, "wrong!"); + std::is_same_v && + std::is_same_v, + "wrong!"); constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, "wrong!"); + KPerBlock == BlockGemmShape::kK, + "wrong!"); // constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -141,10 +146,14 @@ struct BlockGemmASmemBSmemCReg auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), - {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); - statically_indexed_array, MIterPerWarp> a_warp_windows; + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -159,10 +168,14 @@ struct BlockGemmASmemBSmemCReg auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); - statically_indexed_array, NIterPerWarp> b_warp_windows; + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -179,7 +192,7 @@ struct BlockGemmASmemBSmemCReg // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; #if defined(ENABLE_PREFETCH) -#pragma message ("local data share prefetch") +#pragma message("local data share prefetch") a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); @@ -222,14 +235,16 @@ struct BlockGemmASmemBSmemCReg const BBlockWindowTmp& b_block_window_tmp) const { static_assert(std::is_same_v && - std::is_same_v, "wrong!"); + std::is_same_v, + "wrong!"); constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, "wrong!"); + KPerBlock == BlockGemmShape::kK, + "wrong!"); // constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -250,10 +265,14 @@ struct BlockGemmASmemBSmemCReg auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), - {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); - statically_indexed_array, MIterPerWarp> a_warp_windows; + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -268,10 +287,14 @@ struct BlockGemmASmemBSmemCReg auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); - statically_indexed_array, NIterPerWarp> b_warp_windows; + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -354,4 +377,4 @@ struct BlockGemmASmemBSmemCReg } }; -} // namespace ck +} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp index ec71f9c76a..2fdb63794f 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -26,76 +26,68 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy #endif #if defined(NAIVE_IMPLEMENTATION) -#pragma message ("mfma m32 n32 k8") +#pragma message("mfma m32 n32 k8") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, - kMWarp, - kNWarp); + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, - kMWarp, - kNWarp); + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); } #elif defined(USING_MFMA_32x32x_8x2) -#pragma message ("mfma m32 n32 k16") +#pragma message("mfma m32 n32 k16") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, - kMWarp, - kNWarp); + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, - kMWarp, - kNWarp); + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); } #elif defined(USING_MFMA_16x16x16) -#pragma message ("mfma m16 n16 k16") +#pragma message("mfma m16 n16 k16") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, - kMWarp, - kNWarp); + return make_tuple( + WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, - kMWarp, - kNWarp); + return make_tuple( + WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); } #elif defined(USING_MFMA_16x16x_16x2) -#pragma message ("mfma m16 n16 k32") +#pragma message("mfma m16 n16 k32") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, - kMWarp, - kNWarp); + return make_tuple( + WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, - kMWarp, - kNWarp); + return make_tuple( + WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp); } #endif else diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index 26d5618330..57a0614c7f 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -80,17 +80,16 @@ struct BlockGemmPipelineAGmemBGmemCReg WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock / - (kBlockSize / WaveSize) / - (MPerXDL * NPerXDL * KPerXDL); + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); // A/B split schedule // compiler is likely to use ds_read2 when instruction width smaller than 16bytes - constexpr auto num_ds_read_inst_a = - AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? A_LDS_Read_Inst_Num : - A_LDS_Read_Inst_Num / 2; - constexpr auto num_ds_read_inst_b = - AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? B_LDS_Read_Inst_Num : - B_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 + ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; @@ -123,8 +122,7 @@ struct BlockGemmPipelineAGmemBGmemCReg // ? sizeof(ComputeDataType) / // sizeof(ADataType) : sizeof(ComputeDataType) // / sizeof(BDataType); - constexpr auto num_mfma_stage1 = - num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; @@ -138,62 +136,65 @@ struct BlockGemmPipelineAGmemBGmemCReg ignore = i; static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_a, 0); // MFMA }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, - num_mfma_per_issue - num_mfma_per_dswrite_a * num_dswrite_per_issue_a, 0); // MFMA + num_mfma_per_issue - num_mfma_per_dswrite_a * + num_dswrite_per_issue_a, + 0); // MFMA }); static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { ignore = i; static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, - num_mfma_per_issue - num_mfma_per_dswrite_b * num_dswrite_per_issue_b, 0); // MFMA + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA }); // stage 2 static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= - ds_read_a_mfma_rate) + ds_read_a_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read } else { - __builtin_amdgcn_sched_group_barrier( - 0x100, - num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, - 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read } __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= - ds_read_b_mfma_rate) + ds_read_b_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read } else { - __builtin_amdgcn_sched_group_barrier( - 0x100, - num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, - 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read } __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); } #endif - template + template CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, @@ -201,12 +202,12 @@ struct BlockGemmPipelineAGmemBGmemCReg { static_assert( std::is_same_v> && - std::is_same_v>, + std::is_same_v>, "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); // ----------------------------------------------------------------------------------------- @@ -220,8 +221,8 @@ struct BlockGemmPipelineAGmemBGmemCReg auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), - 16) * 16; + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; // B tile in LDS BDataType* p_b_lds = static_cast( @@ -234,40 +235,44 @@ struct BlockGemmPipelineAGmemBGmemCReg // A DRAM tile window for load auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); // A LDS tile window for store auto a_copy_lds_window = make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_copy_dram_window.get_tile_distribution()); + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); // B DRAM tile window for load auto b_copy_dram_window = make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); // B LDS tile window for store auto b_copy_lds_window = make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_copy_dram_window.get_tile_distribution()); + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); #if defined(ENABLE_PREFETCH) // A LDS tile for block GEMM auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}, + a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode())); // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}, + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode())); #else // A LDS tile for block GEMM @@ -285,14 +290,11 @@ struct BlockGemmPipelineAGmemBGmemCReg // Acc register tile auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); - using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); - using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); ABlockTile a_block_tile; BBlockTile b_block_tile; @@ -314,7 +316,7 @@ struct BlockGemmPipelineAGmemBGmemCReg a_block_tile = load_tile(a_copy_dram_window); b_block_tile = load_tile(b_copy_dram_window); - if (num_loop > 1) + if(num_loop > 1) { move_tile_window(a_copy_dram_window, a_dram_tile_window_step); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); @@ -338,7 +340,7 @@ struct BlockGemmPipelineAGmemBGmemCReg __builtin_amdgcn_sched_barrier(0); // Main body - if (num_loop > 2) + if(num_loop > 2) { index_t i = 0; do @@ -373,7 +375,7 @@ struct BlockGemmPipelineAGmemBGmemCReg } // Tail - if (num_loop > 1) + if(num_loop > 1) { block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); @@ -398,7 +400,7 @@ struct BlockGemmPipelineAGmemBGmemCReg index_t iCounter = num_loop - 1; - while (iCounter > 0) + while(iCounter > 0) { a_block_tile = load_tile(a_copy_dram_window); b_block_tile = load_tile(b_copy_dram_window); diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp index 8149fb4132..30bdcd679f 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -38,7 +38,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto a_lds_block_desc = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -52,7 +52,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto a_lds_block_desc = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -66,7 +66,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto a_lds_block_desc = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -79,8 +79,8 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, - number{}, - number{}), + number{}, + number{}), make_tuple(number{}, number{}, number<1>{}), number{}, number<1>{}); @@ -88,26 +88,26 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), + number{})), + make_pass_through_transform(number{})), make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{})); constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( a_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto a_lds_block_desc = transform_tensor_descriptor( a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); #endif @@ -132,7 +132,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_0, make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -146,7 +146,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_0, make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -160,7 +160,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_0, make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -173,8 +173,8 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, - number{}, - number{}), + number{}, + number{}), make_tuple(number{}, number{}, number<1>{}), number{}, number<1>{}); @@ -182,26 +182,26 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), + number{})), + make_pass_through_transform(number{})), make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{})); constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); #endif @@ -228,11 +228,11 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); } template @@ -254,11 +254,11 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); } #if defined(ENABLE_INSTRUCTION_SCH) @@ -313,7 +313,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { - using ADataType = remove_cvref_t; + using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; return GetGlobalVectorLoadSize(); @@ -322,7 +322,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { - using BDataType = remove_cvref_t; + using BDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; return GetGlobalVectorLoadSize(); diff --git a/example/ck_tile/99_toy_example/02_gemm/config.h b/example/ck_tile/99_toy_example/02_gemm/config.h index fc70599520..235337df7b 100644 --- a/example/ck_tile/99_toy_example/02_gemm/config.h +++ b/example/ck_tile/99_toy_example/02_gemm/config.h @@ -1,39 +1,38 @@ #if defined(KERNEL_A) - #define PADDING_K_FIRST - #define USING_MFMA_32x32x_8x2 +#define PADDING_K_FIRST +#define USING_MFMA_32x32x_8x2 #elif defined(KERNEL_B) - #define PADDING_K_FIRST - #define USING_MFMA_16x16x16 +#define PADDING_K_FIRST +#define USING_MFMA_16x16x16 #elif defined(KERNEL_C) - #define PADDING_K_FIRST - #define USING_MFMA_16x16x_16x2 +#define PADDING_K_FIRST +#define USING_MFMA_16x16x_16x2 #elif defined(KERNEL_D) - #define USING_MFMA_16x16x_16x2 - #define USING_XOR_BASED_BANK_CONFLICT_FREE +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE #elif defined(KERNEL_E) - #define USING_MFMA_16x16x_16x2 - #define USING_XOR_BASED_BANK_CONFLICT_FREE - #define ADJUST_BLOCK_TILE_SHAPE +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE #elif defined(KERNEL_F) - #define USING_MFMA_16x16x_16x2 - #define USING_XOR_BASED_BANK_CONFLICT_FREE - #define ADJUST_BLOCK_TILE_SHAPE - #define ENABLE_PREFETCH +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#define ENABLE_PREFETCH #elif defined(KERNEL_G) - #define USING_MFMA_16x16x_16x2 - #define USING_XOR_BASED_BANK_CONFLICT_FREE - #define ADJUST_BLOCK_TILE_SHAPE - #define ENABLE_PREFETCH - #define ENABLE_INSTRUCTION_SCH +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#define ENABLE_PREFETCH +#define ENABLE_INSTRUCTION_SCH #elif defined(KERNEL_H) - #define USING_MFMA_16x16x_16x2 - #define USING_XOR_BASED_BANK_CONFLICT_FREE - #define ADJUST_BLOCK_TILE_SHAPE - #define ENABLE_PREFETCH - #define ENABLE_INSTRUCTION_SCH - #define ENABLE_CACHE_AWARE_WG_SCH +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#define ENABLE_PREFETCH +#define ENABLE_INSTRUCTION_SCH +#define ENABLE_CACHE_AWARE_WG_SCH #else - #define NAIVE_IMPLEMENTATION +#define NAIVE_IMPLEMENTATION #endif - diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp index 100dd82f61..aee29d1aa7 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp @@ -158,32 +158,30 @@ int main(int argc, char* argv[]) kGemmNPerBlock, kGemmKPerBlock>; - float ave_time = - ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000}, - ck_tile::make_kernel( - gemm_kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(a_buf.GetDeviceBuffer()), - static_cast(b_buf.GetDeviceBuffer()), - static_cast(c_buf.GetDeviceBuffer()), - M, - N, - K, - Lda, - Ldb, - Ldc, - CElementFunction{})); - auto pass = true; + float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000}, + ck_tile::make_kernel( + gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_buf.GetDeviceBuffer()), + static_cast(b_buf.GetDeviceBuffer()), + static_cast(c_buf.GetDeviceBuffer()), + M, + N, + K, + Lda, + Ldb, + Ldc, + CElementFunction{})); + auto pass = true; if(verification) { // reference gemm ck_tile::HostTensor c_host_ref(c_lengths, c_strides); - reference_basic_gemm(a_host, - b_host, - c_host_ref); + reference_basic_gemm( + a_host, b_host, c_host_ref); c_buf.FromDevice(c_host_dev.mData.data()); pass &= ck_tile::check_err(c_host_dev, c_host_ref); std::cout << "valid:" << (pass ? "y" : "n") << std::endl; @@ -202,4 +200,3 @@ int main(int argc, char* argv[]) return !pass; } - diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp index 8cc3013d5f..8b581204b7 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp @@ -67,11 +67,8 @@ template struct Gemm { - using GridGemmProblem = GridGemmProblem; + using GridGemmProblem = + GridGemmProblem; struct GridGemmPolicy { @@ -81,57 +78,63 @@ struct Gemm static constexpr index_t kKPerBlock = kKPerBlock_; template - CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, - index_t N0) + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) { #if defined(ENABLE_CACHE_AWARE_WG_SCH) return [=](index_t block_1d_id) { - constexpr index_t M01 = 4; + constexpr index_t M01 = 4; constexpr index_t GroupNum = 8; const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; - const auto update_M0 = ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + const auto update_M0 = + ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; const auto xcd_id = block_1d_id % GroupNum; const auto l_block_id = block_1d_id - (xcd_id % 2); const auto ridn = GroupNum * M01 * (update_N0 / 2); - const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; - const auto lu = (l_block_id % GroupNum) + rid * ridn; + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; - const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); - const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01)+ lu ) ) / GroupNum; + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = + (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum; auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); const auto total_update_size = update_N0 * update_M0; - if (block_1d_id >= total_update_size) { - auto x = (block_1d_id + 1) - total_update_size; + if(block_1d_id >= total_update_size) + { + auto x = (block_1d_id + 1) - total_update_size; auto rlen = N0 - update_N0; auto rm = 0; auto rn = 0; - if (rlen > 0) { + if(rlen > 0) + { rm = (x - 1) / rlen; rn = x % rlen; } - if (rlen > 0 and rm < M0) { + if(rlen > 0 and rm < M0) + { n = rn + update_N0; m = rm; - } else { - x = x - rlen * M0; + } + else + { + x = x - rlen * M0; rm = (x - 1) / update_N0; rn = x % update_N0; - n = rn; - m = update_M0 + rm; + n = rn; + m = update_M0 + rm; } } return make_multi_index(m, n); - }; + }; #else const auto unmerge = make_merge_transform(make_tuple(N0, M0)); @@ -140,7 +143,6 @@ struct Gemm unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); - }; #endif } @@ -161,15 +163,15 @@ struct Gemm using GridGemm = GridGemm; CK_TILE_DEVICE void operator()(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - const index_t M, - const index_t N, - const index_t K, - const index_t Lda, - const index_t Ldb, - const index_t Ldc, - const CElementFunction& c_element_func) const + const BDataType* p_b, + CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t Lda, + const index_t Ldb, + const index_t Ldc, + const CElementFunction& c_element_func) const { const auto a_dram = [&] { return make_naive_tensor_view( diff --git a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp index 8fb8cdbff7..1ef3c8cc5d 100644 --- a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp @@ -55,10 +55,8 @@ struct GridGemm __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()]; - const auto acc_block_tile = block_gemm_pipeline(a_block_window, - b_block_window, - K / kKPerBlock, - p_smem_char); + const auto acc_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char); // cast to CDataType and apply CElementFunction const auto c_block_tile = tile_elementwise_in( diff --git a/example/ck_tile/99_toy_example/02_gemm/reference_gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/reference_gemm.hpp index f2d0368bcb..07fc2f5221 100644 --- a/example/ck_tile/99_toy_example/02_gemm/reference_gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/reference_gemm.hpp @@ -8,8 +8,8 @@ template void reference_basic_gemm(const ck_tile::HostTensor& a_m_k, - const ck_tile::HostTensor& b_n_k, - ck_tile::HostTensor& c_m_n) + const ck_tile::HostTensor& b_n_k, + ck_tile::HostTensor& c_m_n) { const int N = b_n_k.mDesc.get_lengths()[0]; const int K = b_n_k.mDesc.get_lengths()[1]; @@ -24,12 +24,14 @@ void reference_basic_gemm(const ck_tile::HostTensor& a_m_k, ADataType v_a = a_m_k(m, k); BDataType v_b = b_n_k(n, k); - v_acc += ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); } c_m_n(m, n) = ck_tile::type_convert(v_acc); } }; - ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); + ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])( + std::thread::hardware_concurrency()); } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt old mode 100755 new mode 100644 index 44dfac099c..5ab8e1e314 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt @@ -1,19 +1,22 @@ set(EXAMPLE_REDUCE "basic_flash_attention_fwd") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" -message("adding example ${EXAMPLE_REDUCE}") +#not using add_example_executable() to add this target, since we don't want this to have +#to be included in "make all/install/check" + message("adding example ${EXAMPLE_REDUCE}") -add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL flash_attention_fwd.cpp) -target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -set(EXAMPLE_REDUCE_COMPILE_OPTIONS) + add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL flash_attention_fwd.cpp) + target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + set(EXAMPLE_REDUCE_COMPILE_OPTIONS) -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +#NOTE : we turn off undefined - func - \ + template to let source compile without explicit declare function specializations + list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - Wno - undefined - func - template - + Wno - float - equal) -target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) + target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${ + EXAMPLE_REDUCE_COMPILE_OPTIONS}) -# TODO: we have to turn off this global prop, otherwise the progress bar generated -# by cmake will print too many files, execvp: /bin/sh: Argument list too long -# however, this property may affect global -# TODO: consider codegen a makefile by us -set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) +#TODO : we have to turn off this global prop, otherwise the progress bar generated +#by cmake will print too many files, execvp : / bin / sh : Argument list too long +#however, this property may affect global +#TODO : consider codegen a makefile by us + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index d502655210..76548985ab 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -54,12 +54,10 @@ struct BlockGemmARegBSmemCRegV1 NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * NPerBlock * KPerBlock / - (kBlockSize * B_LDS_RW_Width); + WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); - constexpr index_t C_MFMA_Inst_Num = - MPerBlock * NPerBlock * KPerBlock / (kBlockSize / WaveSize) / - (MPerXDL * NPerXDL * KPerXDL); + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); // B split schedule constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 @@ -92,8 +90,8 @@ struct BlockGemmARegBSmemCRegV1 ignore = i; static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, @@ -126,17 +124,19 @@ struct BlockGemmARegBSmemCRegV1 const ABlockTensorTmp& a_block_tensor_tmp, const BBlockWindowTmp& b_block_window_tmp) const { - static_assert(std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, "wrong!"); + KPerBlock == BlockGemmShape::kK, + "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -154,13 +154,13 @@ struct BlockGemmARegBSmemCRegV1 const index_t iNWarp = get_warp_id() % NWarp; - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, @@ -190,11 +190,14 @@ struct BlockGemmARegBSmemCRegV1 auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - statically_indexed_array, - NIterPerWarp> b_warp_windows; + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -206,9 +209,11 @@ struct BlockGemmARegBSmemCRegV1 }); // check C-block-distribution - static_assert(std::is_same_v, - remove_cvref_t>, "wrong!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); using AWarpDstr = typename WG::AWarpDstr; using CWarpDstr = typename WG::CWarpDstr; @@ -216,8 +221,10 @@ struct BlockGemmARegBSmemCRegV1 using AWarpTensor = typename WG::AWarpTensor; using CWarpTensor = typename WG::CWarpTensor; - constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; @@ -260,16 +267,18 @@ struct BlockGemmARegBSmemCRegV1 __device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, const BBlockWindowTmp& b_block_window_tmp) const { - static_assert(std::is_same_v> && - std::is_same_v>, - "wrong!"); + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, "wrong!"); + KPerBlock == BlockGemmShape::kK, + "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -287,13 +296,13 @@ struct BlockGemmARegBSmemCRegV1 const index_t iNWarp = get_warp_id() % NWarp; - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, @@ -324,11 +333,14 @@ struct BlockGemmARegBSmemCRegV1 auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - statically_indexed_array, - NIterPerWarp> b_warp_windows; + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -348,8 +360,10 @@ struct BlockGemmARegBSmemCRegV1 using AWarpTensor = typename WG::AWarpTensor; using CWarpTensor = typename WG::CWarpTensor; - constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index d54b460b71..02fe7f54ad 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -212,8 +212,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< const ARegBlockTensorTmp& a_reg_block_tensor_tmp, void* p_smem) const { - static_assert(std::is_same_v>, - "wrong!"); + static_assert( + std::is_same_v>, + "wrong!"); static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], @@ -343,9 +344,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const BDataType & b) { return b; }, a_reg_block_tensor_tmp, p_smem); } @@ -357,7 +358,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< { return operator()( b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const BDataType & b) { return b; }, a_reg_block_tensor_tmp, p_smem); } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 2cafb715a2..3be95fdbae 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -76,13 +76,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_problem.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_problem.hpp index 1a620ba54b..000811e0ae 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_problem.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_problem.hpp @@ -23,4 +23,3 @@ struct BlockGemmPipelineProblem }; } // namespace ck_tile - diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp index 3ab79e23b1..4d48478084 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp @@ -36,7 +36,7 @@ __host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() { constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; + constexpr index_t kKPack = 8; using BDataType = remove_cvref_t; @@ -46,8 +46,8 @@ __host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, - number{}, - number{}), + number{}, + number{}), make_tuple(number{}, number{}, number<1>{}), number{}, number<1>{}); @@ -55,26 +55,25 @@ __host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), + number{})), + make_pass_through_transform(number{})), make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{})); constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple( + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp index 8ce1d6c6c7..fba49f9de5 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp @@ -29,14 +29,14 @@ int main(int argc, char* argv[]) using OaccDataType = float; using ODataType = ck_tile::half_t; - ck_tile::index_t Batch = 64; // Batch Number * Head Number - ck_tile::index_t M0 = 4096; // SequenceLengthQ - ck_tile::index_t N0 = 4096; // SequencelengthK - ck_tile::index_t K0 = 128; // HeadDim - ck_tile::index_t N1 = 128; // HeadDim - ck_tile::index_t verification = 0; - ck_tile::index_t init_method = 1; - [[maybe_unused]] ck_tile::index_t time_kernel = 0; + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; + [[maybe_unused]] ck_tile::index_t time_kernel = 0; if(argc == 4) { @@ -83,21 +83,21 @@ int main(int argc, char* argv[]) switch(init_method) { - case 0: break; - case 1: - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); - break; - case 2: - ck_tile::FillUniformDistribution{-3.f, 3.f}(q_host); - ck_tile::FillUniformDistribution{-3.f, 3.f}(k_host); - ck_tile::FillUniformDistribution{-3.f, 3.f}(v_host); - break; - default: - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); + case 0: break; + case 1: + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 2: + ck_tile::FillUniformDistribution{-3.f, 3.f}(q_host); + ck_tile::FillUniformDistribution{-3.f, 3.f}(k_host); + ck_tile::FillUniformDistribution{-3.f, 3.f}(v_host); + break; + default: + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); @@ -126,42 +126,42 @@ int main(int argc, char* argv[]) constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true}, - ck_tile::make_kernel( - ck_tile::FlashAttentionFwd{}, - kGridSize, - kBlockSize, - 0, - static_cast(q_buf.GetDeviceBuffer()), - static_cast(k_buf.GetDeviceBuffer()), - static_cast(v_buf.GetDeviceBuffer()), - static_cast(o_buf.GetDeviceBuffer()), - M0, - N0, - K0, - N1, - Batch, - K0, // StrideQ - K0, // StrideK - N0, // StrideV - N1, // StrideO - M0 * K0, // BatchStrideQ - N0 * K0, // BatchStrideK - N1 * N0, // BatchStrideV - M0 * N1)); // BatchStrideO + ck_tile::make_kernel( + ck_tile::FlashAttentionFwd{}, + kGridSize, + kBlockSize, + 0, + static_cast(q_buf.GetDeviceBuffer()), + static_cast(k_buf.GetDeviceBuffer()), + static_cast(v_buf.GetDeviceBuffer()), + static_cast(o_buf.GetDeviceBuffer()), + M0, + N0, + K0, + N1, + Batch, + K0, // StrideQ + K0, // StrideK + N0, // StrideV + N1, // StrideO + M0 * K0, // BatchStrideQ + N0 * K0, // BatchStrideK + N1 * N0, // BatchStrideV + M0 * N1)); // BatchStrideO // reference auto pass = true; @@ -175,8 +175,8 @@ int main(int argc, char* argv[]) ck_tile::reference_batched_gemm( q_host, k_host, s_host_ref); - ck_tile::reference_batched_softmax(s_host_ref, - p_host_ref); + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref); ck_tile::reference_batched_gemm( p_host_ref, v_host, o_host_ref); @@ -199,4 +199,3 @@ int main(int argc, char* argv[]) return !pass; } - diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp index 085366c0b3..c880639621 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp @@ -15,52 +15,57 @@ namespace ck_tile { -CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, - index_t N0) +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) { return [=](index_t block_1d_id) { - constexpr index_t M01 = 4; + constexpr index_t M01 = 4; constexpr index_t GroupNum = 8; const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; - const auto update_M0 = ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + const auto update_M0 = + ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; const auto xcd_id = block_1d_id % GroupNum; const auto l_block_id = block_1d_id - (xcd_id % 2); const auto ridn = GroupNum * M01 * (update_N0 / 2); - const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; - const auto lu = (l_block_id % GroupNum) + rid * ridn; + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; - const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); - const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01)+ lu ) ) / GroupNum; + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum; auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); const auto total_update_size = update_N0 * update_M0; - if (block_1d_id >= total_update_size) { - auto x = (block_1d_id + 1) - total_update_size; + if(block_1d_id >= total_update_size) + { + auto x = (block_1d_id + 1) - total_update_size; auto rlen = N0 - update_N0; auto rm = 0; auto rn = 0; - if (rlen > 0) { + if(rlen > 0) + { rm = (x - 1) / rlen; rn = x % rlen; } - if (rlen > 0 and rm < M0) { + if(rlen > 0 and rm < M0) + { n = rn + update_N0; m = rm; - } else { - x = x - rlen * M0; + } + else + { + x = x - rlen * M0; rm = (x - 1) / update_N0; rn = x % update_N0; - n = rn; - m = update_M0 + rm; + n = rn; + m = update_M0 + rm; } } return make_multi_index(m, n); @@ -117,8 +122,10 @@ struct FlashAttentionFwd const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0); const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); - const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) % num_tile_m0 * kM0PerBlock); - const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % num_tile_n1 * kN1PerBlock); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) % + num_tile_m0 * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % + num_tile_n1 * kN1PerBlock); #else const auto f = [](index_t dividend, index_t divisor) { @@ -129,9 +136,9 @@ struct FlashAttentionFwd }; const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); - const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); - const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); - const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); + const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); #endif diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 0d734aca3d..c10689567f 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -16,7 +16,6 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/reduce.hpp" - namespace ck_tile { // S[M0, N0] = Q[M0, K0] * K[N0, K0] @@ -40,27 +39,25 @@ template >; + using BlockGemm0Problem = + BlockGemmPipelineProblem>; using BlockGemm0Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; - using BlockGemm0Pipeline = - BlockGemmPipelineAGmemBGmemCReg; + using BlockGemm0Pipeline = BlockGemmPipelineAGmemBGmemCReg; // block gemm1 using BlockGemm1 = BlockGemmARegBSmemCRegV1< - BlockGemmARegBSmemCRegProblem< - PDataType, - VDataType, - OaccDataType, - kBlockSize, - TileGemmShape>, + BlockGemmARegBSmemCRegProblem>, BlockGemmARegBSmemCRegV1DefaultPolicy>; // 3d, with padding @@ -68,7 +65,7 @@ struct FlashAttentionFwdImpl { constexpr index_t kNPerBlock = kN1PerBlock; constexpr index_t kKPerBlock = kK1PerBlock; - constexpr index_t kKPack = 4; + constexpr index_t kKPack = 4; constexpr auto dataTypeSize = sizeof(VDataType); constexpr auto NLdsLayer = @@ -76,8 +73,8 @@ struct FlashAttentionFwdImpl constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, - number{}, - number{}), + number{}, + number{}), make_tuple(number{}, number{}, number<1>{}), number{}, number<1>{}); @@ -85,26 +82,26 @@ struct FlashAttentionFwdImpl constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), + number{})), + make_pass_through_transform(number{})), make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{})); constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; @@ -125,11 +122,11 @@ struct FlashAttentionFwdImpl return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); } __device__ static constexpr index_t GetStaticLdsSize() @@ -190,8 +187,8 @@ struct FlashAttentionFwdImpl // V LDS and LDS window // V LDS occupies the same LDS allocation Q/K LDS - auto v_lds = make_tensor_view(reinterpret_cast(smem_ptr), - MakeVLdsBlockDescriptor()); + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); @@ -229,8 +226,8 @@ struct FlashAttentionFwdImpl auto l = MLBlockTileType{}; tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); - tile_elementwise_inout([](auto& e) { e = std::numeric_limits::lowest(); }, - m); + tile_elementwise_inout( + [](auto& e) { e = std::numeric_limits::lowest(); }, m); tile_elementwise_inout([](auto& e) { e = 0; }, l); // loop over Column of S (J loop) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/reference_batched_gemm.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/reference_batched_gemm.hpp index 2762e66464..afe8be4924 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/reference_batched_gemm.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/reference_batched_gemm.hpp @@ -24,14 +24,15 @@ void reference_batched_gemm(const ck_tile::HostTensor& a_b_m_k, ADataType v_a = a_b_m_k(batch, m, k); BDataType v_b = b_b_n_k(batch, n, k); - v_acc += ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); } c_b_m_n(batch, m, n) = ck_tile::type_convert(v_acc); } }; - ck_tile::make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + ck_tile::make_ParallelTensorFunctor( + f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } - diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/reference_batched_softmax.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/reference_batched_softmax.hpp index 3713a22c6a..6f19f04b9d 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/reference_batched_softmax.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/reference_batched_softmax.hpp @@ -7,7 +7,8 @@ #include "ck_tile/host/host_tensor.hpp" template -void reference_batched_softmax(const ck_tile::HostTensor& a_b_m_n, ck_tile::HostTensor& b_b_m_n) +void reference_batched_softmax(const ck_tile::HostTensor& a_b_m_n, + ck_tile::HostTensor& b_b_m_n) { const int N = a_b_m_n.mDesc.get_lengths()[2]; @@ -41,7 +42,7 @@ void reference_batched_softmax(const ck_tile::HostTensor& a_b_m_n, ck } }; - ck_tile::make_ParallelTensorFunctor(f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])( + ck_tile::make_ParallelTensorFunctor( + f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } - diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt old mode 100755 new mode 100644 index cc9873eeb3..87ec6628cd --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt @@ -1,59 +1,48 @@ -set(FLASH_ATTENTION_FWD_KNOWN_APIS "fwd") -set(FLASH_ATTENTION_FWD_ENABLE_APIS "fwd" CACHE STRING - "semicolon-separated list of APIs to generate (${FLASH_ATTENTION_FWD_KNOWN_APIS}) & link, or \"all\".") -if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL "all") - set(FLASH_ATTENTION_FWD_ENABLE_APIS ${FLASH_ATTENTION_FWD_KNOWN_APIS}) -endif() +set(FLASH_ATTENTION_FWD_KNOWN_APIS + "fwd") set(FLASH_ATTENTION_FWD_ENABLE_APIS + "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${FLASH_ATTENTION_FWD_KNOWN_APIS}) & " + "link, or \"all\".") if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL + "all") set(FLASH_ATTENTION_FWD_ENABLE_APIS ${ + FLASH_ATTENTION_FWD_KNOWN_APIS}) endif() -execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} - --working_path ${CMAKE_CURRENT_BINARY_DIR} - --list_blobs - RESULT_VARIABLE ret -) -if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to list Flash Attention kernels via Python. ${ret}") -endif() + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR} / + generate.py-- api $ { FLASH_ATTENTION_FWD_ENABLE_APIS } --working_path $ { + CMAKE_CURRENT_BINARY_DIR + } --list_blobs RESULT_VARIABLE ret) if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list Flash Attention kernels via Python. ${ret}") endif() -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/flash_attention_fwd_blobs.txt FLASH_ATTENTION_FWD_GEN_BLOBS) + file(STRINGS ${CMAKE_CURRENT_BINARY_DIR} / + flash_attention_fwd_blobs.txt FLASH_ATTENTION_FWD_GEN_BLOBS) -add_custom_command( - OUTPUT ${FLASH_ATTENTION_FWD_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} - --working_path ${CMAKE_CURRENT_BINARY_DIR} - --gen_blobs -) + add_custom_command( + OUTPUT ${FLASH_ATTENTION_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${ + CMAKE_CURRENT_LIST_DIR} / + generate.py-- api $ { FLASH_ATTENTION_FWD_ENABLE_APIS } --working_path $ { + CMAKE_CURRENT_BINARY_DIR + } --gen_blobs) -set(EXAMPLE_REDUCE "codegen_basic_flash_attention_fwd") -message("adding example ${EXAMPLE_REDUCE}") + set(EXAMPLE_REDUCE "codegen_basic_flash_attention_fwd") message( + "adding example ${EXAMPLE_REDUCE}") -add_executable(${EXAMPLE_REDUCE} - EXCLUDE_FROM_ALL - flash_attention_fwd.cpp -) + add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL flash_attention_fwd.cpp) -target_include_directories(${EXAMPLE_REDUCE} - PRIVATE - ${CMAKE_CURRENT_LIST_DIR} -) + target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${ + CMAKE_CURRENT_LIST_DIR}) -target_sources(${EXAMPLE_REDUCE} PRIVATE ${FLASH_ATTENTION_FWD_GEN_BLOBS}) + target_sources(${EXAMPLE_REDUCE} PRIVATE ${ + FLASH_ATTENTION_FWD_GEN_BLOBS}) -message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}") + message("FLASH_ATTENTION_FWD_GEN_BLOBS = " + "${FLASH_ATTENTION_FWD_GEN_BLOBS}") + set(EXAMPLE_REDUCE_COMPILE_OPTIONS) + list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - Wno - + undefined - func - template - Wno - float - + equal-- offload - compress) -set(EXAMPLE_REDUCE_COMPILE_OPTIONS) -list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - -Wno-undefined-func-template - -Wno-float-equal - --offload-compress -) + target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${ + EXAMPLE_REDUCE_COMPILE_OPTIONS}) -target_compile_options(${EXAMPLE_REDUCE} - PRIVATE - ${EXAMPLE_REDUCE_COMPILE_OPTIONS} -) - -set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 740c540d6c..830b2422b5 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -32,17 +32,19 @@ struct BlockGemmARegBSmemCRegV1 const ABlockTensorTmp& a_block_tensor_tmp, const BBlockWindowTmp& b_block_window_tmp) const { - static_assert(std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, "wrong!"); + KPerBlock == BlockGemmShape::kK, + "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -60,13 +62,13 @@ struct BlockGemmARegBSmemCRegV1 const index_t iNWarp = get_warp_id() % NWarp; - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, @@ -96,11 +98,14 @@ struct BlockGemmARegBSmemCRegV1 auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - statically_indexed_array, - NIterPerWarp> b_warp_windows; + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -112,9 +117,11 @@ struct BlockGemmARegBSmemCRegV1 }); // check C-block-distribution - static_assert(std::is_same_v, - remove_cvref_t>, "wrong!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); using AWarpDstr = typename WG::AWarpDstr; using CWarpDstr = typename WG::CWarpDstr; @@ -122,8 +129,10 @@ struct BlockGemmARegBSmemCRegV1 using AWarpTensor = typename WG::AWarpTensor; using CWarpTensor = typename WG::CWarpTensor; - constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; @@ -166,16 +175,18 @@ struct BlockGemmARegBSmemCRegV1 __device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, const BBlockWindowTmp& b_block_window_tmp) const { - static_assert(std::is_same_v> && - std::is_same_v>, - "wrong!"); + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, "wrong!"); + KPerBlock == BlockGemmShape::kK, + "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -193,13 +204,13 @@ struct BlockGemmARegBSmemCRegV1 const index_t iNWarp = get_warp_id() % NWarp; - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, @@ -230,11 +241,14 @@ struct BlockGemmARegBSmemCRegV1 auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - statically_indexed_array, - NIterPerWarp> b_warp_windows; + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -254,8 +268,10 @@ struct BlockGemmARegBSmemCRegV1 using AWarpTensor = typename WG::AWarpTensor; using CWarpTensor = typename WG::CWarpTensor; - constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 4a8c2beeb7..ec9484ffd1 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -207,8 +207,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< const ARegBlockTensorTmp& a_reg_block_tensor_tmp, void* p_smem) const { - static_assert(std::is_same_v>, - "wrong!"); + static_assert( + std::is_same_v>, + "wrong!"); static_assert(kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], @@ -334,9 +335,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const BDataType & b) { return b; }, a_reg_block_tensor_tmp, p_smem); } @@ -348,7 +349,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< { return operator()( b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const BDataType & b) { return b; }, a_reg_block_tensor_tmp, p_smem); } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 2cafb715a2..3be95fdbae 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -76,13 +76,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_problem.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_problem.hpp index 1a620ba54b..000811e0ae 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_problem.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_problem.hpp @@ -23,4 +23,3 @@ struct BlockGemmPipelineProblem }; } // namespace ck_tile - diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp index 3ab79e23b1..4d48478084 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp @@ -36,7 +36,7 @@ __host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() { constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; + constexpr index_t kKPack = 8; using BDataType = remove_cvref_t; @@ -46,8 +46,8 @@ __host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, - number{}, - number{}), + number{}, + number{}), make_tuple(number{}, number{}, number<1>{}), number{}, number<1>{}); @@ -55,26 +55,25 @@ __host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), + number{})), + make_pass_through_transform(number{})), make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{})); constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple( + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp index e3f4fde0bb..9b7c9b1c6c 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp @@ -29,14 +29,14 @@ int main(int argc, char* argv[]) using OaccDataType = float; using ODataType = ck_tile::half_t; - ck_tile::index_t Batch = 64; // Batch Number * Head Number - ck_tile::index_t M0 = 4096; // SequenceLengthQ - ck_tile::index_t N0 = 4096; // SequencelengthK - ck_tile::index_t K0 = 128; // HeadDim - ck_tile::index_t N1 = 128; // HeadDim - ck_tile::index_t verification = 0; - ck_tile::index_t init_method = 1; - [[maybe_unused]] ck_tile::index_t time_kernel = 0; + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; + [[maybe_unused]] ck_tile::index_t time_kernel = 0; if(argc == 4) { @@ -83,21 +83,21 @@ int main(int argc, char* argv[]) switch(init_method) { - case 0: break; - case 1: - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); - break; - case 2: - ck_tile::FillUniformDistribution{-3.f, 3.f}(q_host); - ck_tile::FillUniformDistribution{-3.f, 3.f}(k_host); - ck_tile::FillUniformDistribution{-3.f, 3.f}(v_host); - break; - default: - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); + case 0: break; + case 1: + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 2: + ck_tile::FillUniformDistribution{-3.f, 3.f}(q_host); + ck_tile::FillUniformDistribution{-3.f, 3.f}(k_host); + ck_tile::FillUniformDistribution{-3.f, 3.f}(v_host); + break; + default: + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); @@ -108,9 +108,8 @@ int main(int argc, char* argv[]) k_buf.ToDevice(k_host.mData.data()); v_buf.ToDevice(v_host.mData.data()); - // Construct the FlashAttnArgs object with your arguments - ck_tile::FlashAttnArgs flash_attention_args { + ck_tile::FlashAttnArgs flash_attention_args{ static_cast(q_buf.GetDeviceBuffer()), static_cast(k_buf.GetDeviceBuffer()), static_cast(v_buf.GetDeviceBuffer()), @@ -120,25 +119,25 @@ int main(int argc, char* argv[]) K0, N1, Batch, - K0, // strideQ - K0, // strideK - N0, // strideV - N1, // strideO + K0, // strideQ + K0, // strideK + N0, // strideV + N1, // strideO M0 * K0, // batchStrideQ N0 * K0, // batchStrideK N1 * N0, // batchStrideV - M0 * N1 // batchStrideO + M0 * N1 // batchStrideO }; - float ave_time = ck_tile::flash_attention_fwd - (flash_attention_args, ck_tile::stream_config{nullptr, true}); + float ave_time = ck_tile::flash_attention_fwd(flash_attention_args, + ck_tile::stream_config{nullptr, true}); // reference auto pass = true; @@ -152,8 +151,8 @@ int main(int argc, char* argv[]) ck_tile::reference_batched_gemm( q_host, k_host, s_host_ref); - ck_tile::reference_batched_softmax(s_host_ref, - p_host_ref); + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref); ck_tile::reference_batched_gemm( p_host_ref, v_host, o_host_ref); @@ -176,4 +175,3 @@ int main(int argc, char* argv[]) return !pass; } - diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index c95c291123..56767bec66 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -16,11 +16,7 @@ namespace ck_tile { - -template +template struct FlashAttnArgs { // Pointers to device buffers for Q, K, V, O @@ -29,29 +25,26 @@ struct FlashAttnArgs VDataType* v_ptr; ODataType* o_ptr; - // Problem sizes + // Problem sizes index_t M0; index_t N0; index_t K0; index_t N1; index_t Batch; - // Strides within a batch - index_t strideQ; - index_t strideK; - index_t strideV; - index_t strideO; + // Strides within a batch + index_t strideQ; + index_t strideK; + index_t strideV; + index_t strideO; - // Batch strides - index_t batchStrideQ; - index_t batchStrideK; - index_t batchStrideV; - index_t batchStrideO; + // Batch strides + index_t batchStrideQ; + index_t batchStrideK; + index_t batchStrideV; + index_t batchStrideO; }; - - - // S[M0, N0] = Q[M0, K0] * K[N0, K0] // P[M0, N0] = Softmax(S[M0, N0]) // O[M0, N1] = P[M0, N0] * V[N1, N0] @@ -102,8 +95,10 @@ struct FlashAttentionFwd const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0); const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); - const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) % num_tile_m0 * kM0PerBlock); - const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % num_tile_n1 * kN1PerBlock); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) % + num_tile_m0 * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % + num_tile_n1 * kN1PerBlock); #else const auto f = [](index_t dividend, index_t divisor) { @@ -114,9 +109,9 @@ struct FlashAttentionFwd }; const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); - const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); - const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); - const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); + const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); #endif @@ -179,7 +174,7 @@ struct FlashAttentionFwd // static constexpr index_t kK0PerBlock = kK0PerBlock_; // static constexpr index_t kN1PerBlock = kN1PerBlock_; // static constexpr index_t kK1PerBlock = kK1PerBlock_; - + // static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD // static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; // static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; @@ -215,7 +210,7 @@ struct FlashAttentionFwd // typename VDataType, // typename ODataType, // typename Traits_> -// float flash_attention_fwd_(const FlashAttnArgs& a, +// float flash_attention_fwd_(const FlashAttnArgs& a, // const ck_tile::stream_config& stream_config); // // TODO: fwd_common.cpp @@ -224,13 +219,13 @@ struct FlashAttentionFwd // typename VDataType, // typename ODataType, // typename Traits_> -// float flash_attention_fwd_(const FlashAttnArgs& a, +// float flash_attention_fwd_(const FlashAttnArgs& a, // const ck_tile::stream_config& stream_config) { -// using SaccDataType = typename Traits_::SaccDataType; -// using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; -// using PDataType = typename Traits_::PDataType; -// using OaccDataType = typename Traits_::OaccDataType; - +// using SaccDataType = typename Traits_::SaccDataType; +// using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; +// using PDataType = typename Traits_::PDataType; +// using OaccDataType = typename Traits_::OaccDataType; + // index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); // std::cout << "grid size " << kGridSize << std::endl; @@ -284,7 +279,7 @@ struct FlashAttentionFwd // typename PDataType, // typename OaccDataType, // typename ODataType> -// float flash_attention_fwd(const FlashAttnArgs& a, +// float flash_attention_fwd(const FlashAttnArgs& a, // const ck_tile::stream_config& stream_config) { // constexpr ck_tile::index_t kM0PerBlock = 128; // constexpr ck_tile::index_t kN0PerBlock = 128; @@ -295,26 +290,25 @@ struct FlashAttentionFwd // constexpr ck_tile::index_t kBlockSize = 256; // constexpr ck_tile::index_t kHeadDim = 128; -// return flash_attention_fwd_> // (a, stream_config); // } - // TODO: change to only declare // TODO: fwd_api.cpp template -float flash_attention_fwd(const FlashAttnArgs& a, +float flash_attention_fwd(const FlashAttnArgs& a, const stream_config& stream_config); - } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp index 94016025c4..bb7df5469b 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -16,7 +16,6 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/reduce.hpp" - namespace ck_tile { // S[M0, N0] = Q[M0, K0] * K[N0, K0] @@ -40,27 +39,25 @@ template >; + using BlockGemm0Problem = + BlockGemmPipelineProblem>; using BlockGemm0Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; - using BlockGemm0Pipeline = - BlockGemmPipelineAGmemBGmemCReg; + using BlockGemm0Pipeline = BlockGemmPipelineAGmemBGmemCReg; // block gemm1 using BlockGemm1 = BlockGemmARegBSmemCRegV1< - BlockGemmARegBSmemCRegProblem< - PDataType, - VDataType, - OaccDataType, - kBlockSize, - TileGemmShape>, + BlockGemmARegBSmemCRegProblem>, BlockGemmARegBSmemCRegV1DefaultPolicy>; // 3d, with padding @@ -68,7 +65,7 @@ struct FlashAttentionFwdImpl { constexpr index_t kNPerBlock = kN1PerBlock; constexpr index_t kKPerBlock = kK1PerBlock; - constexpr index_t kKPack = 4; + constexpr index_t kKPack = 4; constexpr auto dataTypeSize = sizeof(VDataType); constexpr auto NLdsLayer = @@ -76,8 +73,8 @@ struct FlashAttentionFwdImpl constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, - number{}, - number{}), + number{}, + number{}), make_tuple(number{}, number{}, number<1>{}), number{}, number<1>{}); @@ -85,26 +82,26 @@ struct FlashAttentionFwdImpl constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), + number{})), + make_pass_through_transform(number{})), make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{})); constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; @@ -125,11 +122,11 @@ struct FlashAttentionFwdImpl return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); } __device__ static constexpr index_t GetStaticLdsSize() @@ -190,8 +187,8 @@ struct FlashAttentionFwdImpl // V LDS and LDS window // V LDS occupies the same LDS allocation Q/K LDS - auto v_lds = make_tensor_view(reinterpret_cast(smem_ptr), - MakeVLdsBlockDescriptor()); + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); @@ -229,8 +226,8 @@ struct FlashAttentionFwdImpl auto l = MLBlockTileType{}; tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); - tile_elementwise_inout([](auto& e) { e = std::numeric_limits::lowest(); }, - m); + tile_elementwise_inout( + [](auto& e) { e = std::numeric_limits::lowest(); }, m); tile_elementwise_inout([](auto& e) { e = 0; }, l); // loop over Column of S (J loop) diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py index 4a36179dbe..ac6c16c797 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py @@ -1,23 +1,15 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#SPDX - License - Identifier : MIT +#Copyright(c) 2025, Advanced Micro Devices, Inc.All rights reserved. -import argparse -from enum import IntEnum -from pathlib import Path -import sys -from typing import List, Optional, Any -import functools -import itertools -import copy -from dataclasses import dataclass +import argparse from enum import IntEnum from pathlib import Path import sys from typing import List, Optional, Any import functools import itertools import copy from dataclasses import dataclass -# def get_if_str(idx, total, last_else=True): -# if idx == 0: -# return 'if' -# elif idx < total - 1: -# return 'else if' -# else: -# return 'else' if last_else else 'else if' +#def get_if_str(idx, total, last_else = True): +#if idx == 0: +#return 'if' +#elif idx < total - 1: +#return 'else if' +#else: +#return 'else' if last_else else 'else if' def get_if_str(size_, total, last_else=True): if size_ == "small": @@ -90,77 +82,80 @@ using traits_ = flash_attention_fwd_traits_; """ -# API_COMMON_HEADER = """ -# // SPDX-License-Identifier: MIT -# // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#API_COMMON_HEADER = "" \ + " +#// SPDX-License-Identifier: MIT +#// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -# #include +# #include # #include "flash_attention_fwd.hpp" -# #include +# #include # #pragma once -# using S = ck_tile::stream_config; -# using A = FlashAttnArgs; +#using S = ck_tile::stream_config; +#using A = FlashAttnArgs; -# {F_traits_define} +#{F_traits_define } -# template -# float flash_attention_fwd_(const FlashAttnArgs& a, -# const ck_tile::stream_config& stream_config) {{ -# using SaccDataType = typename Traits_::SaccDataType; -# using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; -# using PDataType = typename Traits_::PDataType; -# using OaccDataType = typename Traits_::OaccDataType; - -# index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); +#template < typename QDataType, +#typename KDataType, +#typename VDataType, +#typename ODataType, +#typename Traits_> +#float flash_attention_fwd_(const FlashAttnArgs& a, +#const ck_tile::stream_config& stream_config){{ +#using SaccDataType = typename Traits_::SaccDataType; +#using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; +#using PDataType = typename Traits_::PDataType; +#using OaccDataType = typename Traits_::OaccDataType; -# if(stream_config.log_level_ > 0) -# std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush; +#index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); -# return ck_tile::launch_kernel(stream_config, -# ck_tile::make_kernel( -# ck_tile::FlashAttentionFwd{{}}, -# kGridSize, -# Traits_::kBlockSize, -# 0, -# a.q_ptr, -# a.k_ptr, -# a.v_ptr, -# a.o_ptr, -# a.M0, -# a.N0, -# a.K0, -# a.N1, -# a.Batch, -# a.strideQ, // StrideQ -# a.strideK, // StrideK -# a.strideV, // StrideV -# a.strideO, // StrideO -# a.batchStrideQ, // BatchStrideQ -# a.batchStrideK, // BatchStrideK -# a.batchStrideV, // BatchStrideV -# a.batchStrideO)); // BatchStrideO -# }} -# """ +#if (stream_config.log_level_ > 0) +#std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim \ + << ">" << std::flush; + +#return ck_tile::launch_kernel(stream_config, +#ck_tile::make_kernel < Traits_::kBlockSize, Traits_::kBlockPerCu>( +#ck_tile::FlashAttentionFwd < QDataType, +#KDataType, +#VDataType, +#SaccDataType, +#SMPLComputeDataType, +#PDataType, +#OaccDataType, +#ODataType, +#Traits_::kBlockSize, +#Traits_::kHeadDim, +#Traits_::kM0PerBlock, +#Traits_::kN0PerBlock, +#Traits_::kK0PerBlock, +#Traits_::kN1PerBlock, +#Traits_::kK1PerBlock>{{} }, +#kGridSize, +#Traits_::kBlockSize, +# 0, +#a.q_ptr, +#a.k_ptr, +#a.v_ptr, +#a.o_ptr, +#a.M0, +#a.N0, +#a.K0, +#a.N1, +#a.Batch, +#a.strideQ, // StrideQ +#a.strideK, // StrideK +#a.strideV, // StrideV +#a.strideO, // StrideO +#a.batchStrideQ, // BatchStrideQ +#a.batchStrideK, // BatchStrideK +#a.batchStrideV, // BatchStrideV +#a.batchStrideO)); // BatchStrideO +#} } +#"" \ + " API_BASE = """ // SPDX-License-Identifier: MIT @@ -204,14 +199,19 @@ template float flash_attention_fwd && std::is_same_v && std::is_same_v && std::is_same_v) {{ -# {F_per_size_case} -# }} -# """ -# API_PER_SIZE_CASE = """ {F_if} {F_SIZE_COND} {{ -# {F_inner_dispatch} -# }} -# """ +#API_PER_DTYPE = \ + "" \ + " {F_if}(std::is_same_v && std::is_same_v && std::is_same_v && std::is_same_v) {{ +#{F_per_size_case } +#} } +#"" \ + " +#API_PER_SIZE_CASE = "" \ + " {F_if} {F_SIZE_COND} {{ +#{F_inner_dispatch } +#} } +#"" \ + " API_INNER_CASE = """ {F_if} {F_VEC_COND} r = flash_attention_fwd_>(a, stream_config); """ @@ -430,7 +430,7 @@ float flash_attention_fwd_(const FlashAttnArgs str: - # Sort based on dtype +#Sort based on dtype t_dtype_dict = {} blobs = self.get_blobs(args) @@ -464,28 +464,28 @@ float flash_attention_fwd_(const FlashAttnArgs& a_b_m_k, ADataType v_a = a_b_m_k(batch, m, k); BDataType v_b = b_b_n_k(batch, n, k); - v_acc += ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); } c_b_m_n(batch, m, n) = ck_tile::type_convert(v_acc); } }; - ck_tile::make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( + ck_tile::make_ParallelTensorFunctor( + f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } - diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/reference_batched_softmax.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/reference_batched_softmax.hpp index 3713a22c6a..6f19f04b9d 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/reference_batched_softmax.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/reference_batched_softmax.hpp @@ -7,7 +7,8 @@ #include "ck_tile/host/host_tensor.hpp" template -void reference_batched_softmax(const ck_tile::HostTensor& a_b_m_n, ck_tile::HostTensor& b_b_m_n) +void reference_batched_softmax(const ck_tile::HostTensor& a_b_m_n, + ck_tile::HostTensor& b_b_m_n) { const int N = a_b_m_n.mDesc.get_lengths()[2]; @@ -41,7 +42,7 @@ void reference_batched_softmax(const ck_tile::HostTensor& a_b_m_n, ck } }; - ck_tile::make_ParallelTensorFunctor(f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])( + ck_tile::make_ParallelTensorFunctor( + f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])( std::thread::hardware_concurrency()); } - From 0a1227a5e770b021f08c66e3e115de37860d9e7f Mon Sep 17 00:00:00 2001 From: BoboFang Date: Tue, 22 Apr 2025 15:24:55 +0000 Subject: [PATCH 09/21] Fix error after clang-format --- .../99_toy_example/01_add/CMakeLists.txt | 33 ++- .../99_toy_example/02_gemm/CMakeLists.txt | 41 ++- .../03_flash_attention_fwd/CMakeLists.txt | 31 +- .../CMakeLists.txt | 85 +++--- .../generate.py | 272 +++++++++--------- 5 files changed, 234 insertions(+), 228 deletions(-) mode change 100644 => 100755 example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt mode change 100644 => 100755 example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt diff --git a/example/ck_tile/99_toy_example/01_add/CMakeLists.txt b/example/ck_tile/99_toy_example/01_add/CMakeLists.txt index ad51b75885..f241173fe7 100644 --- a/example/ck_tile/99_toy_example/01_add/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/01_add/CMakeLists.txt @@ -1,23 +1,22 @@ set(EXAMPLE_REDUCE "add") -#not using add_example_executable() to add this target, since we don't want this to have -#to be included in "make all/install/check" - message("adding example ${EXAMPLE_REDUCE}") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_REDUCE}") - add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL add.cpp) target_include_directories(${ - EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) set(EXAMPLE_REDUCE_COMPILE_OPTIONS) +add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL add.cpp) +target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_REDUCE_COMPILE_OPTIONS) -#generate assembly -#list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - v-- save - temps - Wno - gnu - line - marker) +# generate assembly +# list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) -#NOTE : we turn off undefined - func - \ - template to let source compile without explicit declare function specializations - list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - Wno - undefined - func - template - Wno - - float - equal) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) - target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) +target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) -#TODO : we have to turn off this global prop, otherwise the progress bar generated -#by cmake will print too many files, execvp : / bin / sh : Argument list too long -#however, this property may affect global -#TODO : consider codegen a makefile by us - set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt b/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt index afb32d49a9..02863398fe 100644 --- a/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt @@ -1,28 +1,27 @@ set(EXAMPLE_REDUCE "basic_gemm") -#not using add_example_executable() to add this target, since we don't want this to have -#to be included in "make all/install/check" - message("adding example ${EXAMPLE_REDUCE}") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_REDUCE}") - add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL gemm.cpp) target_include_directories(${ - EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) set(EXAMPLE_REDUCE_COMPILE_OPTIONS) +add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL gemm.cpp) +target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_REDUCE_COMPILE_OPTIONS) -#generate assembly -#list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - v-- save - temps - Wno - gnu - line - marker) +# generate assembly +# list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) -#NOTE : we turn off undefined - func - \ - template to let source compile without explicit declare function specializations - list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - Wno - undefined - func - template - Wno - - float - equal) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) - if(DEFINED kernel) message("Compiling with Kernel: ${kernel}") - target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE KERNEL_${kernel} = 1) - endif() +if(DEFINED kernel) + message("Compiling with Kernel: ${kernel}") + target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE KERNEL_${kernel}=1) +endif() - target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${ - EXAMPLE_REDUCE_COMPILE_OPTIONS}) +target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) -#TODO : we have to turn off this global prop, otherwise the progress bar generated -#by cmake will print too many files, execvp : / bin / sh : Argument list too long -#however, this property may affect global -#TODO : consider codegen a makefile by us - set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt old mode 100644 new mode 100755 index 5ab8e1e314..44dfac099c --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt @@ -1,22 +1,19 @@ set(EXAMPLE_REDUCE "basic_flash_attention_fwd") -#not using add_example_executable() to add this target, since we don't want this to have -#to be included in "make all/install/check" - message("adding example ${EXAMPLE_REDUCE}") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_REDUCE}") - add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL flash_attention_fwd.cpp) - target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) - set(EXAMPLE_REDUCE_COMPILE_OPTIONS) +add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL flash_attention_fwd.cpp) +target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_REDUCE_COMPILE_OPTIONS) -#NOTE : we turn off undefined - func - \ - template to let source compile without explicit declare function specializations - list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - Wno - undefined - func - template - - Wno - float - equal) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) - target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${ - EXAMPLE_REDUCE_COMPILE_OPTIONS}) +target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) -#TODO : we have to turn off this global prop, otherwise the progress bar generated -#by cmake will print too many files, execvp : / bin / sh : Argument list too long -#however, this property may affect global -#TODO : consider codegen a makefile by us - set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt old mode 100644 new mode 100755 index 87ec6628cd..e9a697fde3 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt @@ -1,48 +1,59 @@ -set(FLASH_ATTENTION_FWD_KNOWN_APIS - "fwd") set(FLASH_ATTENTION_FWD_ENABLE_APIS - "fwd" CACHE STRING - "semicolon-separated list of APIs to generate (${FLASH_ATTENTION_FWD_KNOWN_APIS}) & " - "link, or \"all\".") if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL - "all") set(FLASH_ATTENTION_FWD_ENABLE_APIS ${ - FLASH_ATTENTION_FWD_KNOWN_APIS}) endif() +set(FLASH_ATTENTION_FWD_KNOWN_APIS "fwd") +set(FLASH_ATTENTION_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${FLASH_ATTENTION_FWD_KNOWN_APIS}) & link, or \"all\".") +if(FLASH_ATTENTION_FWD_ENABLE_APIS STREQUAL "all") + set(FLASH_ATTENTION_FWD_ENABLE_APIS ${FLASH_ATTENTION_FWD_KNOWN_APIS}) +endif() - execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR} / - generate.py-- api $ { FLASH_ATTENTION_FWD_ENABLE_APIS } --working_path $ { - CMAKE_CURRENT_BINARY_DIR - } --list_blobs RESULT_VARIABLE ret) if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to list Flash Attention kernels via Python. ${ret}") endif() +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --list_blobs + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list Flash Attention kernels via Python. ${ret}") +endif() - file(STRINGS ${CMAKE_CURRENT_BINARY_DIR} / - flash_attention_fwd_blobs.txt FLASH_ATTENTION_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/flash_attention_fwd_blobs.txt FLASH_ATTENTION_FWD_GEN_BLOBS) - add_custom_command( - OUTPUT ${FLASH_ATTENTION_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${ - CMAKE_CURRENT_LIST_DIR} / - generate.py-- api $ { FLASH_ATTENTION_FWD_ENABLE_APIS } --working_path $ { - CMAKE_CURRENT_BINARY_DIR - } --gen_blobs) +add_custom_command( + OUTPUT ${FLASH_ATTENTION_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --gen_blobs +) - set(EXAMPLE_REDUCE "codegen_basic_flash_attention_fwd") message( - "adding example ${EXAMPLE_REDUCE}") +set(EXAMPLE_REDUCE "codegen_basic_flash_attention_fwd") +message("adding example ${EXAMPLE_REDUCE}") - add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL flash_attention_fwd.cpp) +add_executable(${EXAMPLE_REDUCE} + EXCLUDE_FROM_ALL + flash_attention_fwd.cpp +) - target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${ - CMAKE_CURRENT_LIST_DIR}) +target_include_directories(${EXAMPLE_REDUCE} + PRIVATE + ${CMAKE_CURRENT_LIST_DIR} +) - target_sources(${EXAMPLE_REDUCE} PRIVATE ${ - FLASH_ATTENTION_FWD_GEN_BLOBS}) +target_sources(${EXAMPLE_REDUCE} PRIVATE ${FLASH_ATTENTION_FWD_GEN_BLOBS}) - message("FLASH_ATTENTION_FWD_GEN_BLOBS = " - "${FLASH_ATTENTION_FWD_GEN_BLOBS}") +message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}") - set(EXAMPLE_REDUCE_COMPILE_OPTIONS) - list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - Wno - - undefined - func - template - Wno - float - - equal-- offload - compress) - target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${ - EXAMPLE_REDUCE_COMPILE_OPTIONS}) +set(EXAMPLE_REDUCE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) - set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file +target_compile_options(${EXAMPLE_REDUCE} + PRIVATE + ${EXAMPLE_REDUCE_COMPILE_OPTIONS} +) + +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py index ac6c16c797..040133f8e9 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py @@ -1,15 +1,23 @@ -#SPDX - License - Identifier : MIT -#Copyright(c) 2025, Advanced Micro Devices, Inc.All rights reserved. +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -import argparse from enum import IntEnum from pathlib import Path import sys from typing import List, Optional, Any import functools import itertools import copy from dataclasses import dataclass +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Any +import functools +import itertools +import copy +from dataclasses import dataclass -#def get_if_str(idx, total, last_else = True): -#if idx == 0: -#return 'if' -#elif idx < total - 1: -#return 'else if' -#else: -#return 'else' if last_else else 'else if' +# def get_if_str(idx, total, last_else=True): +# if idx == 0: +# return 'if' +# elif idx < total - 1: +# return 'else if' +# else: +# return 'else' if last_else else 'else if' def get_if_str(size_, total, last_else=True): if size_ == "small": @@ -26,7 +34,7 @@ def BOOL_MAP(b_) -> str: class FlashAttentionFwdCodegen: API_TRAITS_DEFINE = """ - + template ; """ -#API_COMMON_HEADER = "" \ - " -#// SPDX-License-Identifier: MIT -#// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# API_COMMON_HEADER = """ +# // SPDX-License-Identifier: MIT +# // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -# #include +# #include # #include "flash_attention_fwd.hpp" -# #include +# #include # #pragma once -#using S = ck_tile::stream_config; -#using A = FlashAttnArgs; +# using S = ck_tile::stream_config; +# using A = FlashAttnArgs; -#{F_traits_define } +# {F_traits_define} -#template < typename QDataType, -#typename KDataType, -#typename VDataType, -#typename ODataType, -#typename Traits_> -#float flash_attention_fwd_(const FlashAttnArgs& a, -#const ck_tile::stream_config& stream_config){{ -#using SaccDataType = typename Traits_::SaccDataType; -#using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; -#using PDataType = typename Traits_::PDataType; -#using OaccDataType = typename Traits_::OaccDataType; +# template +# float flash_attention_fwd_(const FlashAttnArgs& a, +# const ck_tile::stream_config& stream_config) {{ +# using SaccDataType = typename Traits_::SaccDataType; +# using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; +# using PDataType = typename Traits_::PDataType; +# using OaccDataType = typename Traits_::OaccDataType; -#index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); +# index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); -#if (stream_config.log_level_ > 0) -#std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim \ - << ">" << std::flush; +# if(stream_config.log_level_ > 0) +# std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush; -#return ck_tile::launch_kernel(stream_config, -#ck_tile::make_kernel < Traits_::kBlockSize, Traits_::kBlockPerCu>( -#ck_tile::FlashAttentionFwd < QDataType, -#KDataType, -#VDataType, -#SaccDataType, -#SMPLComputeDataType, -#PDataType, -#OaccDataType, -#ODataType, -#Traits_::kBlockSize, -#Traits_::kHeadDim, -#Traits_::kM0PerBlock, -#Traits_::kN0PerBlock, -#Traits_::kK0PerBlock, -#Traits_::kN1PerBlock, -#Traits_::kK1PerBlock>{{} }, -#kGridSize, -#Traits_::kBlockSize, -# 0, -#a.q_ptr, -#a.k_ptr, -#a.v_ptr, -#a.o_ptr, -#a.M0, -#a.N0, -#a.K0, -#a.N1, -#a.Batch, -#a.strideQ, // StrideQ -#a.strideK, // StrideK -#a.strideV, // StrideV -#a.strideO, // StrideO -#a.batchStrideQ, // BatchStrideQ -#a.batchStrideK, // BatchStrideK -#a.batchStrideV, // BatchStrideV -#a.batchStrideO)); // BatchStrideO -#} } -#"" \ - " +# return ck_tile::launch_kernel(stream_config, +# ck_tile::make_kernel( +# ck_tile::FlashAttentionFwd{{}}, +# kGridSize, +# Traits_::kBlockSize, +# 0, +# a.q_ptr, +# a.k_ptr, +# a.v_ptr, +# a.o_ptr, +# a.M0, +# a.N0, +# a.K0, +# a.N1, +# a.Batch, +# a.strideQ, // StrideQ +# a.strideK, // StrideK +# a.strideV, // StrideV +# a.strideO, // StrideO +# a.batchStrideQ, // BatchStrideQ +# a.batchStrideK, // BatchStrideK +# a.batchStrideV, // BatchStrideV +# a.batchStrideO)); // BatchStrideO +# }} +# """ API_BASE = """ // SPDX-License-Identifier: MIT @@ -199,19 +204,14 @@ template float flash_attention_fwd && std::is_same_v && std::is_same_v && std::is_same_v) {{ -#{F_per_size_case } -#} } -#"" \ - " -#API_PER_SIZE_CASE = "" \ - " {F_if} {F_SIZE_COND} {{ -#{F_inner_dispatch } -#} } -#"" \ - " +# API_PER_DTYPE = """ {F_if}(std::is_same_v && std::is_same_v && std::is_same_v && std::is_same_v) {{ +# {F_per_size_case} +# }} +# """ +# API_PER_SIZE_CASE = """ {F_if} {F_SIZE_COND} {{ +# {F_inner_dispatch} +# }} +# """ API_INNER_CASE = """ {F_if} {F_VEC_COND} r = flash_attention_fwd_>(a, stream_config); """ @@ -224,7 +224,7 @@ template float flash_attention_fwd float flash_attention_fwd_(const FlashAttnArgs& a, const ck_tile::stream_config& stream_config) {{ - using SaccDataType = typename Traits_::SaccDataType; - using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; - using PDataType = typename Traits_::PDataType; - using OaccDataType = typename Traits_::OaccDataType; - + using SaccDataType = typename Traits_::SaccDataType; + using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; + using PDataType = typename Traits_::PDataType; + using OaccDataType = typename Traits_::OaccDataType; + index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); if(stream_config.log_level_ > 0) @@ -430,10 +430,10 @@ float flash_attention_fwd_(const FlashAttnArgs str: -#Sort based on dtype + # Sort based on dtype t_dtype_dict = {} blobs = self.get_blobs(args) - + for blob in blobs: if blob.F_DataTypePair not in t_dtype_dict: t_dtype_dict[blob.F_DataTypePair] = {} @@ -445,16 +445,16 @@ float flash_attention_fwd_(const FlashAttnArgs None: w_p = Path(self.working_path) list_p = w_p / 'flash_attention_fwd_blobs.txt' blobs = self.get_blobs(args) - + with list_p.open('w') as list_f: -#API related files + # API related files list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") -#Kernel instance files + # Kernel instance files for b in blobs: list_f.write(str(w_p / (b.name + ".cpp")) + "\n") @@ -557,7 +557,7 @@ float flash_attention_fwd_(const FlashAttnArgs Date: Tue, 22 Apr 2025 15:34:29 +0000 Subject: [PATCH 10/21] Change the permission of FA CMakeList.txt to 644 --- .../ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt | 0 .../99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt mode change 100755 => 100644 example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt old mode 100755 new mode 100644 From 068d9fdbf71a04b4e25112e77942a5386ca503df Mon Sep 17 00:00:00 2001 From: BoboFang Date: Wed, 23 Apr 2025 09:47:37 +0000 Subject: [PATCH 11/21] Add MakeBlock2TileMap in 04_codegen_flash_attention_fwd --- .../flash_attention_fwd.hpp | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index 56767bec66..324d7924eb 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -16,6 +16,63 @@ namespace ck_tile { +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) +{ + return [=](index_t block_1d_id) { + constexpr index_t M01 = 4; + constexpr index_t GroupNum = 8; + + const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; + const auto update_M0 = + ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + + const auto xcd_id = block_1d_id % GroupNum; + + const auto l_block_id = block_1d_id - (xcd_id % 2); + + const auto ridn = GroupNum * M01 * (update_N0 / 2); + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; + + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum; + + auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); + auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); + + const auto total_update_size = update_N0 * update_M0; + + if(block_1d_id >= total_update_size) + { + auto x = (block_1d_id + 1) - total_update_size; + auto rlen = N0 - update_N0; + + auto rm = 0; + auto rn = 0; + if(rlen > 0) + { + rm = (x - 1) / rlen; + rn = x % rlen; + } + + if(rlen > 0 and rm < M0) + { + n = rn + update_N0; + m = rm; + } + else + { + x = x - rlen * M0; + rm = (x - 1) / update_N0; + rn = x % update_N0; + n = rn; + m = update_M0 + rm; + } + } + return make_multi_index(m, n); + }; +} + template struct FlashAttnArgs { From 35de33c57bfb7ff388632ebb88fcac00a4cef258 Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Wed, 23 Apr 2025 11:48:06 +0800 Subject: [PATCH 12/21] Add codegen instances The following examples have been tested for 04_codegen: ./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 256 256 ./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 64 64 ./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 32 32 ./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 128 128 ./bin/codegen_basic_flash_attention_fwd 1 1 64 2048 2048 128 128 ./bin/codegen_basic_flash_attention_fwd 1 1 64 512 512 128 128 --- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 1 - .../flash_attention_fwd.cpp | 22 +- .../block_gemm_areg_bsmem_creg_v1.hpp | 8 +- ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 19 +- ...emm_areg_bsmem_creg_v1_iteratek_policy.hpp | 19 +- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 4 +- .../flash_attention_fwd.cpp | 22 +- .../generate.py | 200 +++++------------- 8 files changed, 116 insertions(+), 179 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 3be95fdbae..c9e1e30c57 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -4,7 +4,6 @@ #pragma once #include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" namespace ck_tile { diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp index fba49f9de5..c797714421 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp @@ -36,25 +36,21 @@ int main(int argc, char* argv[]) ck_tile::index_t N1 = 128; // HeadDim ck_tile::index_t verification = 0; ck_tile::index_t init_method = 1; - [[maybe_unused]] ck_tile::index_t time_kernel = 0; - if(argc == 4) + if(argc == 3) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); + verification = std::stoi(argv[2]); } - - if(argc == 9) + else if(argc == 8) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); - Batch = std::stoi(argv[4]); - M0 = std::stoi(argv[5]); - N0 = std::stoi(argv[6]); - K0 = std::stoi(argv[7]); - N1 = std::stoi(argv[8]); + verification = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); } std::array q_lengths{Batch, M0, K0}; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 830b2422b5..33d36954c0 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -38,6 +38,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -46,7 +48,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -180,6 +182,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -188,7 +192,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index fb1516eb52..6eee7f0d1e 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1DefaultPolicy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + if constexpr (kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); + } + else if constexpr (kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); + } + else if constexpr (kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp index 32dc09f95e..3924a66daf 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1K8Policy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + if constexpr (kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); + } + else if constexpr (kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); + } + else if constexpr (kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 3be95fdbae..4dd0c9a1e0 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -62,11 +62,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr auto blockgemm = GetBlockGemm(); using BlockGemm = remove_cvref_t; + static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!"); + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = AKDim; constexpr auto config = - BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp index 9b7c9b1c6c..48680a218a 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp @@ -36,25 +36,21 @@ int main(int argc, char* argv[]) ck_tile::index_t N1 = 128; // HeadDim ck_tile::index_t verification = 0; ck_tile::index_t init_method = 1; - [[maybe_unused]] ck_tile::index_t time_kernel = 0; - if(argc == 4) + if(argc == 3) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); + verification = std::stoi(argv[2]); } - - if(argc == 9) + else if(argc == 8) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); - Batch = std::stoi(argv[4]); - M0 = std::stoi(argv[5]); - N0 = std::stoi(argv[6]); - K0 = std::stoi(argv[7]); - N1 = std::stoi(argv[8]); + verification = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); } std::array q_lengths{Batch, M0, K0}; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py index 040133f8e9..00bc91cadc 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py @@ -11,16 +11,8 @@ import itertools import copy from dataclasses import dataclass -# def get_if_str(idx, total, last_else=True): -# if idx == 0: -# return 'if' -# elif idx < total - 1: -# return 'else if' -# else: -# return 'else' if last_else else 'else if' - def get_if_str(size_, total, last_else=True): - if size_ == "small": + if size_ == "head_dim_256_seq_4096": return 'if' else: return 'else if' @@ -39,13 +31,13 @@ template + index_t kBlockSize_ = 256, + index_t kHeadDim_ = 128, + index_t kM0PerBlock_ = 128, + index_t kN0PerBlock_ = 128, + index_t kK0PerBlock_ = 64, + index_t kN1PerBlock_ = 128, + index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_ { using SaccDataType = ck_tile::remove_cvref_t; @@ -62,7 +54,7 @@ struct flash_attention_fwd_traits_ static constexpr index_t kK1PerBlock = kK1PerBlock_; static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD - static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; }; @@ -70,13 +62,13 @@ template + ck_tile::index_t kBlockSize = 256, + ck_tile::index_t kHeadDim = 128, + ck_tile::index_t kM0PerBlock = 128, + ck_tile::index_t kN0PerBlock = 128, + ck_tile::index_t kK0PerBlock = 64, + ck_tile::index_t kN1PerBlock = 128, + ck_tile::index_t kK1PerBlock = 64> using traits_ = flash_attention_fwd_traits_; """ -# API_COMMON_HEADER = """ -# // SPDX-License-Identifier: MIT -# // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -# #include -# #include "flash_attention_fwd.hpp" -# #include - -# #pragma once - -# using S = ck_tile::stream_config; -# using A = FlashAttnArgs; - -# {F_traits_define} - -# template -# float flash_attention_fwd_(const FlashAttnArgs& a, -# const ck_tile::stream_config& stream_config) {{ -# using SaccDataType = typename Traits_::SaccDataType; -# using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; -# using PDataType = typename Traits_::PDataType; -# using OaccDataType = typename Traits_::OaccDataType; - -# index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); - -# if(stream_config.log_level_ > 0) -# std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush; - -# return ck_tile::launch_kernel(stream_config, -# ck_tile::make_kernel( -# ck_tile::FlashAttentionFwd{{}}, -# kGridSize, -# Traits_::kBlockSize, -# 0, -# a.q_ptr, -# a.k_ptr, -# a.v_ptr, -# a.o_ptr, -# a.M0, -# a.N0, -# a.K0, -# a.N1, -# a.Batch, -# a.strideQ, // StrideQ -# a.strideK, // StrideK -# a.strideV, // StrideV -# a.strideO, // StrideO -# a.batchStrideQ, // BatchStrideQ -# a.batchStrideK, // BatchStrideK -# a.batchStrideV, // BatchStrideV -# a.batchStrideO)); // BatchStrideO -# }} -# """ - API_BASE = """ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. @@ -204,14 +124,6 @@ template float flash_attention_fwd && std::is_same_v && std::is_same_v && std::is_same_v) {{ -# {F_per_size_case} -# }} -# """ -# API_PER_SIZE_CASE = """ {F_if} {F_SIZE_COND} {{ -# {F_inner_dispatch} -# }} -# """ API_INNER_CASE = """ {F_if} {F_VEC_COND} r = flash_attention_fwd_>(a, stream_config); """ @@ -320,13 +232,13 @@ template + index_t kBlockSize_ = 256, + index_t kHeadDim_ = 128, + index_t kM0PerBlock_ = 128, + index_t kN0PerBlock_ = 128, + index_t kK0PerBlock_ = 64, + index_t kN1PerBlock_ = 128, + index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_ {{ using SaccDataType = ck_tile::remove_cvref_t; @@ -456,36 +368,28 @@ float flash_attention_fwd_(const FlashAttnArgs Date: Wed, 23 Apr 2025 13:27:32 +0000 Subject: [PATCH 13/21] [GEMM] Add define macro for unused a/b blk window --- .../02_gemm/block_gemm_asmem_bsmem_creg.hpp | 54 ++++++++----------- .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 4 +- 2 files changed, 25 insertions(+), 33 deletions(-) diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp index 0fb5d2de7d..5c58fa3d60 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp @@ -111,8 +111,8 @@ struct BlockGemmASmemBSmemCReg // C += A * B template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockWindowTmp& a_block_window_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const { static_assert(std::is_same_v && std::is_same_v && @@ -127,14 +127,11 @@ struct BlockGemmASmemBSmemCReg KPerBlock == BlockGemmShape::kK, "wrong!"); - // constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - - // using WarpGemm = remove_cvref_t())>; - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; +#if !defined(ENABLE_PREFETCH) constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; @@ -142,7 +139,7 @@ struct BlockGemmASmemBSmemCReg const index_t iMWarp = get_warp_id() / NWarp; const index_t iNWarp = get_warp_id() % NWarp; - // construct A-warp-window + // Construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -158,13 +155,12 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); - // construct B-warp-window + // Construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -180,16 +176,16 @@ struct BlockGemmASmemBSmemCReg static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); +#endif // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor + // Read A warp tensor from A block tensor AWarpTensor a_warp_tensor; #if defined(ENABLE_PREFETCH) #pragma message("local data share prefetch") @@ -200,7 +196,7 @@ struct BlockGemmASmemBSmemCReg a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor + // Read B warp tensor from B block tensor BWarpTensor b_warp_tensor; #if defined(ENABLE_PREFETCH) b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( @@ -209,17 +205,17 @@ struct BlockGemmASmemBSmemCReg #else b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif - // read C warp tensor from C block tensor + // Read C warp tensor from C block tensor CWarpTensor c_warp_tensor; c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM + // Warp GEMM WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor + // Write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), @@ -231,8 +227,8 @@ struct BlockGemmASmemBSmemCReg // C = A * B template - CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const { static_assert(std::is_same_v && std::is_same_v, @@ -246,14 +242,11 @@ struct BlockGemmASmemBSmemCReg KPerBlock == BlockGemmShape::kK, "wrong!"); - // constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - - // using WarpGemm = remove_cvref_t{}))>; - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; +#if !defined(ENABLE_PREFETCH) constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; @@ -261,7 +254,7 @@ struct BlockGemmASmemBSmemCReg const index_t iMWarp = get_warp_id() / NWarp; const index_t iNWarp = get_warp_id() % NWarp; - // construct A-warp-window + // Construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -277,13 +270,12 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); - // construct B-warp-window + // Construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -299,11 +291,11 @@ struct BlockGemmASmemBSmemCReg static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); +#endif static_assert(std::is_same_v, "wrong!"); @@ -323,10 +315,10 @@ struct BlockGemmASmemBSmemCReg auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - // hot loop: + // Hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor + // Read A warp tensor from A block tensor AWarpTensor a_warp_tensor; #if defined(ENABLE_PREFETCH) a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( @@ -336,7 +328,7 @@ struct BlockGemmASmemBSmemCReg a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor + // Read B warp tensor from B block tensor BWarpTensor b_warp_tensor; #if defined(ENABLE_PREFETCH) b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( @@ -345,10 +337,10 @@ struct BlockGemmASmemBSmemCReg #else b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif - // read C warp tensor from C block tensor + // Read C warp tensor from C block tensor CWarpTensor c_warp_tensor; - // warp GEMM + // Warp GEMM if constexpr(KIterPerWarp == 0) { // c = a * b @@ -364,7 +356,7 @@ struct BlockGemmASmemBSmemCReg WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); } - // write C warp tensor into C block tensor + // Write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index 57a0614c7f..effcc2b101 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -307,7 +307,7 @@ struct BlockGemmPipelineAGmemBGmemCReg // Gemm pipeline start #if defined(ENABLE_PREFETCH) - +#pragma message("global prefetch") // Initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -342,7 +342,7 @@ struct BlockGemmPipelineAGmemBGmemCReg // Main body if(num_loop > 2) { - index_t i = 0; + index_t iCounter = 0; do { block_sync_lds(); From ce4061847b73443d26be2f99dc3c34dfacd35460 Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Wed, 23 Apr 2025 13:58:34 +0800 Subject: [PATCH 14/21] Remove unused code --- .../flash_attention_fwd_impl.hpp | 8 +- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 1 - .../flash_attention_fwd.hpp | 178 ------------------ .../flash_attention_fwd_impl.hpp | 8 +- example/ck_tile/99_toy_example/README.md | 2 +- 5 files changed, 7 insertions(+), 190 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index c10689567f..4aae01d1c8 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -4,17 +4,15 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" -#include "ck_tile/core/tensor/tile_distribution.hpp" - -#include "tile_gemm_shape.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "ck_tile/ops/reduce.hpp" #include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp" #include "block_gemm_pipeline_problem.hpp" #include "block_gemm_areg_bsmem_creg_v1.hpp" -#include "ck_tile/ops/reduce.hpp" +#include "tile_gemm_shape.hpp" namespace ck_tile { diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 4dd0c9a1e0..5642617856 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -4,7 +4,6 @@ #pragma once #include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" namespace ck_tile { diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index 324d7924eb..c79e6f6094 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -145,7 +145,6 @@ struct FlashAttentionFwd const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock); const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock); -#if defined(GEMM_OPT) const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1); const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0; @@ -157,20 +156,6 @@ struct FlashAttentionFwd const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % num_tile_n1 * kN1PerBlock); -#else - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - - return make_tuple(quotient, modulus); - }; - const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); - const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); - const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); - const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); - const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); - -#endif const auto kernel_impl = FlashAttentionFwdImpl -// struct flash_attention_fwd_traits_ -// { -// using SaccDataType = ck_tile::remove_cvref_t; -// using SMPLComputeDataType = ck_tile::remove_cvref_t; -// using PDataType = ck_tile::remove_cvref_t; -// using OaccDataType = ck_tile::remove_cvref_t; - -// static constexpr index_t kBlockSize = kBlockSize_; -// static constexpr index_t kHeadDim = kHeadDim_; -// static constexpr index_t kM0PerBlock = kM0PerBlock_; -// static constexpr index_t kN0PerBlock = kN0PerBlock_; -// static constexpr index_t kK0PerBlock = kK0PerBlock_; -// static constexpr index_t kN1PerBlock = kN1PerBlock_; -// static constexpr index_t kK1PerBlock = kK1PerBlock_; - -// static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD -// static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; -// static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; -// }; - -// // TODO: fwd_api.cpp, fwd_common.cpp -// template -// using traits_ = flash_attention_fwd_traits_; -// // fw_api.cpp -// // Note: this internal API only declare, not define here, otherwise will block `make -j` -// template -// float flash_attention_fwd_(const FlashAttnArgs& a, -// const ck_tile::stream_config& stream_config); - -// // TODO: fwd_common.cpp -// template -// float flash_attention_fwd_(const FlashAttnArgs& a, -// const ck_tile::stream_config& stream_config) { -// using SaccDataType = typename Traits_::SaccDataType; -// using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; -// using PDataType = typename Traits_::PDataType; -// using OaccDataType = typename Traits_::OaccDataType; - -// index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); - -// std::cout << "grid size " << kGridSize << std::endl; - -// return ck_tile::launch_kernel(stream_config, -// ck_tile::make_kernel( -// ck_tile::FlashAttentionFwd{}, -// kGridSize, -// Traits_::kBlockSize, -// 0, -// a.q_ptr, -// a.k_ptr, -// a.v_ptr, -// a.o_ptr, -// a.M0, -// a.N0, -// a.K0, -// a.N1, -// a.Batch, -// a.strideQ, // StrideQ -// a.strideK, // StrideK -// a.strideV, // StrideV -// a.strideO, // StrideO -// a.batchStrideQ, // BatchStrideQ -// a.batchStrideK, // BatchStrideK -// a.batchStrideV, // BatchStrideV -// a.batchStrideO)); // BatchStrideO -// } - -// // TODO: change to only declare -// // TODO: fwd_api.cpp -// template -// float flash_attention_fwd(const FlashAttnArgs& a, -// const ck_tile::stream_config& stream_config) { -// constexpr ck_tile::index_t kM0PerBlock = 128; -// constexpr ck_tile::index_t kN0PerBlock = 128; -// constexpr ck_tile::index_t kK0PerBlock = 32; -// constexpr ck_tile::index_t kN1PerBlock = 128; -// constexpr ck_tile::index_t kK1PerBlock = 32; - -// constexpr ck_tile::index_t kBlockSize = 256; -// constexpr ck_tile::index_t kHeadDim = 128; - -// return flash_attention_fwd_> -// (a, stream_config); - -// } - -// TODO: change to only declare -// TODO: fwd_api.cpp template Date: Wed, 23 Apr 2025 14:41:09 +0800 Subject: [PATCH 15/21] Add more warp gemm policies for FA --- .../block_gemm_areg_bsmem_creg_v1.hpp | 8 ++++++-- ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 19 +++++++++++++++++-- ...emm_areg_bsmem_creg_v1_iteratek_policy.hpp | 19 +++++++++++++++++-- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 4 +++- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 76548985ab..23ca02946b 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -130,6 +130,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -138,7 +140,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -272,6 +274,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -280,7 +284,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index fb1516eb52..6eee7f0d1e 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1DefaultPolicy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + if constexpr (kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); + } + else if constexpr (kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); + } + else if constexpr (kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp index 32dc09f95e..3924a66daf 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1K8Policy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + if constexpr (kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); + } + else if constexpr (kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); + } + else if constexpr (kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index c9e1e30c57..5642617856 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -61,11 +61,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr auto blockgemm = GetBlockGemm(); using BlockGemm = remove_cvref_t; + static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!"); + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = AKDim; constexpr auto config = - BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; From 257a06ef54d1483573fae4b18b971ad0ae65aaae Mon Sep 17 00:00:00 2001 From: Clement Lin Date: Wed, 23 Apr 2025 15:04:24 +0800 Subject: [PATCH 16/21] Add the warp gemm option --- .../03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 23ca02946b..fc18640958 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -36,7 +36,7 @@ struct BlockGemmARegBSmemCRegV1 constexpr index_t NPerBlock = BlockGemmShape::kN; constexpr index_t KPerBlock = BlockGemmShape::kK; - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; constexpr index_t MPerXDL = WG::kM; constexpr index_t NPerXDL = WG::kN; From c4b2d5074a9ce3debdfc5750f4a3a3ff72f8e88c Mon Sep 17 00:00:00 2001 From: MHYang Date: Wed, 23 Apr 2025 23:06:23 +0000 Subject: [PATCH 17/21] Implement prefetch and instruction schedule --- .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 23 +- .../03_flash_attention_fwd/CMakeLists.txt | 7 + .../block_gemm_areg_bsmem_creg_v1.hpp | 152 +++++++++++ ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 80 ++++-- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 120 ++++++--- ..._pipeline_agmem_bgmem_creg_policy_impl.hpp | 206 --------------- .../flash_attention_fwd.hpp | 5 +- .../flash_attention_fwd_impl.hpp | 107 ++++++-- .../block_gemm_areg_bsmem_creg_v1.hpp | 244 ++++++++++++++++++ ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 85 ++++-- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 120 ++++++--- ..._pipeline_agmem_bgmem_creg_policy_impl.hpp | 206 --------------- .../flash_attention_fwd.hpp | 1 - .../flash_attention_fwd_impl.hpp | 111 ++++++-- 14 files changed, 865 insertions(+), 602 deletions(-) delete mode 100644 example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp delete mode 100644 example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index effcc2b101..1aa11791dd 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -306,11 +306,11 @@ struct BlockGemmPipelineAGmemBGmemCReg // ------------------------------------------------------------------------------------- // Gemm pipeline start -#if defined(ENABLE_PREFETCH) -#pragma message("global prefetch") // Initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); +#if defined(ENABLE_PREFETCH) +#pragma message("global prefetch") // Prefetch // Global read 0 a_block_tile = load_tile(a_copy_dram_window); @@ -325,7 +325,7 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); - // Global read 0 + // Global read 1 a_block_tile = load_tile(a_copy_dram_window); b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); @@ -347,11 +347,11 @@ struct BlockGemmPipelineAGmemBGmemCReg { block_sync_lds(); - // LDS write 0 + // LDS write 1 store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); - // Global read 0 + // Global read 2 a_block_tile = load_tile(a_copy_dram_window); b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); @@ -387,18 +387,7 @@ struct BlockGemmPipelineAGmemBGmemCReg block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); #else // non-prefetch - a_block_tile = load_tile(a_copy_dram_window); - b_block_tile = load_tile(b_copy_dram_window); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - store_tile(a_copy_lds_window, a_block_tile); - store_tile(b_copy_lds_window, b_block_tile); - - block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); - - index_t iCounter = num_loop - 1; + index_t iCounter = num_loop; while(iCounter > 0) { diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt index 44dfac099c..4c71936c61 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt @@ -10,6 +10,13 @@ set(EXAMPLE_REDUCE_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" OFF) +if(ENABLE_TOY_FA_FWD_OPT) + message("Compiling with toy FA fwd optimization") + # target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT) + add_definitions(-DTOY_FA_FWD_OPT) +endif() + target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index fc18640958..e6b1707851 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -29,6 +29,35 @@ struct BlockGemmARegBSmemCRegV1 static constexpr index_t kPackedSize = ck_tile::numeric_traits>::PackedSize; + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + template CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { @@ -118,6 +147,129 @@ struct BlockGemmARegBSmemCRegV1 }); } + // C += A * B + template + __device__ void operator() (CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BLdsTile& b_block_tensor_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + // C += A * B template __device__ void operator()(CBlockTensor& c_block_tensor, diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 02fe7f54ad..4383ce76cb 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -13,16 +13,13 @@ namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCReg< - Problem, - BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy> +template +struct BlockGemmPipelineAGmemBGmemCReg { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -222,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< ignore = b_element_func; + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + // A tile in Reg,blockTensor // This tensor distribution used to construct both distributed tensor for local buffer store // and read. without buffer address info @@ -261,62 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg< // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); + b_lds_block, make_tuple(number{}, number{}), {0, 0}, + make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile auto c_block_tile = decltype(block_gemm( get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); +#if !defined(TOY_FA_FWD_OPT) + static_for<0, k_loops, 1>{}([&](auto i_k0) { + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + }); +#else + using BLdsTile = typename decltype(block_gemm)::BLdsTile; + BLdsTile bWarpTile; + + // Global read 0 + auto b_block_tile = load_tile(b_copy_dram_window); if constexpr(k_loops > 1) { move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // LDS write 0 store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + block_sync_lds(); + + // LDS read 0 + bWarpTile = load_tile(b_lds_gemm_window); } - __builtin_amdgcn_sched_barrier(0); if constexpr(k_loops > 2) { + __builtin_amdgcn_sched_barrier(0); static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { block_sync_lds(); + // LDS write 1 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); block_sync_lds(); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - store_tile(b_copy_lds_window, b_block_tile); - b_block_tile = load_tile(b_copy_dram_window); + // LDS read 1 + bWarpTile = load_tile(b_lds_gemm_window); block_gemm.HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); }); } - // tail { if constexpr(k_loops > 1) { - block_sync_lds(); - block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); block_sync_lds(); } @@ -324,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_sync_lds(); + bWarpTile = load_tile(b_lds_gemm_window); + block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); } - +#endif return c_block_tile; } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 5642617856..68afedda27 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -3,42 +3,15 @@ #pragma once -#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" namespace ck_tile { -// NOTE: Assume A is K-Major -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy +template +struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy { - template - __host__ __device__ static constexpr auto MakeARegBlockDescriptor() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_reg_block_descriptor(); - } - - template - __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() - { - return policy_impl::make_b_lds_block_descriptor_3d_pad(); - } - - template - __host__ __device__ static constexpr auto MakeADramTileDistribution() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_dram_tile_distribution_skip_lds(); - } - - template - __host__ __device__ static constexpr auto MakeBDramTileDistribution() - { - return policy_impl::make_b_dram_tile_distribution(); - } + static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto GetBlockGemm() @@ -47,13 +20,7 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy return BlockGemmARegBSmemCRegV1{}; } -}; -template -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy - : BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy -{ - static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() @@ -93,11 +60,88 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + template __host__ __device__ static constexpr auto MakeADramTileDistribution() { return MakeARegBlockDescriptor(); } + + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp deleted file mode 100644 index 4d48478084..0000000000 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ /dev/null @@ -1,206 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck_tile/core.hpp" -#include "ck_tile/core/tensor/tile_distribution.hpp" - -namespace ck_tile { -namespace policy_impl { - -// 3d + padding -template -__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto a_lds_block_desc = - transform_tensor_descriptor(a_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; -} - -// 3d + padding -template -__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() -{ - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; - - using BDataType = remove_cvref_t; - - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; -} - -template -__host__ __device__ static constexpr auto make_a_reg_block_descriptor() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - constexpr index_t NWarp = config.template get<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); - - return a_block_dstr; -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution() -{ - using ADataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds() -{ - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K2 = - WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane; - // // 16 / sizeof(ADataType); - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 0, 2>>{}); -} - -template -__host__ __device__ static constexpr auto make_b_dram_tile_distribution() -{ - using BDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto get_block_gemm() -{ - using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy; - - return BlockGemmASmemBSmemCReg{}; -} - -} // namespace policy_impl -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp index c880639621..7fc5a78806 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp @@ -8,11 +8,11 @@ #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" #include "block_gemm_pipeline_problem.hpp" #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "flash_attention_fwd_impl.hpp" + namespace ck_tile { CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) @@ -115,7 +115,8 @@ struct FlashAttentionFwd const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock); const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock); -#if defined(GEMM_OPT) +#if defined(TOY_FA_FWD_OPT) +#pragma message("Enable toy FA fwd opt") const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1); const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 4aae01d1c8..689b557500 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -14,6 +14,7 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "tile_gemm_shape.hpp" + namespace ck_tile { // S[M0, N0] = Q[M0, K0] * K[N0, K0] @@ -152,6 +153,10 @@ struct FlashAttentionFwdImpl constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + // allocate LDS __shared__ char smem_ptr[GetStaticLdsSize()]; @@ -179,7 +184,6 @@ struct FlashAttentionFwdImpl make_tuple(number{}, number{}), {iN1, 0}, MakeVDramTileDistribution()); - // Q in register auto q_reg_tensor = load_tile(q_dram_window); @@ -188,12 +192,22 @@ struct FlashAttentionFwdImpl auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); +#if defined(TOY_FA_FWD_OPT) + // V LDS tile window for store + auto v_copy_lds_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); + + // V LDS tile for block GEMM + auto v_lds_gemm_window = make_tile_window( + v_lds, make_tuple(number{}, number{}), {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); +#else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM0 pipeline and Block GEMM1 - constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; - constexpr auto gemm1 = BlockGemm1{}; +#endif // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; @@ -239,9 +253,10 @@ struct FlashAttentionFwdImpl const auto s = tile_elementwise_in(type_convert, s_acc); +#if defined(TOY_FA_FWD_OPT) // prefetch load v tile - const auto v_prefetch = load_tile(v_dram_window); - + auto v_prefetch = load_tile(v_dram_window); +#endif // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, sequence<1>{}, f_max, std::numeric_limits::lowest()); @@ -291,10 +306,30 @@ struct FlashAttentionFwdImpl o_acc(i_j_idx) *= tmp; }); }); - block_sync_lds(); - store_tile(v_lds_window, v_prefetch); - move_tile_window(v_dram_window, {0, kK1PerBlock}); +#if !defined(TOY_FA_FWD_OPT) + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_lds_window, v); + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); + block_sync_lds(); + }); +#else + using VLdsTile = typename decltype(gemm1)::BLdsTile; + VLdsTile vWarpTile; // type cast Pcompute{j} into P{j} const auto p = @@ -304,34 +339,60 @@ struct FlashAttentionFwdImpl constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; if constexpr(k1_loops > 1) + { + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_copy_lds_window, v_prefetch); + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + } + if constexpr(k1_loops > 2) { __builtin_amdgcn_sched_barrier(0); - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v + static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) { block_sync_lds(); - gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - v_lds_window); - block_sync_lds(); - store_tile(v_lds_window, v); + + // LDS write 1 + store_tile(v_copy_lds_window, v_prefetch); + + // Global read 2 + v_prefetch = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1PerBlock}); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); gemm1.template HotLoopScheduler<8, 4>(); __builtin_amdgcn_sched_barrier(0); }); } // tail { + if constexpr (k1_loops > 1) + { + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } + store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 1) * kK1PerBlock>{}, - sequence{}), - v_lds_window); + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } +#endif // move tile windows move_tile_window(k_dram_window, {kN0PerBlock, 0}); iN0 += kN0PerBlock; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 33d36954c0..e6b1707851 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -26,6 +26,250 @@ struct BlockGemmARegBSmemCRegV1 static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + template + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MPerXDL = WG::kM; + constexpr index_t NPerXDL = WG::kN; + constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = config.template get<1>(); + + constexpr index_t B_LDS_RW_Width = SmemPack; + + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * VectorSizeB); + + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + // B split schedule + constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_b_issue_cycle = + B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + // C += A * B + template + __device__ void operator() (CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BLdsTile& b_block_tensor_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + // C += A * B template __device__ void operator()(CBlockTensor& c_block_tensor, diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index ec9484ffd1..4383ce76cb 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -13,16 +13,13 @@ namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCReg< - Problem, - BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy> +template +struct BlockGemmPipelineAGmemBGmemCReg { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -134,6 +131,8 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_block_tile = load_tile(b_copy_dram_window); } + __builtin_amdgcn_sched_barrier(0); + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -158,6 +157,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } @@ -217,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< ignore = b_element_func; + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + // A tile in Reg,blockTensor // This tensor distribution used to construct both distributed tensor for local buffer store // and read. without buffer address info @@ -256,58 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg< // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); + b_lds_block, make_tuple(number{}, number{}), {0, 0}, + make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile auto c_block_tile = decltype(block_gemm( get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); +#if !defined(TOY_FA_FWD_OPT) + static_for<0, k_loops, 1>{}([&](auto i_k0) { + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + }); +#else + using BLdsTile = typename decltype(block_gemm)::BLdsTile; + BLdsTile bWarpTile; + + // Global read 0 + auto b_block_tile = load_tile(b_copy_dram_window); if constexpr(k_loops > 1) { move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // LDS write 0 store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + block_sync_lds(); + + // LDS read 0 + bWarpTile = load_tile(b_lds_gemm_window); } if constexpr(k_loops > 2) { + __builtin_amdgcn_sched_barrier(0); static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { block_sync_lds(); + // LDS write 1 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); block_sync_lds(); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // LDS read 1 + bWarpTile = load_tile(b_lds_gemm_window); - store_tile(b_copy_lds_window, b_block_tile); - b_block_tile = load_tile(b_copy_dram_window); + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } - // tail { if constexpr(k_loops > 1) { - block_sync_lds(); - block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); block_sync_lds(); } @@ -315,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_sync_lds(); + bWarpTile = load_tile(b_lds_gemm_window); + block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); } - +#endif return c_block_tile; } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 5642617856..68afedda27 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -3,42 +3,15 @@ #pragma once -#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" namespace ck_tile { -// NOTE: Assume A is K-Major -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy +template +struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy { - template - __host__ __device__ static constexpr auto MakeARegBlockDescriptor() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_reg_block_descriptor(); - } - - template - __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() - { - return policy_impl::make_b_lds_block_descriptor_3d_pad(); - } - - template - __host__ __device__ static constexpr auto MakeADramTileDistribution() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_dram_tile_distribution_skip_lds(); - } - - template - __host__ __device__ static constexpr auto MakeBDramTileDistribution() - { - return policy_impl::make_b_dram_tile_distribution(); - } + static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto GetBlockGemm() @@ -47,13 +20,7 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy return BlockGemmARegBSmemCRegV1{}; } -}; -template -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy - : BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy -{ - static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() @@ -93,11 +60,88 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + template __host__ __device__ static constexpr auto MakeADramTileDistribution() { return MakeARegBlockDescriptor(); } + + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp deleted file mode 100644 index 4d48478084..0000000000 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ /dev/null @@ -1,206 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck_tile/core.hpp" -#include "ck_tile/core/tensor/tile_distribution.hpp" - -namespace ck_tile { -namespace policy_impl { - -// 3d + padding -template -__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto a_lds_block_desc = - transform_tensor_descriptor(a_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; -} - -// 3d + padding -template -__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() -{ - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; - - using BDataType = remove_cvref_t; - - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; -} - -template -__host__ __device__ static constexpr auto make_a_reg_block_descriptor() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - constexpr index_t NWarp = config.template get<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); - - return a_block_dstr; -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution() -{ - using ADataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds() -{ - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K2 = - WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane; - // // 16 / sizeof(ADataType); - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 0, 2>>{}); -} - -template -__host__ __device__ static constexpr auto make_b_dram_tile_distribution() -{ - using BDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto get_block_gemm() -{ - using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy; - - return BlockGemmASmemBSmemCReg{}; -} - -} // namespace policy_impl -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index c79e6f6094..dcb901c0a2 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -9,7 +9,6 @@ #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" #include "block_gemm_pipeline_problem.hpp" #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "flash_attention_fwd_impl.hpp" diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp index 0b396ea59f..689b557500 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -14,6 +14,7 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "tile_gemm_shape.hpp" + namespace ck_tile { // S[M0, N0] = Q[M0, K0] * K[N0, K0] @@ -152,6 +153,10 @@ struct FlashAttentionFwdImpl constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + // allocate LDS __shared__ char smem_ptr[GetStaticLdsSize()]; @@ -179,7 +184,6 @@ struct FlashAttentionFwdImpl make_tuple(number{}, number{}), {iN1, 0}, MakeVDramTileDistribution()); - // Q in register auto q_reg_tensor = load_tile(q_dram_window); @@ -188,12 +192,22 @@ struct FlashAttentionFwdImpl auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); +#if defined(TOY_FA_FWD_OPT) + // V LDS tile window for store + auto v_copy_lds_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); + + // V LDS tile for block GEMM + auto v_lds_gemm_window = make_tile_window( + v_lds, make_tuple(number{}, number{}), {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); +#else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM0 pipeline and Block GEMM1 - constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; - constexpr auto gemm1 = BlockGemm1{}; +#endif // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; @@ -239,9 +253,10 @@ struct FlashAttentionFwdImpl const auto s = tile_elementwise_in(type_convert, s_acc); +#if defined(TOY_FA_FWD_OPT) // prefetch load v tile - const auto v_prefetch = load_tile(v_dram_window); - + auto v_prefetch = load_tile(v_dram_window); +#endif // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, sequence<1>{}, f_max, std::numeric_limits::lowest()); @@ -291,10 +306,30 @@ struct FlashAttentionFwdImpl o_acc(i_j_idx) *= tmp; }); }); - block_sync_lds(); - store_tile(v_lds_window, v_prefetch); - move_tile_window(v_dram_window, {0, kK1PerBlock}); +#if !defined(TOY_FA_FWD_OPT) + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_lds_window, v); + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); + block_sync_lds(); + }); +#else + using VLdsTile = typename decltype(gemm1)::BLdsTile; + VLdsTile vWarpTile; // type cast Pcompute{j} into P{j} const auto p = @@ -305,29 +340,59 @@ struct FlashAttentionFwdImpl if constexpr(k1_loops > 1) { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_copy_lds_window, v_prefetch); + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + } + if constexpr(k1_loops > 2) + { + __builtin_amdgcn_sched_barrier(0); + static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) { block_sync_lds(); - gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - v_lds_window); - block_sync_lds(); - store_tile(v_lds_window, v); + + // LDS write 1 + store_tile(v_copy_lds_window, v_prefetch); + + // Global read 2 + v_prefetch = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1PerBlock}); + + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + gemm1.template HotLoopScheduler<8, 4>(); + __builtin_amdgcn_sched_barrier(0); }); } // tail { + if constexpr (k1_loops > 1) + { + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } + store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 1) * kK1PerBlock>{}, - sequence{}), - v_lds_window); + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } +#endif // move tile windows move_tile_window(k_dram_window, {kN0PerBlock, 0}); iN0 += kN0PerBlock; From 0e6a23258e85afa09cd75788c31645f97fb045d2 Mon Sep 17 00:00:00 2001 From: MHYang Date: Thu, 24 Apr 2025 09:09:18 +0000 Subject: [PATCH 18/21] Fix clang-format --- .../block_gemm_areg_bsmem_creg_v1.hpp | 17 +- ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 6 +- ...emm_areg_bsmem_creg_v1_iteratek_policy.hpp | 6 +- ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 25 +- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 15 +- .../flash_attention_fwd.cpp | 14 +- .../flash_attention_fwd.hpp | 1 - .../flash_attention_fwd_impl.hpp | 49 +- .../block_gemm_areg_bsmem_creg_v1.hpp | 17 +- ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 6 +- ...emm_areg_bsmem_creg_v1_iteratek_policy.hpp | 6 +- ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 25 +- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 15 +- .../flash_attention_fwd.cpp | 14 +- .../flash_attention_fwd.hpp | 1 - .../flash_attention_fwd_impl.hpp | 49 +- .../generate.py | 539 +++++++++--------- 17 files changed, 408 insertions(+), 397 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index e6b1707851..129a4c5ed5 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -32,7 +32,8 @@ struct BlockGemmARegBSmemCRegV1 // B block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = + Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; constexpr index_t MWarp = config.template get<1>(); @@ -55,7 +56,8 @@ struct BlockGemmARegBSmemCRegV1 return b_block_dstr_encode; } - static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); template @@ -149,15 +151,14 @@ struct BlockGemmARegBSmemCRegV1 // C += A * B template - __device__ void operator() (CBlockTensor& c_block_tensor, + __device__ void operator()(CBlockTensor& c_block_tensor, const ABlockTensorTmp& a_block_tensor_tmp, const BLdsTile& b_block_tensor_tmp) const { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 6eee7f0d1e..8994689841 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - if constexpr (kM0 == 64) + if constexpr(kM0 == 64) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); } - else if constexpr (kM0 == 32) + else if constexpr(kM0 == 32) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); } - else if constexpr (kM0 == 128) + else if constexpr(kM0 == 128) { return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp index 3924a66daf..e3f3fd0cd6 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1K8Policy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - if constexpr (kM0 == 64) + if constexpr(kM0 == 64) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); } - else if constexpr (kM0 == 32) + else if constexpr(kM0 == 32) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); } - else if constexpr (kM0 == 128) + else if constexpr(kM0 == 128) { return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 4383ce76cb..928ca83f65 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -261,7 +261,9 @@ struct BlockGemmPipelineAGmemBGmemCReg // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}, + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile @@ -269,7 +271,6 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); #if !defined(TOY_FA_FWD_OPT) @@ -279,10 +280,10 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(b_copy_lds_window, b_block_tile); block_sync_lds(); block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); block_sync_lds(); }); #else @@ -322,10 +323,10 @@ struct BlockGemmPipelineAGmemBGmemCReg move_tile_window(b_copy_dram_window, {0, kKPerBlock}); block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - bWarpTile); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); block_sync_lds(); @@ -344,7 +345,7 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - bWarpTile); + bWarpTile); block_sync_lds(); } @@ -358,7 +359,7 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - bWarpTile); + bWarpTile); } #endif return c_block_tile; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 68afedda27..9b52143c92 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -21,7 +21,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return BlockGemmARegBSmemCRegV1{}; } - template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() { @@ -60,14 +59,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } - template __host__ __device__ static constexpr auto MakeADramTileDistribution() { return MakeARegBlockDescriptor(); } - template __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() { @@ -99,24 +96,24 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, - make_tuple( - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, make_tuple( - make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{})), make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } - template __host__ __device__ static constexpr auto MakeBDramTileDistribution() { diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp index c797714421..4ce61ed20c 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp @@ -29,13 +29,13 @@ int main(int argc, char* argv[]) using OaccDataType = float; using ODataType = ck_tile::half_t; - ck_tile::index_t Batch = 64; // Batch Number * Head Number - ck_tile::index_t M0 = 4096; // SequenceLengthQ - ck_tile::index_t N0 = 4096; // SequencelengthK - ck_tile::index_t K0 = 128; // HeadDim - ck_tile::index_t N1 = 128; // HeadDim - ck_tile::index_t verification = 0; - ck_tile::index_t init_method = 1; + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; if(argc == 3) { diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp index 7fc5a78806..4317ebee8d 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp @@ -12,7 +12,6 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "flash_attention_fwd_impl.hpp" - namespace ck_tile { CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 689b557500..bffed23722 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -14,7 +14,6 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "tile_gemm_shape.hpp" - namespace ck_tile { // S[M0, N0] = Q[M0, K0] * K[N0, K0] @@ -196,14 +195,16 @@ struct FlashAttentionFwdImpl // V LDS tile window for store auto v_copy_lds_window = make_tile_window(v_lds, - make_tuple(number{}, number{}), - {0, 0}, - v_dram_window.get_tile_distribution()); + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); // V LDS tile for block GEMM - auto v_lds_gemm_window = make_tile_window( - v_lds, make_tuple(number{}, number{}), {0, 0}, - make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); + auto v_lds_gemm_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); #else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); @@ -321,10 +322,10 @@ struct FlashAttentionFwdImpl store_tile(v_lds_window, v); block_sync_lds(); gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - v_lds_window); + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); block_sync_lds(); }); #else @@ -361,10 +362,10 @@ struct FlashAttentionFwdImpl move_tile_window(v_dram_window, {0, kK1PerBlock}); gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); vWarpTile = load_tile(v_lds_gemm_window); gemm1.template HotLoopScheduler<8, 4>(); @@ -373,23 +374,23 @@ struct FlashAttentionFwdImpl } // tail { - if constexpr (k1_loops > 1) + if constexpr(k1_loops > 1) { gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 2) * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 1) * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } #endif diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index e6b1707851..129a4c5ed5 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -32,7 +32,8 @@ struct BlockGemmARegBSmemCRegV1 // B block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = + Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; constexpr index_t MWarp = config.template get<1>(); @@ -55,7 +56,8 @@ struct BlockGemmARegBSmemCRegV1 return b_block_dstr_encode; } - static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); template @@ -149,15 +151,14 @@ struct BlockGemmARegBSmemCRegV1 // C += A * B template - __device__ void operator() (CBlockTensor& c_block_tensor, + __device__ void operator()(CBlockTensor& c_block_tensor, const ABlockTensorTmp& a_block_tensor_tmp, const BLdsTile& b_block_tensor_tmp) const { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 6eee7f0d1e..8994689841 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - if constexpr (kM0 == 64) + if constexpr(kM0 == 64) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); } - else if constexpr (kM0 == 32) + else if constexpr(kM0 == 32) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); } - else if constexpr (kM0 == 128) + else if constexpr(kM0 == 128) { return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp index 3924a66daf..e3f3fd0cd6 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1K8Policy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - if constexpr (kM0 == 64) + if constexpr(kM0 == 64) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); } - else if constexpr (kM0 == 32) + else if constexpr(kM0 == 32) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); } - else if constexpr (kM0 == 128) + else if constexpr(kM0 == 128) { return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 4383ce76cb..928ca83f65 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -261,7 +261,9 @@ struct BlockGemmPipelineAGmemBGmemCReg // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}, + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile @@ -269,7 +271,6 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); #if !defined(TOY_FA_FWD_OPT) @@ -279,10 +280,10 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(b_copy_lds_window, b_block_tile); block_sync_lds(); block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); block_sync_lds(); }); #else @@ -322,10 +323,10 @@ struct BlockGemmPipelineAGmemBGmemCReg move_tile_window(b_copy_dram_window, {0, kKPerBlock}); block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - bWarpTile); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); block_sync_lds(); @@ -344,7 +345,7 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - bWarpTile); + bWarpTile); block_sync_lds(); } @@ -358,7 +359,7 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - bWarpTile); + bWarpTile); } #endif return c_block_tile; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 68afedda27..9b52143c92 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -21,7 +21,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return BlockGemmARegBSmemCRegV1{}; } - template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() { @@ -60,14 +59,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } - template __host__ __device__ static constexpr auto MakeADramTileDistribution() { return MakeARegBlockDescriptor(); } - template __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() { @@ -99,24 +96,24 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, - make_tuple( - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, make_tuple( - make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{})), make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } - template __host__ __device__ static constexpr auto MakeBDramTileDistribution() { diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp index 48680a218a..3750ede188 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp @@ -29,13 +29,13 @@ int main(int argc, char* argv[]) using OaccDataType = float; using ODataType = ck_tile::half_t; - ck_tile::index_t Batch = 64; // Batch Number * Head Number - ck_tile::index_t M0 = 4096; // SequenceLengthQ - ck_tile::index_t N0 = 4096; // SequencelengthK - ck_tile::index_t K0 = 128; // HeadDim - ck_tile::index_t N1 = 128; // HeadDim - ck_tile::index_t verification = 0; - ck_tile::index_t init_method = 1; + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; if(argc == 3) { diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index dcb901c0a2..38c56a27e8 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -155,7 +155,6 @@ struct FlashAttentionFwd const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % num_tile_n1 * kN1PerBlock); - const auto kernel_impl = FlashAttentionFwdImpl{}, number{}), - {0, 0}, - v_dram_window.get_tile_distribution()); + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); // V LDS tile for block GEMM - auto v_lds_gemm_window = make_tile_window( - v_lds, make_tuple(number{}, number{}), {0, 0}, - make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); + auto v_lds_gemm_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); #else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); @@ -321,10 +322,10 @@ struct FlashAttentionFwdImpl store_tile(v_lds_window, v); block_sync_lds(); gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - v_lds_window); + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); block_sync_lds(); }); #else @@ -361,10 +362,10 @@ struct FlashAttentionFwdImpl move_tile_window(v_dram_window, {0, kK1PerBlock}); gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); vWarpTile = load_tile(v_lds_gemm_window); gemm1.template HotLoopScheduler<8, 4>(); @@ -373,23 +374,23 @@ struct FlashAttentionFwdImpl } // tail { - if constexpr (k1_loops > 1) + if constexpr(k1_loops > 1) { gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 2) * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 1) * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } #endif diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py index 00bc91cadc..10def9a5dd 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py @@ -1,33 +1,18 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#SPDX - License - Identifier : MIT +#Copyright(c) 2025, Advanced Micro Devices, Inc.All rights reserved. -import argparse -from enum import IntEnum -from pathlib import Path -import sys -from typing import List, Optional, Any -import functools -import itertools -import copy -from dataclasses import dataclass +import argparse from enum import IntEnum from pathlib import Path import sys from typing import List, Optional, Any import functools import itertools import copy from dataclasses import dataclass -def get_if_str(size_, total, last_else=True): - if size_ == "head_dim_256_seq_4096": - return 'if' - else: - return 'else if' + def get_if_str(size_, total, last_else = True) : if size_ == "head_dim_256_seq_4096" : return 'if' else : return 'else if' -DATA_TYPE_MAP = {'fp32': 'float', - 'fp16': 'ck_tile::half_t', - 'bf16': 'ck_tile::bf16_t'} + DATA_TYPE_MAP = {'fp32' : 'float', 'fp16' : 'ck_tile::half_t', 'bf16' : 'ck_tile::bf16_t' } -def BOOL_MAP(b_) -> str: - return 'true' if b_ else 'false' + def BOOL_MAP(b_)->str: return 'true' if b_ else 'false' -class FlashAttentionFwdCodegen: - API_TRAITS_DEFINE = """ + class FlashAttentionFwdCodegen:API_TRAITS_DEFINE = "" + " -template -struct flash_attention_fwd_traits_ -{ - using SaccDataType = ck_tile::remove_cvref_t; - using SMPLComputeDataType = ck_tile::remove_cvref_t; - using PDataType = ck_tile::remove_cvref_t; - using OaccDataType = ck_tile::remove_cvref_t; + index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_{using SaccDataType = ck_tile::remove_cvref_t ; +using SMPLComputeDataType = ck_tile::remove_cvref_t; +using PDataType = ck_tile::remove_cvref_t; +using OaccDataType = ck_tile::remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kHeadDim = kHeadDim_; - static constexpr index_t kM0PerBlock = kM0PerBlock_; - static constexpr index_t kN0PerBlock = kN0PerBlock_; - static constexpr index_t kK0PerBlock = kK0PerBlock_; - static constexpr index_t kN1PerBlock = kN1PerBlock_; - static constexpr index_t kK1PerBlock = kK1PerBlock_; +static constexpr index_t kBlockSize = kBlockSize_; +static constexpr index_t kHeadDim = kHeadDim_; +static constexpr index_t kM0PerBlock = kM0PerBlock_; +static constexpr index_t kN0PerBlock = kN0PerBlock_; +static constexpr index_t kK0PerBlock = kK0PerBlock_; +static constexpr index_t kN1PerBlock = kN1PerBlock_; +static constexpr index_t kK1PerBlock = kK1PerBlock_; - static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD - static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); - static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; -}; +static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD +static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); +static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; +} +; template using traits_ = flash_attention_fwd_traits_; -""" + SMPLComputeDataType, + PDataType, + OaccDataType, + kBlockSize, + kHeadDim, + kM0PerBlock, + kN0PerBlock, + kK0PerBlock, + kN1PerBlock, + kK1PerBlock>; +"" + " - API_BASE = """ + API_BASE = "" + " // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "flash_attention_fwd.hpp" -namespace ck_tile {{ + namespace ck_tile +{ + { -{F_traits_define} + { + F_traits_define + } -// Note: this internal API only declare, not define here, otherwise will block `make -j` -template -float flash_attention_fwd_(const FlashAttnArgs& a, - const ck_tile::stream_config& stream_config); + // Note: this internal API only declare, not define here, otherwise will block `make -j` + template + float flash_attention_fwd_( + const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config); -template -float flash_attention_fwd(const FlashAttnArgs& a, - const ck_tile::stream_config& stream_config) {{ - float r = -1; -{F_dispatch} - return r; -}} - -template float flash_attention_fwd( - const FlashAttnArgs&, - const ck_tile::stream_config&); - -}} -""" - - API_INNER_CASE = """ {F_if} {F_VEC_COND} - r = flash_attention_fwd_>(a, stream_config); -""" - - INSTANCE_BASE = """ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "flash_attention_fwd_api_common.hpp" - -namespace ck_tile { -// clang-format off -// -{F_instance_def} -// clang-format on + template + float flash_attention_fwd( + const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config) + { + { + float r = -1; + { + F_dispatch + } + return r; + } + } + template float flash_attention_fwd(const FlashAttnArgs&, + const ck_tile::stream_config&); + } } -""" +"" + " - def __init__(self, working_path, kernel_filter): - self.working_path = working_path - self.kernel_filter = kernel_filter + API_INNER_CASE = "" + " {F_if} {F_VEC_COND} + r = flash_attention_fwd_>( + a, stream_config); +"" + " - @dataclass - class h_traits: - F_SaccDataType: str - F_SMPLComputeDataType: str - F_PDataType: str - F_OaccDataType: str - F_kBlockSize: int - F_kHeadDim: int - F_kM0PerBlock: int - F_kN0PerBlock: int - F_kK0PerBlock: int - F_kN1PerBlock: int - F_kK1PerBlock: int - - @property - def trait_name(self) -> str: - return (f"{DATA_TYPE_MAP[self.F_SaccDataType]}, " - f"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, " - f"{DATA_TYPE_MAP[self.F_PDataType]}, " - f"{DATA_TYPE_MAP[self.F_OaccDataType]}, " - f"{self.F_kBlockSize}, {self.F_kHeadDim}, " - f"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, " - f"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}") - - @property - def def_name(self) -> str: - return (f"template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, " - f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, " - f"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, " - f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, " - "const ck_tile::stream_config&);") - - @dataclass - class h_instance: - F_DataTypePair: str # "q,k,v,o" - F_SizeCategory: str # "small", "medium", "large" - instance_list: List[Any] # List[h_traits] - - INSTANCE_BASE = """ + INSTANCE_BASE = "" + " // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "flash_attention_fwd_api_common.hpp" -namespace ck_tile {{ -// clang-format off + namespace ck_tile +{ + // clang-format off // {F_instance_def} -// clang-format on -}} + // clang-format on +} +"" + " + + def + __init__(self, working_path, kernel_filter) + : self.working_path = working_path self.kernel_filter = kernel_filter + + @dataclass class h_traits + : F_SaccDataType : str F_SMPLComputeDataType : str F_PDataType : str F_OaccDataType + : str F_kBlockSize : int F_kHeadDim : int F_kM0PerBlock : int F_kN0PerBlock : int F_kK0PerBlock + : int F_kN1PerBlock : int F_kK1PerBlock : int + + @property def trait_name(self) + ->str + : return (f "{DATA_TYPE_MAP[self.F_SaccDataType]}, " f + "{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, " f + "{DATA_TYPE_MAP[self.F_PDataType]}, " f "{DATA_TYPE_MAP[self.F_OaccDataType]}, " f + "{self.F_kBlockSize}, {self.F_kHeadDim}, " f + "{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, " f + "{self.F_kN1PerBlock}, {self.F_kK1PerBlock}") + + @property def def_name(self) + ->str + : return (f "template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, " f + "{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, " f + "traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, " f + "{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, " + "const ck_tile::stream_config&);") + + @dataclass class h_instance : F_DataTypePair : str #"q,k,v,o" F_SizeCategory : str + #"small", + "medium", + "large" instance_list : List[Any] #List[h_traits] + + INSTANCE_BASE = "" + " +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "flash_attention_fwd_api_common.hpp" + + namespace ck_tile +{ + { + // clang-format off +// +{F_instance_def} + // clang-format on + }} """ @property @@ -226,123 +229,133 @@ namespace ck_tile {{ #include "flash_attention_fwd.hpp" -namespace ck_tile {{ +namespace ck_tile +{ + { -template -struct flash_attention_fwd_traits_ -{{ - using SaccDataType = ck_tile::remove_cvref_t; - using SMPLComputeDataType = ck_tile::remove_cvref_t; - using PDataType = ck_tile::remove_cvref_t; - using OaccDataType = ck_tile::remove_cvref_t; + template + struct flash_attention_fwd_traits_ + { + { + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kHeadDim = kHeadDim_; - static constexpr index_t kM0PerBlock = kM0PerBlock_; - static constexpr index_t kN0PerBlock = kN0PerBlock_; - static constexpr index_t kK0PerBlock = kK0PerBlock_; - static constexpr index_t kN1PerBlock = kN1PerBlock_; - static constexpr index_t kK1PerBlock = kK1PerBlock_; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kHeadDim = kHeadDim_; + static constexpr index_t kM0PerBlock = kM0PerBlock_; + static constexpr index_t kN0PerBlock = kN0PerBlock_; + static constexpr index_t kK0PerBlock = kK0PerBlock_; + static constexpr index_t kN1PerBlock = kN1PerBlock_; + static constexpr index_t kK1PerBlock = kK1PerBlock_; - static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD - static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; - static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; -}}; + static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + } + }; + template + using traits_ = flash_attention_fwd_traits_; -template -using traits_ = flash_attention_fwd_traits_; + template + float flash_attention_fwd_( + const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config) + { + { + using SaccDataType = typename Traits_::SaccDataType; + using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; + using PDataType = typename Traits_::PDataType; + using OaccDataType = typename Traits_::OaccDataType; + index_t kGridSize = + a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); -template -float flash_attention_fwd_(const FlashAttnArgs& a, - const ck_tile::stream_config& stream_config) {{ - using SaccDataType = typename Traits_::SaccDataType; - using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; - using PDataType = typename Traits_::PDataType; - using OaccDataType = typename Traits_::OaccDataType; + if(stream_config.log_level_ > 0) + std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," + << Traits_::kHeadDim << ">" << std::flush; - index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); - - if(stream_config.log_level_ > 0) - std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush; - - return ck_tile::launch_kernel(stream_config, - ck_tile::make_kernel( - ck_tile::FlashAttentionFwd{{}}, - kGridSize, - Traits_::kBlockSize, - 0, - a.q_ptr, - a.k_ptr, - a.v_ptr, - a.o_ptr, - a.M0, - a.N0, - a.K0, - a.N1, - a.Batch, - a.strideQ, // StrideQ - a.strideK, // StrideK - a.strideV, // StrideV - a.strideO, // StrideO - a.batchStrideQ, // BatchStrideQ - a.batchStrideK, // BatchStrideK - a.batchStrideV, // BatchStrideV - a.batchStrideO)); // BatchStrideO -}} -}} + return ck_tile::launch_kernel( + stream_config, + ck_tile::make_kernel( + ck_tile::FlashAttentionFwd{{}}, + kGridSize, + Traits_::kBlockSize, + 0, + a.q_ptr, + a.k_ptr, + a.v_ptr, + a.o_ptr, + a.M0, + a.N0, + a.K0, + a.N1, + a.Batch, + a.strideQ, // StrideQ + a.strideK, // StrideK + a.strideV, // StrideV + a.strideO, // StrideO + a.batchStrideQ, // BatchStrideQ + a.batchStrideK, // BatchStrideK + a.batchStrideV, // BatchStrideV + a.batchStrideO)); // BatchStrideO + } + } + } +} """ def content_api(self, args) -> str: - # Sort based on dtype +#Sort based on dtype t_dtype_dict = {} blobs = self.get_blobs(args) @@ -402,7 +415,7 @@ float flash_attention_fwd_(const FlashAttnArgs Date: Thu, 24 Apr 2025 09:50:22 +0000 Subject: [PATCH 19/21] Fix generate.py --- .../generate.py | 513 +++++++++--------- 1 file changed, 250 insertions(+), 263 deletions(-) diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py index 10def9a5dd..00bc91cadc 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py @@ -1,18 +1,33 @@ -#SPDX - License - Identifier : MIT -#Copyright(c) 2025, Advanced Micro Devices, Inc.All rights reserved. +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -import argparse from enum import IntEnum from pathlib import Path import sys from typing import List, Optional, Any import functools import itertools import copy from dataclasses import dataclass +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Any +import functools +import itertools +import copy +from dataclasses import dataclass - def get_if_str(size_, total, last_else = True) : if size_ == "head_dim_256_seq_4096" : return 'if' else : return 'else if' +def get_if_str(size_, total, last_else=True): + if size_ == "head_dim_256_seq_4096": + return 'if' + else: + return 'else if' - DATA_TYPE_MAP = {'fp32' : 'float', 'fp16' : 'ck_tile::half_t', 'bf16' : 'ck_tile::bf16_t' } +DATA_TYPE_MAP = {'fp32': 'float', + 'fp16': 'ck_tile::half_t', + 'bf16': 'ck_tile::bf16_t'} - def BOOL_MAP(b_)->str: return 'true' if b_ else 'false' +def BOOL_MAP(b_) -> str: + return 'true' if b_ else 'false' - class FlashAttentionFwdCodegen:API_TRAITS_DEFINE = "" - " +class FlashAttentionFwdCodegen: + API_TRAITS_DEFINE = """ - template struct flash_attention_fwd_traits_{using SaccDataType = ck_tile::remove_cvref_t ; -using SMPLComputeDataType = ck_tile::remove_cvref_t; -using PDataType = ck_tile::remove_cvref_t; -using OaccDataType = ck_tile::remove_cvref_t; + index_t kK1PerBlock_ = 64> +struct flash_attention_fwd_traits_ +{ + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; -static constexpr index_t kBlockSize = kBlockSize_; -static constexpr index_t kHeadDim = kHeadDim_; -static constexpr index_t kM0PerBlock = kM0PerBlock_; -static constexpr index_t kN0PerBlock = kN0PerBlock_; -static constexpr index_t kK0PerBlock = kK0PerBlock_; -static constexpr index_t kN1PerBlock = kN1PerBlock_; -static constexpr index_t kK1PerBlock = kK1PerBlock_; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kHeadDim = kHeadDim_; + static constexpr index_t kM0PerBlock = kM0PerBlock_; + static constexpr index_t kN0PerBlock = kN0PerBlock_; + static constexpr index_t kK0PerBlock = kK0PerBlock_; + static constexpr index_t kN1PerBlock = kN1PerBlock_; + static constexpr index_t kK1PerBlock = kK1PerBlock_; -static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD -static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); -static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; -} -; + static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); + static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; +}; template using traits_ = flash_attention_fwd_traits_; -"" - " + SMPLComputeDataType, + PDataType, + OaccDataType, + kBlockSize, + kHeadDim, + kM0PerBlock, + kN0PerBlock, + kK0PerBlock, + kN1PerBlock, + kK1PerBlock>; +""" - API_BASE = "" - " + API_BASE = """ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "flash_attention_fwd.hpp" - namespace ck_tile -{ - { +namespace ck_tile {{ - { - F_traits_define - } +{F_traits_define} - // Note: this internal API only declare, not define here, otherwise will block `make -j` - template - float flash_attention_fwd_( - const FlashAttnArgs& a, - const ck_tile::stream_config& stream_config); +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float flash_attention_fwd_(const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config); - template - float flash_attention_fwd( - const FlashAttnArgs& a, - const ck_tile::stream_config& stream_config) - { - { - float r = -1; - { - F_dispatch - } - return r; - } - } +template +float flash_attention_fwd(const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config) {{ + float r = -1; +{F_dispatch} + return r; +}} - template float flash_attention_fwd(const FlashAttnArgs&, - const ck_tile::stream_config&); - } -} -"" - " +template float flash_attention_fwd( + const FlashAttnArgs&, + const ck_tile::stream_config&); - API_INNER_CASE = "" - " {F_if} {F_VEC_COND} - r = flash_attention_fwd_>( - a, stream_config); -"" - " +}} +""" - INSTANCE_BASE = "" - " + API_INNER_CASE = """ {F_if} {F_VEC_COND} + r = flash_attention_fwd_>(a, stream_config); +""" + + INSTANCE_BASE = """ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "flash_attention_fwd_api_common.hpp" - namespace ck_tile -{ - // clang-format off +namespace ck_tile { +// clang-format off // {F_instance_def} - // clang-format on +// clang-format on + } -"" - " +""" - def - __init__(self, working_path, kernel_filter) - : self.working_path = working_path self.kernel_filter = kernel_filter + def __init__(self, working_path, kernel_filter): + self.working_path = working_path + self.kernel_filter = kernel_filter - @dataclass class h_traits - : F_SaccDataType : str F_SMPLComputeDataType : str F_PDataType : str F_OaccDataType - : str F_kBlockSize : int F_kHeadDim : int F_kM0PerBlock : int F_kN0PerBlock : int F_kK0PerBlock - : int F_kN1PerBlock : int F_kK1PerBlock : int + @dataclass + class h_traits: + F_SaccDataType: str + F_SMPLComputeDataType: str + F_PDataType: str + F_OaccDataType: str + F_kBlockSize: int + F_kHeadDim: int + F_kM0PerBlock: int + F_kN0PerBlock: int + F_kK0PerBlock: int + F_kN1PerBlock: int + F_kK1PerBlock: int - @property def trait_name(self) - ->str - : return (f "{DATA_TYPE_MAP[self.F_SaccDataType]}, " f - "{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, " f - "{DATA_TYPE_MAP[self.F_PDataType]}, " f "{DATA_TYPE_MAP[self.F_OaccDataType]}, " f - "{self.F_kBlockSize}, {self.F_kHeadDim}, " f - "{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, " f - "{self.F_kN1PerBlock}, {self.F_kK1PerBlock}") + @property + def trait_name(self) -> str: + return (f"{DATA_TYPE_MAP[self.F_SaccDataType]}, " + f"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, " + f"{DATA_TYPE_MAP[self.F_PDataType]}, " + f"{DATA_TYPE_MAP[self.F_OaccDataType]}, " + f"{self.F_kBlockSize}, {self.F_kHeadDim}, " + f"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, " + f"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}") - @property def def_name(self) - ->str - : return (f "template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, " f - "{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, " f - "traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, " f - "{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, " - "const ck_tile::stream_config&);") + @property + def def_name(self) -> str: + return (f"template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, " + f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, " + f"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, " + f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, " + "const ck_tile::stream_config&);") - @dataclass class h_instance : F_DataTypePair : str #"q,k,v,o" F_SizeCategory : str - #"small", - "medium", - "large" instance_list : List[Any] #List[h_traits] + @dataclass + class h_instance: + F_DataTypePair: str # "q,k,v,o" + F_SizeCategory: str # "small", "medium", "large" + instance_list: List[Any] # List[h_traits] - INSTANCE_BASE = "" - " + INSTANCE_BASE = """ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "flash_attention_fwd_api_common.hpp" - namespace ck_tile -{ - { - // clang-format off +namespace ck_tile {{ +// clang-format off // {F_instance_def} - // clang-format on - }} +// clang-format on +}} """ @property @@ -229,133 +226,123 @@ using traits_ = flash_attention_fwd_traits_ - struct flash_attention_fwd_traits_ - { - { - using SaccDataType = ck_tile::remove_cvref_t; - using SMPLComputeDataType = ck_tile::remove_cvref_t; - using PDataType = ck_tile::remove_cvref_t; - using OaccDataType = ck_tile::remove_cvref_t; +template +struct flash_attention_fwd_traits_ +{{ + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kHeadDim = kHeadDim_; - static constexpr index_t kM0PerBlock = kM0PerBlock_; - static constexpr index_t kN0PerBlock = kN0PerBlock_; - static constexpr index_t kK0PerBlock = kK0PerBlock_; - static constexpr index_t kN1PerBlock = kN1PerBlock_; - static constexpr index_t kK1PerBlock = kK1PerBlock_; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kHeadDim = kHeadDim_; + static constexpr index_t kM0PerBlock = kM0PerBlock_; + static constexpr index_t kN0PerBlock = kN0PerBlock_; + static constexpr index_t kK0PerBlock = kK0PerBlock_; + static constexpr index_t kN1PerBlock = kN1PerBlock_; + static constexpr index_t kK1PerBlock = kK1PerBlock_; - static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD - static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; - static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; - } - }; + static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; +}}; - template - using traits_ = flash_attention_fwd_traits_; - template - float flash_attention_fwd_( - const FlashAttnArgs& a, - const ck_tile::stream_config& stream_config) - { - { - using SaccDataType = typename Traits_::SaccDataType; - using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; - using PDataType = typename Traits_::PDataType; - using OaccDataType = typename Traits_::OaccDataType; +template +using traits_ = flash_attention_fwd_traits_; - index_t kGridSize = - a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); - if(stream_config.log_level_ > 0) - std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," - << Traits_::kHeadDim << ">" << std::flush; +template +float flash_attention_fwd_(const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config) {{ + using SaccDataType = typename Traits_::SaccDataType; + using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; + using PDataType = typename Traits_::PDataType; + using OaccDataType = typename Traits_::OaccDataType; - return ck_tile::launch_kernel( - stream_config, - ck_tile::make_kernel( - ck_tile::FlashAttentionFwd{{}}, - kGridSize, - Traits_::kBlockSize, - 0, - a.q_ptr, - a.k_ptr, - a.v_ptr, - a.o_ptr, - a.M0, - a.N0, - a.K0, - a.N1, - a.Batch, - a.strideQ, // StrideQ - a.strideK, // StrideK - a.strideV, // StrideV - a.strideO, // StrideO - a.batchStrideQ, // BatchStrideQ - a.batchStrideK, // BatchStrideK - a.batchStrideV, // BatchStrideV - a.batchStrideO)); // BatchStrideO - } - } - } -} + index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); + + if(stream_config.log_level_ > 0) + std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush; + + return ck_tile::launch_kernel(stream_config, + ck_tile::make_kernel( + ck_tile::FlashAttentionFwd{{}}, + kGridSize, + Traits_::kBlockSize, + 0, + a.q_ptr, + a.k_ptr, + a.v_ptr, + a.o_ptr, + a.M0, + a.N0, + a.K0, + a.N1, + a.Batch, + a.strideQ, // StrideQ + a.strideK, // StrideK + a.strideV, // StrideV + a.strideO, // StrideO + a.batchStrideQ, // BatchStrideQ + a.batchStrideK, // BatchStrideK + a.batchStrideV, // BatchStrideV + a.batchStrideO)); // BatchStrideO +}} +}} """ def content_api(self, args) -> str: -#Sort based on dtype + # Sort based on dtype t_dtype_dict = {} blobs = self.get_blobs(args) @@ -415,7 +402,7 @@ namespace ck_tile h_traits = self.h_traits h_instance = self.h_instance -#Define kernel configurations for different size categories + # Define kernel configurations for different size categories trait_dict = { "head_dim_256_seq_4096": [ h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64), @@ -437,17 +424,17 @@ namespace ck_tile ], } -#Toy example only support fp16 + # Toy example only support fp16 dtype_combinations = [ "fp16,fp16,fp16,fp16" -#"bf16,bf16,bf16,bf16" + # "bf16,bf16,bf16,bf16" ] total_blob = [] for dtype_pair in dtype_combinations: for size_category in trait_dict: traits = trait_dict[size_category] -#Convert data types for the current dtype_pair + # Convert data types for the current dtype_pair q_type, k_type, v_type, o_type = dtype_pair.split(',') current_traits = [] for t in traits: @@ -468,10 +455,10 @@ namespace ck_tile blobs = self.get_blobs(args) with list_p.open('w') as list_f: -#API related files + # API related files list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") -#Kernel instance files + # Kernel instance files for b in blobs: list_f.write(str(w_p / (b.name + ".cpp")) + "\n") From 7e0ca8c2c7d806febc1d8672d40a064cf05c3c72 Mon Sep 17 00:00:00 2001 From: MHYang Date: Thu, 24 Apr 2025 10:41:34 +0000 Subject: [PATCH 20/21] Remove unused flag --- .../99_toy_example/03_flash_attention_fwd/CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt index 4c71936c61..23fc7484dd 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt @@ -13,8 +13,7 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" OFF) if(ENABLE_TOY_FA_FWD_OPT) message("Compiling with toy FA fwd optimization") - # target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT) - add_definitions(-DTOY_FA_FWD_OPT) + target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT) endif() target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) From e0bfe71854917c0e8afacde4bd9a203e50d91191 Mon Sep 17 00:00:00 2001 From: MHYang Date: Thu, 24 Apr 2025 16:50:22 +0000 Subject: [PATCH 21/21] Fix unexpected errors --- .../block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 3 +-- .../03_flash_attention_fwd/flash_attention_fwd_impl.hpp | 4 ++-- .../block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 3 +-- .../flash_attention_fwd_impl.hpp | 4 ++-- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 928ca83f65..cfbd7d6376 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -292,10 +292,9 @@ struct BlockGemmPipelineAGmemBGmemCReg // Global read 0 auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); if constexpr(k_loops > 1) { - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - // LDS write 0 store_tile(b_copy_lds_window, b_block_tile); diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index bffed23722..fbca3a95ac 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -171,7 +171,7 @@ struct FlashAttentionFwdImpl auto q_dram_window = make_tile_window( q_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {iM0, 0}, BlockGemm0Policy::template MakeADramTileDistribution()); @@ -257,6 +257,7 @@ struct FlashAttentionFwdImpl #if defined(TOY_FA_FWD_OPT) // prefetch load v tile auto v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); #endif // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( @@ -341,7 +342,6 @@ struct FlashAttentionFwdImpl if constexpr(k1_loops > 1) { - move_tile_window(v_dram_window, {0, kK1PerBlock}); store_tile(v_copy_lds_window, v_prefetch); v_prefetch = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1PerBlock}); diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 928ca83f65..cfbd7d6376 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -292,10 +292,9 @@ struct BlockGemmPipelineAGmemBGmemCReg // Global read 0 auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); if constexpr(k_loops > 1) { - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - // LDS write 0 store_tile(b_copy_lds_window, b_block_tile); diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp index bffed23722..fbca3a95ac 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -171,7 +171,7 @@ struct FlashAttentionFwdImpl auto q_dram_window = make_tile_window( q_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {iM0, 0}, BlockGemm0Policy::template MakeADramTileDistribution()); @@ -257,6 +257,7 @@ struct FlashAttentionFwdImpl #if defined(TOY_FA_FWD_OPT) // prefetch load v tile auto v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); #endif // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( @@ -341,7 +342,6 @@ struct FlashAttentionFwdImpl if constexpr(k1_loops > 1) { - move_tile_window(v_dram_window, {0, kK1PerBlock}); store_tile(v_copy_lds_window, v_prefetch); v_prefetch = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1PerBlock});