From b3100b6f434caa2924e3dfe11a0ba5b3d7258aaa Mon Sep 17 00:00:00 2001 From: danyao12 Date: Sat, 20 Jul 2024 16:09:14 +0800 Subject: [PATCH] remove FmhaBwdTilePartitioner --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 9 ++-- include/ck_tile/ops/fmha.hpp | 1 - .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 52 +++++++++++++------ .../fmha/kernel/fmha_bwd_tile_partitioner.hpp | 50 ------------------ 4 files changed, 40 insertions(+), 72 deletions(-) delete mode 100644 include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 36033b43c1..1a88e30dde 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -104,8 +104,7 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< false>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = - ck_tile::FmhaBwdDQDKDVKernel, - fmha_bwd_pipeline_{F_idx}, + ck_tile::FmhaBwdDQDKDVKernel; @@ -517,8 +516,7 @@ using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO; using fmha_bwd_dot_do_o_kernel_{F_idx} = - ck_tile::FmhaBwdOGradDotOKernel, - fmha_bwd_dot_do_o_{F_idx}>; + ck_tile::FmhaBwdOGradDotOKernel; using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; @@ -641,8 +639,7 @@ using fmha_bwd_convert_dq_{F_idx} = typename ck_tile::BlockFmhaBwdConvertQGrad; using fmha_bwd_convert_dq_kernel_{F_idx} = - ck_tile::FmhaBwdConvertQGradKernel, - fmha_bwd_convert_dq_{F_idx}>; + ck_tile::FmhaBwdConvertQGradKernel; using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 0f7e03de95..408f066236 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -8,7 +8,6 @@ #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index df30e8b163..f84cb34e60 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -23,13 +23,9 @@ namespace ck_tile { -template +template struct FmhaBwdDQDKDVKernel { - using TilePartitioner = ck_tile::remove_cvref_t; using FmhaPipeline = ck_tile::remove_cvref_t; using KGradEpiloguePipeline = ck_tile::remove_cvref_t; using VGradEpiloguePipeline = ck_tile::remove_cvref_t; @@ -536,7 +532,17 @@ struct FmhaBwdDQDKDVKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) { - return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_); + return dim3( + batch_size_, nhead_, ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0)); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex() + { + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.x; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -554,7 +560,7 @@ struct FmhaBwdDQDKDVKernel __shared__ char smem_ptr[GetSmemSize()]; // divide problem - const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k); + const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex(); const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0); @@ -1037,10 +1043,9 @@ struct FmhaBwdDQDKDVKernel } }; -template +template struct FmhaBwdOGradDotOKernel { - using TilePartitioner = ck_tile::remove_cvref_t; using FmhaBwdOGradDotO = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu; @@ -1189,7 +1194,16 @@ struct FmhaBwdOGradDotOKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) { - return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex() + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -1199,7 +1213,7 @@ struct FmhaBwdOGradDotOKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // divide problem - const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); + const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex(); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); @@ -1286,10 +1300,9 @@ struct FmhaBwdOGradDotOKernel } }; -template +template struct FmhaBwdConvertQGradKernel { - using TilePartitioner = ck_tile::remove_cvref_t; using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu; @@ -1439,7 +1452,16 @@ struct FmhaBwdConvertQGradKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) { - return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex() + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -1449,7 +1471,7 @@ struct FmhaBwdConvertQGradKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // divide problem - const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); + const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex(); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp deleted file mode 100644 index 676e6f55e9..0000000000 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -template -struct FmhaBwdKTilePartitioner -{ - CK_TILE_HOST static constexpr auto - GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) - { - // TODO: this may need tuning - return dim3(batch_size_, nhead_, ck_tile::integer_divide_ceil(seqlen_k_, kN0)); - } - - CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/) - { - const index_t i_block = blockIdx.z; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.x; - - return ck_tile::make_tuple(i_block, i_nhead, i_batch); - } -}; - -template -struct FmhaBwdQTilePartitioner -{ - CK_TILE_HOST static constexpr auto - GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) - { - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_); - } - - CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/) - { - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - return ck_tile::make_tuple(i_block, i_nhead, i_batch); - } -}; - -} // namespace ck_tile