diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index effcc2b101..1aa11791dd 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -306,11 +306,11 @@ struct BlockGemmPipelineAGmemBGmemCReg // ------------------------------------------------------------------------------------- // Gemm pipeline start -#if defined(ENABLE_PREFETCH) -#pragma message("global prefetch") // Initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); +#if defined(ENABLE_PREFETCH) +#pragma message("global prefetch") // Prefetch // Global read 0 a_block_tile = load_tile(a_copy_dram_window); @@ -325,7 +325,7 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); - // Global read 0 + // Global read 1 a_block_tile = load_tile(a_copy_dram_window); b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); @@ -347,11 +347,11 @@ struct BlockGemmPipelineAGmemBGmemCReg { block_sync_lds(); - // LDS write 0 + // LDS write 1 store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); - // Global read 0 + // Global read 2 a_block_tile = load_tile(a_copy_dram_window); b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); @@ -387,18 +387,7 @@ struct BlockGemmPipelineAGmemBGmemCReg block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); #else // non-prefetch - a_block_tile = load_tile(a_copy_dram_window); - b_block_tile = load_tile(b_copy_dram_window); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - store_tile(a_copy_lds_window, a_block_tile); - store_tile(b_copy_lds_window, b_block_tile); - - block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); - - index_t iCounter = num_loop - 1; + index_t iCounter = num_loop; while(iCounter > 0) { diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt index 44dfac099c..4c71936c61 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt @@ -10,6 +10,13 @@ set(EXAMPLE_REDUCE_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" OFF) +if(ENABLE_TOY_FA_FWD_OPT) + message("Compiling with toy FA fwd optimization") + # target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT) + add_definitions(-DTOY_FA_FWD_OPT) +endif() + target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index fc18640958..e6b1707851 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -29,6 +29,35 @@ struct BlockGemmARegBSmemCRegV1 static constexpr index_t kPackedSize = ck_tile::numeric_traits>::PackedSize; + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + template CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { @@ -118,6 +147,129 @@ struct BlockGemmARegBSmemCRegV1 }); } + // C += A * B + template + __device__ void operator() (CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BLdsTile& b_block_tensor_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + // C += A * B template __device__ void operator()(CBlockTensor& c_block_tensor, diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 02fe7f54ad..4383ce76cb 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -13,16 +13,13 @@ namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCReg< - Problem, - BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy> +template +struct BlockGemmPipelineAGmemBGmemCReg { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -222,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< ignore = b_element_func; + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + // A tile in Reg,blockTensor // This tensor distribution used to construct both distributed tensor for local buffer store // and read. without buffer address info @@ -261,62 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg< // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); + b_lds_block, make_tuple(number{}, number{}), {0, 0}, + make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile auto c_block_tile = decltype(block_gemm( get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); +#if !defined(TOY_FA_FWD_OPT) + static_for<0, k_loops, 1>{}([&](auto i_k0) { + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + }); +#else + using BLdsTile = typename decltype(block_gemm)::BLdsTile; + BLdsTile bWarpTile; + + // Global read 0 + auto b_block_tile = load_tile(b_copy_dram_window); if constexpr(k_loops > 1) { move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // LDS write 0 store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + block_sync_lds(); + + // LDS read 0 + bWarpTile = load_tile(b_lds_gemm_window); } - __builtin_amdgcn_sched_barrier(0); if constexpr(k_loops > 2) { + __builtin_amdgcn_sched_barrier(0); static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { block_sync_lds(); + // LDS write 1 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); block_sync_lds(); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - store_tile(b_copy_lds_window, b_block_tile); - b_block_tile = load_tile(b_copy_dram_window); + // LDS read 1 + bWarpTile = load_tile(b_lds_gemm_window); block_gemm.HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); }); } - // tail { if constexpr(k_loops > 1) { - block_sync_lds(); - block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); block_sync_lds(); } @@ -324,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_sync_lds(); + bWarpTile = load_tile(b_lds_gemm_window); + block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); } - +#endif return c_block_tile; } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 5642617856..68afedda27 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -3,42 +3,15 @@ #pragma once -#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" namespace ck_tile { -// NOTE: Assume A is K-Major -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy +template +struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy { - template - __host__ __device__ static constexpr auto MakeARegBlockDescriptor() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_reg_block_descriptor(); - } - - template - __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() - { - return policy_impl::make_b_lds_block_descriptor_3d_pad(); - } - - template - __host__ __device__ static constexpr auto MakeADramTileDistribution() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_dram_tile_distribution_skip_lds(); - } - - template - __host__ __device__ static constexpr auto MakeBDramTileDistribution() - { - return policy_impl::make_b_dram_tile_distribution(); - } + static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto GetBlockGemm() @@ -47,13 +20,7 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy return BlockGemmARegBSmemCRegV1{}; } -}; -template -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy - : BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy -{ - static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() @@ -93,11 +60,88 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + template __host__ __device__ static constexpr auto MakeADramTileDistribution() { return MakeARegBlockDescriptor(); } + + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp deleted file mode 100644 index 4d48478084..0000000000 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ /dev/null @@ -1,206 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck_tile/core.hpp" -#include "ck_tile/core/tensor/tile_distribution.hpp" - -namespace ck_tile { -namespace policy_impl { - -// 3d + padding -template -__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto a_lds_block_desc = - transform_tensor_descriptor(a_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; -} - -// 3d + padding -template -__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() -{ - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; - - using BDataType = remove_cvref_t; - - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; -} - -template -__host__ __device__ static constexpr auto make_a_reg_block_descriptor() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - constexpr index_t NWarp = config.template get<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); - - return a_block_dstr; -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution() -{ - using ADataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds() -{ - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K2 = - WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane; - // // 16 / sizeof(ADataType); - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 0, 2>>{}); -} - -template -__host__ __device__ static constexpr auto make_b_dram_tile_distribution() -{ - using BDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto get_block_gemm() -{ - using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy; - - return BlockGemmASmemBSmemCReg{}; -} - -} // namespace policy_impl -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp index c880639621..7fc5a78806 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp @@ -8,11 +8,11 @@ #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" #include "block_gemm_pipeline_problem.hpp" #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "flash_attention_fwd_impl.hpp" + namespace ck_tile { CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) @@ -115,7 +115,8 @@ struct FlashAttentionFwd const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock); const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock); -#if defined(GEMM_OPT) +#if defined(TOY_FA_FWD_OPT) +#pragma message("Enable toy FA fwd opt") const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1); const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 4aae01d1c8..689b557500 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -14,6 +14,7 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "tile_gemm_shape.hpp" + namespace ck_tile { // S[M0, N0] = Q[M0, K0] * K[N0, K0] @@ -152,6 +153,10 @@ struct FlashAttentionFwdImpl constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + // allocate LDS __shared__ char smem_ptr[GetStaticLdsSize()]; @@ -179,7 +184,6 @@ struct FlashAttentionFwdImpl make_tuple(number{}, number{}), {iN1, 0}, MakeVDramTileDistribution()); - // Q in register auto q_reg_tensor = load_tile(q_dram_window); @@ -188,12 +192,22 @@ struct FlashAttentionFwdImpl auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); +#if defined(TOY_FA_FWD_OPT) + // V LDS tile window for store + auto v_copy_lds_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); + + // V LDS tile for block GEMM + auto v_lds_gemm_window = make_tile_window( + v_lds, make_tuple(number{}, number{}), {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); +#else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM0 pipeline and Block GEMM1 - constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; - constexpr auto gemm1 = BlockGemm1{}; +#endif // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; @@ -239,9 +253,10 @@ struct FlashAttentionFwdImpl const auto s = tile_elementwise_in(type_convert, s_acc); +#if defined(TOY_FA_FWD_OPT) // prefetch load v tile - const auto v_prefetch = load_tile(v_dram_window); - + auto v_prefetch = load_tile(v_dram_window); +#endif // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, sequence<1>{}, f_max, std::numeric_limits::lowest()); @@ -291,10 +306,30 @@ struct FlashAttentionFwdImpl o_acc(i_j_idx) *= tmp; }); }); - block_sync_lds(); - store_tile(v_lds_window, v_prefetch); - move_tile_window(v_dram_window, {0, kK1PerBlock}); +#if !defined(TOY_FA_FWD_OPT) + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_lds_window, v); + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); + block_sync_lds(); + }); +#else + using VLdsTile = typename decltype(gemm1)::BLdsTile; + VLdsTile vWarpTile; // type cast Pcompute{j} into P{j} const auto p = @@ -304,34 +339,60 @@ struct FlashAttentionFwdImpl constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; if constexpr(k1_loops > 1) + { + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_copy_lds_window, v_prefetch); + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + } + if constexpr(k1_loops > 2) { __builtin_amdgcn_sched_barrier(0); - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v + static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) { block_sync_lds(); - gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - v_lds_window); - block_sync_lds(); - store_tile(v_lds_window, v); + + // LDS write 1 + store_tile(v_copy_lds_window, v_prefetch); + + // Global read 2 + v_prefetch = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1PerBlock}); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); gemm1.template HotLoopScheduler<8, 4>(); __builtin_amdgcn_sched_barrier(0); }); } // tail { + if constexpr (k1_loops > 1) + { + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } + store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 1) * kK1PerBlock>{}, - sequence{}), - v_lds_window); + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } +#endif // move tile windows move_tile_window(k_dram_window, {kN0PerBlock, 0}); iN0 += kN0PerBlock; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 33d36954c0..e6b1707851 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -26,6 +26,250 @@ struct BlockGemmARegBSmemCRegV1 static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + template + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MPerXDL = WG::kM; + constexpr index_t NPerXDL = WG::kN; + constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = config.template get<1>(); + + constexpr index_t B_LDS_RW_Width = SmemPack; + + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * VectorSizeB); + + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + // B split schedule + constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_b_issue_cycle = + B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + // C += A * B + template + __device__ void operator() (CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BLdsTile& b_block_tensor_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + // C += A * B template __device__ void operator()(CBlockTensor& c_block_tensor, diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index ec9484ffd1..4383ce76cb 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -13,16 +13,13 @@ namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCReg< - Problem, - BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy> +template +struct BlockGemmPipelineAGmemBGmemCReg { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -134,6 +131,8 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_block_tile = load_tile(b_copy_dram_window); } + __builtin_amdgcn_sched_barrier(0); + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -158,6 +157,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } @@ -217,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< ignore = b_element_func; + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + // A tile in Reg,blockTensor // This tensor distribution used to construct both distributed tensor for local buffer store // and read. without buffer address info @@ -256,58 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg< // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); + b_lds_block, make_tuple(number{}, number{}), {0, 0}, + make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile auto c_block_tile = decltype(block_gemm( get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - auto b_block_tile = load_tile(b_copy_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); +#if !defined(TOY_FA_FWD_OPT) + static_for<0, k_loops, 1>{}([&](auto i_k0) { + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + }); +#else + using BLdsTile = typename decltype(block_gemm)::BLdsTile; + BLdsTile bWarpTile; + + // Global read 0 + auto b_block_tile = load_tile(b_copy_dram_window); if constexpr(k_loops > 1) { move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // LDS write 0 store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + block_sync_lds(); + + // LDS read 0 + bWarpTile = load_tile(b_lds_gemm_window); } if constexpr(k_loops > 2) { + __builtin_amdgcn_sched_barrier(0); static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { block_sync_lds(); + // LDS write 1 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); block_sync_lds(); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // LDS read 1 + bWarpTile = load_tile(b_lds_gemm_window); - store_tile(b_copy_lds_window, b_block_tile); - b_block_tile = load_tile(b_copy_dram_window); + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } - // tail { if constexpr(k_loops > 1) { - block_sync_lds(); - block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); block_sync_lds(); } @@ -315,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_sync_lds(); + bWarpTile = load_tile(b_lds_gemm_window); + block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); } - +#endif return c_block_tile; } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 5642617856..68afedda27 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -3,42 +3,15 @@ #pragma once -#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" namespace ck_tile { -// NOTE: Assume A is K-Major -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy +template +struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy { - template - __host__ __device__ static constexpr auto MakeARegBlockDescriptor() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_reg_block_descriptor(); - } - - template - __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() - { - return policy_impl::make_b_lds_block_descriptor_3d_pad(); - } - - template - __host__ __device__ static constexpr auto MakeADramTileDistribution() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_dram_tile_distribution_skip_lds(); - } - - template - __host__ __device__ static constexpr auto MakeBDramTileDistribution() - { - return policy_impl::make_b_dram_tile_distribution(); - } + static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto GetBlockGemm() @@ -47,13 +20,7 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy return BlockGemmARegBSmemCRegV1{}; } -}; -template -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy - : BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy -{ - static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() @@ -93,11 +60,88 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + template __host__ __device__ static constexpr auto MakeADramTileDistribution() { return MakeARegBlockDescriptor(); } + + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp deleted file mode 100644 index 4d48478084..0000000000 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ /dev/null @@ -1,206 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck_tile/core.hpp" -#include "ck_tile/core/tensor/tile_distribution.hpp" - -namespace ck_tile { -namespace policy_impl { - -// 3d + padding -template -__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto a_lds_block_desc = - transform_tensor_descriptor(a_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; -} - -// 3d + padding -template -__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() -{ - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = 8; - - using BDataType = remove_cvref_t; - - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; -} - -template -__host__ __device__ static constexpr auto make_a_reg_block_descriptor() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - constexpr index_t NWarp = config.template get<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); - - return a_block_dstr; -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution() -{ - using ADataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds() -{ - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K2 = - WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane; - // // 16 / sizeof(ADataType); - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 0, 2>>{}); -} - -template -__host__ __device__ static constexpr auto make_b_dram_tile_distribution() -{ - using BDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto get_block_gemm() -{ - using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy; - - return BlockGemmASmemBSmemCReg{}; -} - -} // namespace policy_impl -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index c79e6f6094..dcb901c0a2 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -9,7 +9,6 @@ #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" #include "block_gemm_pipeline_problem.hpp" #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "flash_attention_fwd_impl.hpp" diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp index 0b396ea59f..689b557500 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -14,6 +14,7 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "tile_gemm_shape.hpp" + namespace ck_tile { // S[M0, N0] = Q[M0, K0] * K[N0, K0] @@ -152,6 +153,10 @@ struct FlashAttentionFwdImpl constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + // allocate LDS __shared__ char smem_ptr[GetStaticLdsSize()]; @@ -179,7 +184,6 @@ struct FlashAttentionFwdImpl make_tuple(number{}, number{}), {iN1, 0}, MakeVDramTileDistribution()); - // Q in register auto q_reg_tensor = load_tile(q_dram_window); @@ -188,12 +192,22 @@ struct FlashAttentionFwdImpl auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); +#if defined(TOY_FA_FWD_OPT) + // V LDS tile window for store + auto v_copy_lds_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); + + // V LDS tile for block GEMM + auto v_lds_gemm_window = make_tile_window( + v_lds, make_tuple(number{}, number{}), {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); +#else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM0 pipeline and Block GEMM1 - constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; - constexpr auto gemm1 = BlockGemm1{}; +#endif // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; @@ -239,9 +253,10 @@ struct FlashAttentionFwdImpl const auto s = tile_elementwise_in(type_convert, s_acc); +#if defined(TOY_FA_FWD_OPT) // prefetch load v tile - const auto v_prefetch = load_tile(v_dram_window); - + auto v_prefetch = load_tile(v_dram_window); +#endif // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, sequence<1>{}, f_max, std::numeric_limits::lowest()); @@ -291,10 +306,30 @@ struct FlashAttentionFwdImpl o_acc(i_j_idx) *= tmp; }); }); - block_sync_lds(); - store_tile(v_lds_window, v_prefetch); - move_tile_window(v_dram_window, {0, kK1PerBlock}); +#if !defined(TOY_FA_FWD_OPT) + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_lds_window, v); + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); + block_sync_lds(); + }); +#else + using VLdsTile = typename decltype(gemm1)::BLdsTile; + VLdsTile vWarpTile; // type cast Pcompute{j} into P{j} const auto p = @@ -305,29 +340,59 @@ struct FlashAttentionFwdImpl if constexpr(k1_loops > 1) { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_copy_lds_window, v_prefetch); + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + } + if constexpr(k1_loops > 2) + { + __builtin_amdgcn_sched_barrier(0); + static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) { block_sync_lds(); - gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - v_lds_window); - block_sync_lds(); - store_tile(v_lds_window, v); + + // LDS write 1 + store_tile(v_copy_lds_window, v_prefetch); + + // Global read 2 + v_prefetch = load_tile(v_dram_window); move_tile_window(v_dram_window, {0, kK1PerBlock}); + + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + gemm1.template HotLoopScheduler<8, 4>(); + __builtin_amdgcn_sched_barrier(0); }); } // tail { + if constexpr (k1_loops > 1) + { + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } + store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 1) * kK1PerBlock>{}, - sequence{}), - v_lds_window); + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } +#endif // move tile windows move_tile_window(k_dram_window, {kN0PerBlock, 0}); iN0 += kN0PerBlock;