diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp index e77f3c835e..5c58fa3d60 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp @@ -21,7 +21,11 @@ struct BlockGemmASmemBSmemCReg using BlockGemmShape = remove_cvref_t; using WarpGemm = remove_cvref_t< - decltype(Policy::template GetWarpGemmMWarpNWarp().template at<0>())>; + decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; + static constexpr index_t MWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<2>(); using AWarpDstr = typename WarpGemm::AWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr; @@ -42,15 +46,11 @@ struct BlockGemmASmemBSmemCReg static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; -#if defined(ENABLE_INSTRUCTION_SCH) +#if defined(ENABLE_PREFETCH) // A block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarp * WarpGemm::kM); - constexpr index_t KPerBlock = BlockGemmShape::kK; constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; @@ -70,11 +70,7 @@ struct BlockGemmASmemBSmemCReg // B block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarp * WarpGemm::kN); - constexpr index_t KPerBlock = BlockGemmShape::kK; constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; @@ -99,24 +95,24 @@ struct BlockGemmASmemBSmemCReg using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - ALdsTile a_warp_tile_; - ALdsTile b_warp_tile_; + ALdsTile aWarpTile; + BLdsTile bWarpTile; // Prefetch from LDS to warp register template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window) { - load_tile(a_warp_tile_, a_block_window); - load_tile(b_warp_tile_, b_block_window); + aWarpTile = load_tile(a_block_window); + bWarpTile = load_tile(b_block_window); } #endif // C += A * B template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockWindowTmp& a_block_window_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const { static_assert(std::is_same_v && std::is_same_v && @@ -131,17 +127,11 @@ struct BlockGemmASmemBSmemCReg KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - constexpr index_t NWarp = config.template get<2>(); - - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; +#if !defined(ENABLE_PREFETCH) constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; @@ -149,13 +139,13 @@ struct BlockGemmASmemBSmemCReg const index_t iMWarp = get_warp_id() / NWarp; const index_t iNWarp = get_warp_id() % NWarp; - // construct A-warp-window + // Construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WG::kM, + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); statically_indexed_array< statically_indexed_array, @@ -165,19 +155,18 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); - // construct B-warp-window + // Construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); statically_indexed_array< statically_indexed_array, @@ -187,48 +176,46 @@ struct BlockGemmASmemBSmemCReg static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); +#endif // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { -#if defined(ENABLE_INSTRUCTION_SCH) -#pragma message("local data share prefetch") - // read A warp tensor from A block tensor + // Read A warp tensor from A block tensor AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) +#pragma message("local data share prefetch") + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); #else - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { -#if defined(ENABLE_INSTRUCTION_SCH) - // read B warp tensor from B block tensor + // Read B warp tensor from B block tensor BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); #else - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif - // read C warp tensor from C block tensor + // Read C warp tensor from C block tensor CWarpTensor c_warp_tensor; c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // Warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor + // Write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), @@ -240,8 +227,8 @@ struct BlockGemmASmemBSmemCReg // C = A * B template - CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const { static_assert(std::is_same_v && std::is_same_v, @@ -255,17 +242,11 @@ struct BlockGemmASmemBSmemCReg KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t{}))>; - - constexpr index_t MWarp = config.template get(number<1>{}); - constexpr index_t NWarp = config.template get(number<2>{}); - - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; +#if !defined(ENABLE_PREFETCH) constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; @@ -273,13 +254,13 @@ struct BlockGemmASmemBSmemCReg const index_t iMWarp = get_warp_id() / NWarp; const index_t iNWarp = get_warp_id() % NWarp; - // construct A-warp-window + // Construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WG::kM, + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); statically_indexed_array< statically_indexed_array, @@ -289,19 +270,18 @@ struct BlockGemmASmemBSmemCReg static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); - // construct B-warp-window + // Construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); statically_indexed_array< statically_indexed_array, @@ -311,13 +291,13 @@ struct BlockGemmASmemBSmemCReg static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - move_tile_window(b_warp_windows(nIter)(kIter), {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); +#endif - static_assert(std::is_same_v, "wrong!"); + static_assert(std::is_same_v, "wrong!"); // Construct C-Block-Tensor constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< @@ -329,44 +309,42 @@ struct BlockGemmASmemBSmemCReg sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - // hot loop: + // Hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { -#if defined(ENABLE_INSTRUCTION_SCH) - // read A warp tensor from A block tensor + // Read A warp tensor from A block tensor AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); #else - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); #endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { -#if defined(ENABLE_INSTRUCTION_SCH) - // read B warp tensor from B block tensor + // Read B warp tensor from B block tensor BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); #else - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); #endif - // read C warp tensor from C block tensor + // Read C warp tensor from C block tensor CWarpTensor c_warp_tensor; - // warp GEMM + // Warp GEMM if constexpr(KIterPerWarp == 0) { // c = a * b - c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor); + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); } else { @@ -375,10 +353,10 @@ struct BlockGemmASmemBSmemCReg merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); } - // write C warp tensor into C block tensor + // Write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp index a9fa1e9436..2fdb63794f 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -17,19 +17,29 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { +#if defined(ADJUST_BLOCK_TILE_SHAPE) + constexpr index_t kMWarp = 2; + constexpr index_t kNWarp = 2; +#else + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; +#endif + #if defined(NAIVE_IMPLEMENTATION) #pragma message("mfma m32 n32 k8") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 2, 2); + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 2, 2); + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); } #elif defined(USING_MFMA_32x32x_8x2) #pragma message("mfma m32 n32 k16") @@ -37,13 +47,15 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2); + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 2, 2); + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); } #elif defined(USING_MFMA_16x16x16) #pragma message("mfma m16 n16 k16") @@ -51,13 +63,15 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 2); + return make_tuple( + WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 2, 2); + return make_tuple( + WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); } #elif defined(USING_MFMA_16x16x_16x2) #pragma message("mfma m16 n16 k32") @@ -65,13 +79,15 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 2); + return make_tuple( + WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp); } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, 2, 2); + return make_tuple( + WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp); } #endif else diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index 3272835bf3..1aa11791dd 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -42,23 +42,8 @@ struct BlockGemmPipelineAGmemBGmemCReg } #if defined(ENABLE_INSTRUCTION_SCH) - static constexpr index_t APackedSize = + static constexpr index_t kPackedSize = ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - using I0 = number<0>; - using I1 = number<1>; - using I2 = number<2>; - - static constexpr index_t BlockSize = Problem::kBlockSize; - static constexpr index_t MPerBlock = BlockGemmShape::kM; - static constexpr index_t NPerBlock = BlockGemmShape::kN; - static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -74,35 +59,35 @@ struct BlockGemmPipelineAGmemBGmemCReg constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; constexpr index_t WaveSize = 64; - constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); - constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + constexpr index_t WaveNumM = BlockGemm::MWarp; + constexpr index_t WaveNumN = BlockGemm::NWarp; constexpr index_t AB_LDS_RW_Width = GetSmemPack(); constexpr index_t A_Buffer_Load_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA()); constexpr index_t B_Buffer_Load_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB()); constexpr index_t A_LDS_Write_Inst_Num = - MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t B_LDS_Write_Inst_Num = - NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t A_LDS_Read_Inst_Num = - WaveNumN * MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width); + WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); - constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / - (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); // A/B split schedule // compiler is likely to use ds_read2 when instruction width smaller than 16bytes - constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 + constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? A_LDS_Read_Inst_Num : A_LDS_Read_Inst_Num / 2; - constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 + constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? B_LDS_Read_Inst_Num : B_LDS_Read_Inst_Num / 2; @@ -116,9 +101,9 @@ struct BlockGemmPipelineAGmemBGmemCReg constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; constexpr auto ds_read_a_issue_cycle = - AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4; constexpr auto ds_read_b_issue_cycle = - AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; + AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; constexpr auto ds_read_a_mfma_rate = (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); constexpr auto ds_read_b_mfma_rate = @@ -275,18 +260,18 @@ struct BlockGemmPipelineAGmemBGmemCReg {0, 0}, b_copy_dram_window.get_tile_distribution()); -#if defined(ENABLE_INSTRUCTION_SCH) +#if defined(ENABLE_PREFETCH) // A LDS tile for block GEMM auto a_lds_gemm_window = make_tile_window( a_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode())); // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( b_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode())); #else @@ -313,59 +298,63 @@ struct BlockGemmPipelineAGmemBGmemCReg ABlockTile a_block_tile; BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock); // ------------------------------------------------------------------------------------- // Gemm pipeline start -#if defined(ENABLE_INSTRUCTION_SCH) - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; - using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); - - // Prefetch - // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - load_tile(b_block_tile, b_copy_dram_window); - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - // Initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - // LDS write 0 - store_tile(a_copy_lds_window, a_block_tile); - store_tile(b_copy_lds_window, b_block_tile); - +#if defined(ENABLE_PREFETCH) +#pragma message("global prefetch") + // Prefetch // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - load_tile(b_block_tile, b_copy_dram_window); - move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); - block_sync_lds(); + if(num_loop > 1) + { + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - // Prefetch from LDS to warp register in block gemm - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // LDS write 0 + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + // Prefetch from LDS to warp register in block gemm + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + } __builtin_amdgcn_sched_barrier(0); // Main body - if constexpr(HasHotLoop) + if(num_loop > 2) { - index_t i = 0; + index_t iCounter = 0; do { block_sync_lds(); - // LDS write 0 + // LDS write 1 store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); - // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); + // Global read 2 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - load_tile(b_block_tile, b_copy_dram_window); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); @@ -375,116 +364,37 @@ struct BlockGemmPipelineAGmemBGmemCReg // Prefetch from LDS to warp register in block gemm block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); +#if defined(ENABLE_INSTRUCTION_SCH) HotLoopScheduler(); +#endif __builtin_amdgcn_sched_barrier(0); - i += 1; - } while(i < (num_loop - 2)); + iCounter += 1; + } while(iCounter < (num_loop - 2)); } // Tail - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); + if(num_loop > 1) + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + } store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); block_sync_lds(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - -#elif defined(ENABLE_PREFETCH) - // Prefetch - // Global read 0 - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); - - { - // Move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // Initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - store_tile(a_copy_lds_window, a_block_tile); - // Global read 1 - load_tile(a_block_tile, a_copy_dram_window); - - // LDS write 0 - store_tile(b_copy_lds_window, b_block_tile); - // Global read 1 - load_tile(b_block_tile, b_copy_dram_window); - } - - index_t iCounter = num_loop - 2; - - do - { - block_sync_lds(); - - // GEMM i - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // Move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - store_tile(a_copy_lds_window, a_block_tile); - // Global read i + 2 - load_tile(a_block_tile, a_copy_dram_window); - - // LDS write i + 1 - store_tile(b_copy_lds_window, b_block_tile); - // Global read i + 2 - load_tile(b_block_tile, b_copy_dram_window); - - iCounter--; - - } while(iCounter > 0); - - // Tail - { - block_sync_lds(); - - // GEMM num_loop - 2 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - - block_sync_lds(); - - // LDS write num_loop - 1 - store_tile(a_copy_lds_window, a_block_tile); - - store_tile(b_copy_lds_window, b_block_tile); - - block_sync_lds(); - - // GEMM num_loop - 1 - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - } #else // non-prefetch - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); - store_tile(a_copy_lds_window, a_block_tile); - store_tile(b_copy_lds_window, b_block_tile); - - block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); - block_sync_lds(); - - index_t iCounter = num_loop - 1; + index_t iCounter = num_loop; while(iCounter > 0) { - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - - load_tile(a_block_tile, a_copy_dram_window); - load_tile(b_block_tile, b_copy_dram_window); + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); store_tile(a_copy_lds_window, a_block_tile); store_tile(b_copy_lds_window, b_block_tile); diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp index 3ddd37ce1a..30bdcd679f 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -313,26 +313,18 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - static_assert(std::is_same_v); - + using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - return GetGlobalVectorLoadSize(); } template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { - using BDataType = remove_cvref_t; - using BLayout = remove_cvref_t; - static_assert(std::is_same_v); - + using BDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - return GetGlobalVectorLoadSize(); } diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp index 88a35ec2d2..8b581204b7 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp @@ -29,28 +29,6 @@ struct GridGemmProblem using CElementFunction = CElementFunction_; }; -#if defined(ENABLE_INSTRUCTION_SCH) -template -struct TileGemmShape -{ - using BlockTile = remove_cvref_t; - using BlockWarps = remove_cvref_t; - using WarpTile = remove_cvref_t; - - static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); - - static constexpr index_t kM = BlockTile::at(number<0>{}); - static constexpr index_t kN = BlockTile::at(number<1>{}); - static constexpr index_t kK = BlockTile::at(number<2>{}); - - static constexpr bool PermuteA = PermuteA_; - static constexpr bool PermuteB = PermuteB_; -}; -#else template struct TileGemmShape { @@ -58,71 +36,7 @@ struct TileGemmShape static constexpr index_t kN = kNPerTile; static constexpr index_t kK = kKPerTile; }; -#endif -#if defined(ENABLE_INSTRUCTION_SCH) -template -struct TileGemmTraits -{ - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool kPadK = kPadK_; - - static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - - using ALayout = ALayout_; - using BLayout = BLayout_; - using CLayout = CLayout_; - - static constexpr bool TransposeC = TransposeC_; -}; - -template -struct BlockGemmPipelineProblem -{ - using Traits = remove_cvref_t; - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - - using BlockGemmShape = remove_cvref_t; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - - static constexpr bool kPadM = Traits::kPadM; - static constexpr bool kPadN = Traits::kPadN; - static constexpr bool kPadK = Traits::kPadK; - - static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; - - static constexpr auto Scheduler = Scheduler_; - static constexpr auto HasHotLoop = HasHotLoop_; - static constexpr auto TailNum = TailNum_; - - static constexpr bool TransposeC = Traits::TransposeC; -}; -#else template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() { -#if defined(ENABLE_INSTRUCTION_SCH) - // Block GEMM pipeline w/ instruction scheduling - using GemmShape = TileGemmShape, - sequence, - sequence, - PermuteA, - PermuteB>; - - using GemmTraits = TileGemmTraits; - - using BlockGemmPipelineProblem_ = - BlockGemmPipelineProblem; -#else using BlockGemmPipelineProblem_ = BlockGemmPipelineProblem>; - -#endif return BlockGemmPipelineAGmemBGmemCReg{}; } }; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt old mode 100755 new mode 100644 index 44dfac099c..23fc7484dd --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt @@ -10,6 +10,12 @@ set(EXAMPLE_REDUCE_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" OFF) +if(ENABLE_TOY_FA_FWD_OPT) + message("Compiling with toy FA fwd optimization") + target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT) +endif() + target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 830b2422b5..129a4c5ed5 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -26,6 +26,251 @@ struct BlockGemmARegBSmemCRegV1 static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr auto config = + Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + template + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MPerXDL = WG::kM; + constexpr index_t NPerXDL = WG::kN; + constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = config.template get<1>(); + + constexpr index_t B_LDS_RW_Width = SmemPack; + + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * VectorSizeB); + + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + // B split schedule + constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_b_issue_cycle = + B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + // C += A * B + template + __device__ void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BLdsTile& b_block_tensor_tmp) const + { + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + // C += A * B template __device__ void operator()(CBlockTensor& c_block_tensor, @@ -38,6 +283,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -46,7 +293,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -180,6 +427,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -188,7 +437,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index fb1516eb52..8994689841 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1DefaultPolicy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + if constexpr(kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); + } + else if constexpr(kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); + } + else if constexpr(kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp index 32dc09f95e..e3f3fd0cd6 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1K8Policy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + if constexpr(kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); + } + else if constexpr(kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); + } + else if constexpr(kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 25e9ea7e1a..cfbd7d6376 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -13,16 +13,13 @@ namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCReg< - Problem, - BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy> +template +struct BlockGemmPipelineAGmemBGmemCReg { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -58,8 +55,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); ignore = a_element_func; @@ -135,6 +131,8 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_block_tile = load_tile(b_copy_dram_window); } + __builtin_amdgcn_sched_barrier(0); + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -159,6 +157,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } @@ -218,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< ignore = b_element_func; + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + // A tile in Reg,blockTensor // This tensor distribution used to construct both distributed tensor for local buffer store // and read. without buffer address info @@ -257,58 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg< // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile auto c_block_tile = decltype(block_gemm( get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - auto b_block_tile = load_tile(b_copy_dram_window); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); +#if !defined(TOY_FA_FWD_OPT) + static_for<0, k_loops, 1>{}([&](auto i_k0) { + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + }); +#else + using BLdsTile = typename decltype(block_gemm)::BLdsTile; + BLdsTile bWarpTile; + + // Global read 0 + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); if constexpr(k_loops > 1) { + // LDS write 0 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 + b_block_tile = load_tile(b_copy_dram_window); move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - store_tile(b_copy_lds_window, b_block_tile); - b_block_tile = load_tile(b_copy_dram_window); + block_sync_lds(); + + // LDS read 0 + bWarpTile = load_tile(b_lds_gemm_window); } if constexpr(k_loops > 2) { + __builtin_amdgcn_sched_barrier(0); static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { block_sync_lds(); + // LDS write 1 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, i_k0 * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); block_sync_lds(); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // LDS read 1 + bWarpTile = load_tile(b_lds_gemm_window); - store_tile(b_copy_lds_window, b_block_tile); - b_block_tile = load_tile(b_copy_dram_window); + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } - // tail { if constexpr(k_loops > 1) { - block_sync_lds(); - block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); block_sync_lds(); } @@ -316,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_sync_lds(); + bWarpTile = load_tile(b_lds_gemm_window); + block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); } - +#endif return c_block_tile; } @@ -336,9 +374,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const BDataType & b) { return b; }, a_reg_block_tensor_tmp, p_smem); } @@ -350,7 +388,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< { return operator()( b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const BDataType & b) { return b; }, a_reg_block_tensor_tmp, p_smem); } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 91b92ed51a..9b52143c92 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -3,43 +3,15 @@ #pragma once -#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" namespace ck_tile { -// NOTE: Assume A is K-Major -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy +template +struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy { - template - __host__ __device__ static constexpr auto MakeARegBlockDescriptor() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_reg_block_descriptor(); - } - - template - __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() - { - return policy_impl::make_b_lds_block_descriptor_3d_pad(); - } - - template - __host__ __device__ static constexpr auto MakeADramTileDistribution() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_dram_tile_distribution_skip_lds(); - } - - template - __host__ __device__ static constexpr auto MakeBDramTileDistribution() - { - return policy_impl::make_b_dram_tile_distribution(); - } + static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto GetBlockGemm() @@ -48,13 +20,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy return BlockGemmARegBSmemCRegV1{}; } -}; - -template -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy - : BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy -{ - static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() @@ -62,11 +27,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr auto blockgemm = GetBlockGemm(); using BlockGemm = remove_cvref_t; + static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!"); + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = AKDim; constexpr auto config = - BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -91,6 +58,87 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + return MakeARegBlockDescriptor(); + } + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp deleted file mode 100644 index ad6d6d3996..0000000000 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ /dev/null @@ -1,180 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck_tile/core.hpp" -#include "ck_tile/core/tensor/tile_distribution.hpp" - -namespace ck_tile { -namespace policy_impl { - -// 3d + padding -template -__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto a_lds_block_desc = - transform_tensor_descriptor(a_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; -} - -// 3d + padding -template -__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() -{ - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto b_lds_block_desc = - transform_tensor_descriptor(b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return b_lds_block_desc; -} - -template -__host__ __device__ static constexpr auto make_a_reg_block_descriptor() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - constexpr index_t NWarp = config.template get<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); - - return a_block_dstr; -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution() -{ - using ADataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds() -{ - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K2 = - WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane; - // // 16 / sizeof(ADataType); - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 0, 2>>{}); -} - -template -__host__ __device__ static constexpr auto make_b_dram_tile_distribution() -{ - using BDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto get_block_gemm() -{ - using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy; - - return BlockGemmASmemBSmemCReg{}; -} - -} // namespace policy_impl -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp index fba49f9de5..4ce61ed20c 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp @@ -29,32 +29,28 @@ int main(int argc, char* argv[]) using OaccDataType = float; using ODataType = ck_tile::half_t; - ck_tile::index_t Batch = 64; // Batch Number * Head Number - ck_tile::index_t M0 = 4096; // SequenceLengthQ - ck_tile::index_t N0 = 4096; // SequencelengthK - ck_tile::index_t K0 = 128; // HeadDim - ck_tile::index_t N1 = 128; // HeadDim - ck_tile::index_t verification = 0; - ck_tile::index_t init_method = 1; - [[maybe_unused]] ck_tile::index_t time_kernel = 0; + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; - if(argc == 4) + if(argc == 3) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); + verification = std::stoi(argv[2]); } - - if(argc == 9) + else if(argc == 8) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); - Batch = std::stoi(argv[4]); - M0 = std::stoi(argv[5]); - N0 = std::stoi(argv[6]); - K0 = std::stoi(argv[7]); - N1 = std::stoi(argv[8]); + verification = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); } std::array q_lengths{Batch, M0, K0}; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp index caeeece8e9..4317ebee8d 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp @@ -8,13 +8,69 @@ #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" #include "block_gemm_pipeline_problem.hpp" #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "flash_attention_fwd_impl.hpp" namespace ck_tile { +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) +{ + return [=](index_t block_1d_id) { + constexpr index_t M01 = 4; + constexpr index_t GroupNum = 8; + + const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; + const auto update_M0 = + ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + + const auto xcd_id = block_1d_id % GroupNum; + + const auto l_block_id = block_1d_id - (xcd_id % 2); + + const auto ridn = GroupNum * M01 * (update_N0 / 2); + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; + + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum; + + auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); + auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); + + const auto total_update_size = update_N0 * update_M0; + + if(block_1d_id >= total_update_size) + { + auto x = (block_1d_id + 1) - total_update_size; + auto rlen = N0 - update_N0; + + auto rm = 0; + auto rn = 0; + if(rlen > 0) + { + rm = (x - 1) / rlen; + rn = x % rlen; + } + + if(rlen > 0 and rm < M0) + { + n = rn + update_N0; + m = rm; + } + else + { + x = x - rlen * M0; + rm = (x - 1) / update_N0; + rn = x % update_N0; + n = rn; + m = update_M0 + rm; + } + } + return make_multi_index(m, n); + }; +} + // S[M0, N0] = Q[M0, K0] * K[N0, K0] // P[M0, N0] = Softmax(S[M0, N0]) // O[M0, N1] = P[M0, N0] * V[N1, N0] @@ -53,25 +109,38 @@ struct FlashAttentionFwd const index_t BatchStrideV, const index_t BatchStrideO) const { - // divide problem - const index_t num_tile_m0 = M0 / kM0PerBlock; - const index_t num_tile_n1 = N1 / kN1PerBlock; - const index_t id_block = get_block_id(); + const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock); + const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock); + +#if defined(TOY_FA_FWD_OPT) +#pragma message("Enable toy FA fwd opt") + const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1); + + const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0; + const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0); + + const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) % + num_tile_m0 * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % + num_tile_n1 * kN1PerBlock); + +#else const auto f = [](index_t dividend, index_t divisor) { index_t quotient = dividend / divisor; index_t modulus = dividend - quotient * divisor; return make_tuple(quotient, modulus); }; - const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); + const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); - const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); - const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); - const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); +#endif const auto kernel_impl = FlashAttentionFwdImpl{}, number{}, number{}), - make_tuple(number<(kNPerBlock + kPad) * kK1>{}, number{}, number<1>{}), - number{}, + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, number<1>{}); - constexpr auto b_lds_block_desc = transform_tensor_descriptor( + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } @@ -132,6 +152,10 @@ struct FlashAttentionFwdImpl constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + // allocate LDS __shared__ char smem_ptr[GetStaticLdsSize()]; @@ -146,7 +170,10 @@ struct FlashAttentionFwdImpl v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{}); auto q_dram_window = make_tile_window( - q_dram, make_tuple(number{}, number{}), {iM0, 0}); + q_dram, + make_tuple(number{}, number{}), + {iM0, 0}, + BlockGemm0Policy::template MakeADramTileDistribution()); auto k_dram_window = make_tile_window( k_dram, make_tuple(number{}, number{}), {0, 0}); @@ -156,22 +183,32 @@ struct FlashAttentionFwdImpl make_tuple(number{}, number{}), {iN1, 0}, MakeVDramTileDistribution()); - - // Q in Register - auto q_reg_tensor = make_static_distributed_tensor( - BlockGemm0Policy::template MakeARegBlockDescriptor()); + // Q in register + auto q_reg_tensor = load_tile(q_dram_window); // V LDS and LDS window // V LDS occupies the same LDS allocation Q/K LDS auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); +#if defined(TOY_FA_FWD_OPT) + // V LDS tile window for store + auto v_copy_lds_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); + + // V LDS tile for block GEMM + auto v_lds_gemm_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); +#else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM0 pipeline and Block GEMM1 - constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; - constexpr auto gemm1 = BlockGemm1{}; +#endif // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; @@ -209,22 +246,19 @@ struct FlashAttentionFwdImpl // loop over Column of S (J loop) index_t iN0 = 0; - // Cold Q_Reg_Cache - s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr); do { - // Hot Q_Reg_Cache - if(iN0 > 0) - { - s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); - } + s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); + // S{j} const auto s = tile_elementwise_in(type_convert, s_acc); +#if defined(TOY_FA_FWD_OPT) // prefetch load v tile - const auto v_prefetch = load_tile(v_dram_window); - + auto v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); +#endif // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, sequence<1>{}, f_max, std::numeric_limits::lowest()); @@ -274,10 +308,30 @@ struct FlashAttentionFwdImpl o_acc(i_j_idx) *= tmp; }); }); - block_sync_lds(); - store_tile(v_lds_window, v_prefetch); - move_tile_window(v_dram_window, {0, kK1PerBlock}); +#if !defined(TOY_FA_FWD_OPT) + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_lds_window, v); + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); + block_sync_lds(); + }); +#else + using VLdsTile = typename decltype(gemm1)::BLdsTile; + VLdsTile vWarpTile; // type cast Pcompute{j} into P{j} const auto p = @@ -288,29 +342,58 @@ struct FlashAttentionFwdImpl if constexpr(k1_loops > 1) { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v + store_tile(v_copy_lds_window, v_prefetch); + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + } + if constexpr(k1_loops > 2) + { + __builtin_amdgcn_sched_barrier(0); + static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) { block_sync_lds(); + + // LDS write 1 + store_tile(v_copy_lds_window, v_prefetch); + + // Global read 2 + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + gemm1(o_acc, get_slice_tile(p, sequence<0, i_k1 * kK1PerBlock>{}, sequence{}), - v_lds_window); + vWarpTile); block_sync_lds(); - store_tile(v_lds_window, v); - move_tile_window(v_dram_window, {0, kK1PerBlock}); + vWarpTile = load_tile(v_lds_gemm_window); + gemm1.template HotLoopScheduler<8, 4>(); + __builtin_amdgcn_sched_barrier(0); }); } // tail { + if constexpr(k1_loops > 1) + { + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } + store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1PerBlock>{}, sequence{}), - v_lds_window); + vWarpTile); block_sync_lds(); } +#endif // move tile windows move_tile_window(k_dram_window, {kN0PerBlock, 0}); iN0 += kN0PerBlock; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt old mode 100755 new mode 100644 index cc9873eeb3..e9a697fde3 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt @@ -7,8 +7,8 @@ endif() execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} - --working_path ${CMAKE_CURRENT_BINARY_DIR} + --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} + --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs RESULT_VARIABLE ret ) @@ -21,21 +21,21 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/flash_attention_fwd_blobs.txt FLASH_ATT add_custom_command( OUTPUT ${FLASH_ATTENTION_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} - --working_path ${CMAKE_CURRENT_BINARY_DIR} + --api ${FLASH_ATTENTION_FWD_ENABLE_APIS} + --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs ) set(EXAMPLE_REDUCE "codegen_basic_flash_attention_fwd") message("adding example ${EXAMPLE_REDUCE}") -add_executable(${EXAMPLE_REDUCE} - EXCLUDE_FROM_ALL +add_executable(${EXAMPLE_REDUCE} + EXCLUDE_FROM_ALL flash_attention_fwd.cpp ) -target_include_directories(${EXAMPLE_REDUCE} - PRIVATE +target_include_directories(${EXAMPLE_REDUCE} + PRIVATE ${CMAKE_CURRENT_LIST_DIR} ) @@ -45,14 +45,14 @@ message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}") set(EXAMPLE_REDUCE_COMPILE_OPTIONS) -list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS - -Wno-undefined-func-template - -Wno-float-equal +list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS + -Wno-undefined-func-template + -Wno-float-equal --offload-compress ) -target_compile_options(${EXAMPLE_REDUCE} - PRIVATE +target_compile_options(${EXAMPLE_REDUCE} + PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS} ) diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 830b2422b5..129a4c5ed5 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -26,6 +26,251 @@ struct BlockGemmARegBSmemCRegV1 static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr auto config = + Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + template + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MPerXDL = WG::kM; + constexpr index_t NPerXDL = WG::kN; + constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNumM = config.template get<1>(); + + constexpr index_t B_LDS_RW_Width = SmemPack; + + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * VectorSizeB); + + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + // B split schedule + constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_b_issue_cycle = + B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + // C += A * B + template + __device__ void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BLdsTile& b_block_tensor_tmp) const + { + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + // C += A * B template __device__ void operator()(CBlockTensor& c_block_tensor, @@ -38,6 +283,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -46,7 +293,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -180,6 +427,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -188,7 +437,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index fb1516eb52..8994689841 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1DefaultPolicy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + if constexpr(kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); + } + else if constexpr(kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); + } + else if constexpr(kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp index 32dc09f95e..e3f3fd0cd6 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1K8Policy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + if constexpr(kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); + } + else if constexpr(kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); + } + else if constexpr(kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 25e9ea7e1a..cfbd7d6376 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -13,16 +13,13 @@ namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCReg< - Problem, - BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy> +template +struct BlockGemmPipelineAGmemBGmemCReg { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -58,8 +55,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); ignore = a_element_func; @@ -135,6 +131,8 @@ struct BlockGemmPipelineAGmemBGmemCReg< b_block_tile = load_tile(b_copy_dram_window); } + __builtin_amdgcn_sched_barrier(0); + if constexpr(k_loops > 2) { static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { @@ -159,6 +157,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< store_tile(b_copy_lds_window, b_block_tile); b_block_tile = load_tile(b_copy_dram_window); + + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } @@ -218,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< ignore = b_element_func; + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + // A tile in Reg,blockTensor // This tensor distribution used to construct both distributed tensor for local buffer store // and read. without buffer address info @@ -257,58 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg< // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM - constexpr auto block_gemm = Policy::template GetBlockGemm(); + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile auto c_block_tile = decltype(block_gemm( get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - auto b_block_tile = load_tile(b_copy_dram_window); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); +#if !defined(TOY_FA_FWD_OPT) + static_for<0, k_loops, 1>{}([&](auto i_k0) { + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); + block_sync_lds(); + }); +#else + using BLdsTile = typename decltype(block_gemm)::BLdsTile; + BLdsTile bWarpTile; + + // Global read 0 + auto b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); if constexpr(k_loops > 1) { + // LDS write 0 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 + b_block_tile = load_tile(b_copy_dram_window); move_tile_window(b_copy_dram_window, {0, kKPerBlock}); - store_tile(b_copy_lds_window, b_block_tile); - b_block_tile = load_tile(b_copy_dram_window); + block_sync_lds(); + + // LDS read 0 + bWarpTile = load_tile(b_lds_gemm_window); } if constexpr(k_loops > 2) { + __builtin_amdgcn_sched_barrier(0); static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { block_sync_lds(); + // LDS write 1 + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, i_k0 * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); block_sync_lds(); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // LDS read 1 + bWarpTile = load_tile(b_lds_gemm_window); - store_tile(b_copy_lds_window, b_block_tile); - b_block_tile = load_tile(b_copy_dram_window); + block_gemm.HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); }); } - // tail { if constexpr(k_loops > 1) { - block_sync_lds(); - block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); block_sync_lds(); } @@ -316,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg< block_sync_lds(); + bWarpTile = load_tile(b_lds_gemm_window); + block_gemm(c_block_tile, get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - b_copy_lds_window); + bWarpTile); } - +#endif return c_block_tile; } @@ -336,9 +374,9 @@ struct BlockGemmPipelineAGmemBGmemCReg< { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const BDataType & b) { return b; }, a_reg_block_tensor_tmp, p_smem); } @@ -350,7 +388,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< { return operator()( b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](const BDataType & b) { return b; }, a_reg_block_tensor_tmp, p_smem); } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 91b92ed51a..9b52143c92 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -3,43 +3,15 @@ #pragma once -#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" namespace ck_tile { -// NOTE: Assume A is K-Major -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy +template +struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy { - template - __host__ __device__ static constexpr auto MakeARegBlockDescriptor() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_reg_block_descriptor(); - } - - template - __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() - { - return policy_impl::make_b_lds_block_descriptor_3d_pad(); - } - - template - __host__ __device__ static constexpr auto MakeADramTileDistribution() - { - constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = remove_cvref_t; - - return policy_impl::make_a_dram_tile_distribution_skip_lds(); - } - - template - __host__ __device__ static constexpr auto MakeBDramTileDistribution() - { - return policy_impl::make_b_dram_tile_distribution(); - } + static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto GetBlockGemm() @@ -48,13 +20,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy return BlockGemmARegBSmemCRegV1{}; } -}; - -template -struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy - : BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy -{ - static constexpr index_t AKDim = AKDim_; template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() @@ -62,11 +27,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr auto blockgemm = GetBlockGemm(); using BlockGemm = remove_cvref_t; + static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!"); + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = AKDim; constexpr auto config = - BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -91,6 +58,87 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + return MakeARegBlockDescriptor(); + } + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp deleted file mode 100644 index ad6d6d3996..0000000000 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ /dev/null @@ -1,180 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck_tile/core.hpp" -#include "ck_tile/core/tensor/tile_distribution.hpp" - -namespace ck_tile { -namespace policy_impl { - -// 3d + padding -template -__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto a_lds_block_desc = - transform_tensor_descriptor(a_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; -} - -// 3d + padding -template -__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() -{ - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto b_lds_block_desc = - transform_tensor_descriptor(b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return b_lds_block_desc; -} - -template -__host__ __device__ static constexpr auto make_a_reg_block_descriptor() -{ - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - constexpr index_t NWarp = config.template get<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); - - return a_block_dstr; -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution() -{ - using ADataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds() -{ - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template get<1>(); - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K2 = - WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane; - // // 16 / sizeof(ADataType); - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, - sequence<0, 0, 2>>{}); -} - -template -__host__ __device__ static constexpr auto make_b_dram_tile_distribution() -{ - using BDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -} - -template -__host__ __device__ static constexpr auto get_block_gemm() -{ - using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy; - - return BlockGemmASmemBSmemCReg{}; -} - -} // namespace policy_impl -} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp index 9b7c9b1c6c..3750ede188 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp @@ -29,32 +29,28 @@ int main(int argc, char* argv[]) using OaccDataType = float; using ODataType = ck_tile::half_t; - ck_tile::index_t Batch = 64; // Batch Number * Head Number - ck_tile::index_t M0 = 4096; // SequenceLengthQ - ck_tile::index_t N0 = 4096; // SequencelengthK - ck_tile::index_t K0 = 128; // HeadDim - ck_tile::index_t N1 = 128; // HeadDim - ck_tile::index_t verification = 0; - ck_tile::index_t init_method = 1; - [[maybe_unused]] ck_tile::index_t time_kernel = 0; + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; - if(argc == 4) + if(argc == 3) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); + verification = std::stoi(argv[2]); } - - if(argc == 9) + else if(argc == 8) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); - Batch = std::stoi(argv[4]); - M0 = std::stoi(argv[5]); - N0 = std::stoi(argv[6]); - K0 = std::stoi(argv[7]); - N1 = std::stoi(argv[8]); + verification = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); } std::array q_lengths{Batch, M0, K0}; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index b8c5518b23..38c56a27e8 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -9,13 +9,69 @@ #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" #include "block_gemm_pipeline_problem.hpp" #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "flash_attention_fwd_impl.hpp" namespace ck_tile { +CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) +{ + return [=](index_t block_1d_id) { + constexpr index_t M01 = 4; + constexpr index_t GroupNum = 8; + + const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; + const auto update_M0 = + ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + + const auto xcd_id = block_1d_id % GroupNum; + + const auto l_block_id = block_1d_id - (xcd_id % 2); + + const auto ridn = GroupNum * M01 * (update_N0 / 2); + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; + + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum; + + auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); + auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); + + const auto total_update_size = update_N0 * update_M0; + + if(block_1d_id >= total_update_size) + { + auto x = (block_1d_id + 1) - total_update_size; + auto rlen = N0 - update_N0; + + auto rm = 0; + auto rn = 0; + if(rlen > 0) + { + rm = (x - 1) / rlen; + rn = x % rlen; + } + + if(rlen > 0 and rm < M0) + { + n = rn + update_N0; + m = rm; + } + else + { + x = x - rlen * M0; + rm = (x - 1) / update_N0; + rn = x % update_N0; + n = rn; + m = update_M0 + rm; + } + } + return make_multi_index(m, n); + }; +} + template struct FlashAttnArgs { @@ -83,25 +139,21 @@ struct FlashAttentionFwd const index_t BatchStrideV, const index_t BatchStrideO) const { - // divide problem - const index_t num_tile_m0 = M0 / kM0PerBlock; - const index_t num_tile_n1 = N1 / kN1PerBlock; - const index_t id_block = get_block_id(); - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; + const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock); + const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock); - return make_tuple(quotient, modulus); - }; + const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1); - const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); - const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); + const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0; + const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0); const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); - const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); - const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) % + num_tile_m0 * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % + num_tile_n1 * kN1PerBlock); const auto kernel_impl = FlashAttentionFwdImpl -// struct flash_attention_fwd_traits_ -// { -// using SaccDataType = ck_tile::remove_cvref_t; -// using SMPLComputeDataType = ck_tile::remove_cvref_t; -// using PDataType = ck_tile::remove_cvref_t; -// using OaccDataType = ck_tile::remove_cvref_t; - -// static constexpr index_t kBlockSize = kBlockSize_; -// static constexpr index_t kHeadDim = kHeadDim_; -// static constexpr index_t kM0PerBlock = kM0PerBlock_; -// static constexpr index_t kN0PerBlock = kN0PerBlock_; -// static constexpr index_t kK0PerBlock = kK0PerBlock_; -// static constexpr index_t kN1PerBlock = kN1PerBlock_; -// static constexpr index_t kK1PerBlock = kK1PerBlock_; - -// static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD -// static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; -// static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; -// }; - -// // TODO: fwd_api.cpp, fwd_common.cpp -// template -// using traits_ = flash_attention_fwd_traits_; -// // fw_api.cpp -// // Note: this internal API only declare, not define here, otherwise will block `make -j` -// template -// float flash_attention_fwd_(const FlashAttnArgs& a, -// const ck_tile::stream_config& stream_config); - -// // TODO: fwd_common.cpp -// template -// float flash_attention_fwd_(const FlashAttnArgs& a, -// const ck_tile::stream_config& stream_config) { -// using SaccDataType = typename Traits_::SaccDataType; -// using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; -// using PDataType = typename Traits_::PDataType; -// using OaccDataType = typename Traits_::OaccDataType; - -// index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); - -// std::cout << "grid size " << kGridSize << std::endl; - -// return ck_tile::launch_kernel(stream_config, -// ck_tile::make_kernel( -// ck_tile::FlashAttentionFwd{}, -// kGridSize, -// Traits_::kBlockSize, -// 0, -// a.q_ptr, -// a.k_ptr, -// a.v_ptr, -// a.o_ptr, -// a.M0, -// a.N0, -// a.K0, -// a.N1, -// a.Batch, -// a.strideQ, // StrideQ -// a.strideK, // StrideK -// a.strideV, // StrideV -// a.strideO, // StrideO -// a.batchStrideQ, // BatchStrideQ -// a.batchStrideK, // BatchStrideK -// a.batchStrideV, // BatchStrideV -// a.batchStrideO)); // BatchStrideO -// } - -// // TODO: change to only declare -// // TODO: fwd_api.cpp -// template -// float flash_attention_fwd(const FlashAttnArgs& a, -// const ck_tile::stream_config& stream_config) { -// constexpr ck_tile::index_t kM0PerBlock = 128; -// constexpr ck_tile::index_t kN0PerBlock = 128; -// constexpr ck_tile::index_t kK0PerBlock = 32; -// constexpr ck_tile::index_t kN1PerBlock = 128; -// constexpr ck_tile::index_t kK1PerBlock = 32; - -// constexpr ck_tile::index_t kBlockSize = 256; -// constexpr ck_tile::index_t kHeadDim = 128; - -// return flash_attention_fwd_> -// (a, stream_config); - -// } - -// TODO: change to only declare -// TODO: fwd_api.cpp template {}, number{}, number{}), - make_tuple(number<(kNPerBlock + kPad) * kK1>{}, number{}, number<1>{}), - number{}, + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, number<1>{}); - constexpr auto b_lds_block_desc = transform_tensor_descriptor( + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } @@ -132,6 +152,10 @@ struct FlashAttentionFwdImpl constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + // allocate LDS __shared__ char smem_ptr[GetStaticLdsSize()]; @@ -146,7 +170,10 @@ struct FlashAttentionFwdImpl v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{}); auto q_dram_window = make_tile_window( - q_dram, make_tuple(number{}, number{}), {iM0, 0}); + q_dram, + make_tuple(number{}, number{}), + {iM0, 0}, + BlockGemm0Policy::template MakeADramTileDistribution()); auto k_dram_window = make_tile_window( k_dram, make_tuple(number{}, number{}), {0, 0}); @@ -156,22 +183,32 @@ struct FlashAttentionFwdImpl make_tuple(number{}, number{}), {iN1, 0}, MakeVDramTileDistribution()); - - // Q in Register - auto q_reg_tensor = make_static_distributed_tensor( - BlockGemm0Policy::template MakeARegBlockDescriptor()); + // Q in register + auto q_reg_tensor = load_tile(q_dram_window); // V LDS and LDS window // V LDS occupies the same LDS allocation Q/K LDS auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), MakeVLdsBlockDescriptor()); +#if defined(TOY_FA_FWD_OPT) + // V LDS tile window for store + auto v_copy_lds_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); + + // V LDS tile for block GEMM + auto v_lds_gemm_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); +#else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM0 pipeline and Block GEMM1 - constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; - constexpr auto gemm1 = BlockGemm1{}; +#endif // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; @@ -209,22 +246,19 @@ struct FlashAttentionFwdImpl // loop over Column of S (J loop) index_t iN0 = 0; - // Cold Q_Reg_Cache - s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr); do { - // Hot Q_Reg_Cache - if(iN0 > 0) - { - s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); - } + s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); + // S{j} const auto s = tile_elementwise_in(type_convert, s_acc); +#if defined(TOY_FA_FWD_OPT) // prefetch load v tile - const auto v_prefetch = load_tile(v_dram_window); - + auto v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); +#endif // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, sequence<1>{}, f_max, std::numeric_limits::lowest()); @@ -274,10 +308,30 @@ struct FlashAttentionFwdImpl o_acc(i_j_idx) *= tmp; }); }); - block_sync_lds(); - store_tile(v_lds_window, v_prefetch); - move_tile_window(v_dram_window, {0, kK1PerBlock}); +#if !defined(TOY_FA_FWD_OPT) + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + move_tile_window(v_dram_window, {0, kK1PerBlock}); + store_tile(v_lds_window, v); + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); + block_sync_lds(); + }); +#else + using VLdsTile = typename decltype(gemm1)::BLdsTile; + VLdsTile vWarpTile; // type cast Pcompute{j} into P{j} const auto p = @@ -288,29 +342,58 @@ struct FlashAttentionFwdImpl if constexpr(k1_loops > 1) { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v + store_tile(v_copy_lds_window, v_prefetch); + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); + } + if constexpr(k1_loops > 2) + { + __builtin_amdgcn_sched_barrier(0); + static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) { block_sync_lds(); + + // LDS write 1 + store_tile(v_copy_lds_window, v_prefetch); + + // Global read 2 + v_prefetch = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + gemm1(o_acc, get_slice_tile(p, sequence<0, i_k1 * kK1PerBlock>{}, sequence{}), - v_lds_window); + vWarpTile); block_sync_lds(); - store_tile(v_lds_window, v); - move_tile_window(v_dram_window, {0, kK1PerBlock}); + vWarpTile = load_tile(v_lds_gemm_window); + gemm1.template HotLoopScheduler<8, 4>(); + __builtin_amdgcn_sched_barrier(0); }); } // tail { + if constexpr(k1_loops > 1) + { + gemm1(o_acc, + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); + block_sync_lds(); + } + store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); + vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1PerBlock>{}, sequence{}), - v_lds_window); + vWarpTile); block_sync_lds(); } +#endif // move tile windows move_tile_window(k_dram_window, {kN0PerBlock, 0}); iN0 += kN0PerBlock; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py index 4a36179dbe..00bc91cadc 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py @@ -11,16 +11,8 @@ import itertools import copy from dataclasses import dataclass -# def get_if_str(idx, total, last_else=True): -# if idx == 0: -# return 'if' -# elif idx < total - 1: -# return 'else if' -# else: -# return 'else' if last_else else 'else if' - def get_if_str(size_, total, last_else=True): - if size_ == "small": + if size_ == "head_dim_256_seq_4096": return 'if' else: return 'else if' @@ -34,18 +26,18 @@ def BOOL_MAP(b_) -> str: class FlashAttentionFwdCodegen: API_TRAITS_DEFINE = """ - + template + index_t kBlockSize_ = 256, + index_t kHeadDim_ = 128, + index_t kM0PerBlock_ = 128, + index_t kN0PerBlock_ = 128, + index_t kK0PerBlock_ = 64, + index_t kN1PerBlock_ = 128, + index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_ { using SaccDataType = ck_tile::remove_cvref_t; @@ -60,23 +52,23 @@ struct flash_attention_fwd_traits_ static constexpr index_t kK0PerBlock = kK0PerBlock_; static constexpr index_t kN1PerBlock = kN1PerBlock_; static constexpr index_t kK1PerBlock = kK1PerBlock_; - + static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD - static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; -}; - +}; + template + ck_tile::index_t kBlockSize = 256, + ck_tile::index_t kHeadDim = 128, + ck_tile::index_t kM0PerBlock = 128, + ck_tile::index_t kN0PerBlock = 128, + ck_tile::index_t kK0PerBlock = 64, + ck_tile::index_t kN1PerBlock = 128, + ck_tile::index_t kK1PerBlock = 64> using traits_ = flash_attention_fwd_traits_; """ -# API_COMMON_HEADER = """ -# // SPDX-License-Identifier: MIT -# // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -# #include -# #include "flash_attention_fwd.hpp" -# #include - -# #pragma once - -# using S = ck_tile::stream_config; -# using A = FlashAttnArgs; - -# {F_traits_define} - -# template -# float flash_attention_fwd_(const FlashAttnArgs& a, -# const ck_tile::stream_config& stream_config) {{ -# using SaccDataType = typename Traits_::SaccDataType; -# using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; -# using PDataType = typename Traits_::PDataType; -# using OaccDataType = typename Traits_::OaccDataType; - -# index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); - -# if(stream_config.log_level_ > 0) -# std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush; - -# return ck_tile::launch_kernel(stream_config, -# ck_tile::make_kernel( -# ck_tile::FlashAttentionFwd{{}}, -# kGridSize, -# Traits_::kBlockSize, -# 0, -# a.q_ptr, -# a.k_ptr, -# a.v_ptr, -# a.o_ptr, -# a.M0, -# a.N0, -# a.K0, -# a.N1, -# a.Batch, -# a.strideQ, // StrideQ -# a.strideK, // StrideK -# a.strideV, // StrideV -# a.strideO, // StrideO -# a.batchStrideQ, // BatchStrideQ -# a.batchStrideK, // BatchStrideK -# a.batchStrideV, // BatchStrideV -# a.batchStrideO)); // BatchStrideO -# }} -# """ - API_BASE = """ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. @@ -204,14 +124,6 @@ template float flash_attention_fwd && std::is_same_v && std::is_same_v && std::is_same_v) {{ -# {F_per_size_case} -# }} -# """ -# API_PER_SIZE_CASE = """ {F_if} {F_SIZE_COND} {{ -# {F_inner_dispatch} -# }} -# """ API_INNER_CASE = """ {F_if} {F_VEC_COND} r = flash_attention_fwd_>(a, stream_config); """ @@ -224,7 +136,7 @@ template float flash_attention_fwd + index_t kBlockSize_ = 256, + index_t kHeadDim_ = 128, + index_t kM0PerBlock_ = 128, + index_t kN0PerBlock_ = 128, + index_t kK0PerBlock_ = 64, + index_t kN1PerBlock_ = 128, + index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_ {{ using SaccDataType = ck_tile::remove_cvref_t; @@ -341,13 +253,13 @@ struct flash_attention_fwd_traits_ static constexpr index_t kK0PerBlock = kK0PerBlock_; static constexpr index_t kN1PerBlock = kN1PerBlock_; static constexpr index_t kK1PerBlock = kK1PerBlock_; - + static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; -}}; - - +}}; + + template float flash_attention_fwd_(const FlashAttnArgs& a, const ck_tile::stream_config& stream_config) {{ - using SaccDataType = typename Traits_::SaccDataType; - using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; - using PDataType = typename Traits_::PDataType; - using OaccDataType = typename Traits_::OaccDataType; - + using SaccDataType = typename Traits_::SaccDataType; + using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; + using PDataType = typename Traits_::PDataType; + using OaccDataType = typename Traits_::OaccDataType; + index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); if(stream_config.log_level_ > 0) @@ -433,7 +345,7 @@ float flash_attention_fwd_(const FlashAttnArgs None: w_p = Path(self.working_path) list_p = w_p / 'flash_attention_fwd_blobs.txt' blobs = self.get_blobs(args) - + with list_p.open('w') as list_f: # API related files list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") @@ -557,11 +467,11 @@ float flash_attention_fwd_(const FlashAttnArgs