From 3db57fc3488f422cdccb86cea3584da82efc3b44 Mon Sep 17 00:00:00 2001 From: asleepzzz Date: Thu, 28 Aug 2025 22:50:42 +0800 Subject: [PATCH] Revert "[CK_TILE] FMHA BWD Enable Tile 16x192 (#2741)" (#2757) This reverts commit f2e6edde3bd2b890733bf0ea011743d8d1deca49. [ROCm/composable_kernel commit: 038ea82315d7b45f31b807a69b80c2fb8c687d71] --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 35 ++--- example/ck_tile/01_fmha/fmha_bwd.cpp | 1 + example/ck_tile/01_fmha/fmha_bwd.hpp | 12 +- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 134 +++++++----------- ...wd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 40 +++--- ...mha_bwd_pipeline_trload_default_policy.hpp | 65 +++------ 6 files changed, 114 insertions(+), 173 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index bee1c77c7b..0391191fb2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -125,8 +125,7 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dvpad}, {F_deterministic}, {F_trload}, - {F_maxq}, - {F_bn0}>; + {F_maxq}>; #include @@ -219,10 +218,10 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{ + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; r = fmha_bwd_>(s, a); return r; }} @@ -387,7 +386,6 @@ def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize] elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': return [ FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), - FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), ] @@ -521,8 +519,7 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_mode}, {F_spad}, {F_dpad}, - {F_deterministic}, - {F_bn0}>; + {F_deterministic}>; #include @@ -659,17 +656,6 @@ class FmhaBwdApiTrait: if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' else : return f'a.hdim_v % {self.bhdv} == 0' - @property - def extra_cond(self) -> str: - if self.tr_load == 't' and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128: - return "&& (a.seqlen_k <= 256)" - else: - return "" - - @property - def convert_dq_bn0(self) -> int: - return self.tile.F_bn0 if self.deterministic == 't' else 0 - @property def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy @@ -694,7 +680,7 @@ class FmhaBwdApiTrait: return 2 return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad, + F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) @@ -722,8 +708,7 @@ class FmhaBwdApiPool: F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, - F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond, - F_convert_dq_bn0=trait.convert_dq_bn0) + F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled]) i += 1 return inners @@ -806,9 +791,6 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if tr_load == "t" and (dpad == "t" or dvpad == "t"): continue # tr_load cannot work with dpad or dvpad - if optdim_list != [-1]: - if hdim not in optdim_list: - continue t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load) if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): @@ -817,6 +799,9 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Flash attention integration if receipt == 2: diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index b51886e6d8..9f1e0f6948 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -803,6 +803,7 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); + dq_buf.SetZero(); dbias_buf.SetZero(); dq_acc_buf.SetZero(); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index f1f8eee5e4..8d35b2d12c 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -372,8 +372,7 @@ template + ck_tile::index_t MaxSeqLenQ_> struct fmha_bwd_dq_dk_dv_traits_ { }; @@ -413,10 +412,15 @@ template + bool kIsDeterministic_> struct fmha_bwd_convert_dq_traits_ { + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kIsDeterministic = kIsDeterministic_; }; template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 81950bd30a..9bd78b4077 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -103,41 +103,27 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR const auto do_lds_ptr0 = reinterpret_cast(smem_ptr_); const auto do_lds_ptr1 = reinterpret_cast( smem_ptr_ + Policy::template GetSmemSizeOGrad()); - const auto q_lds_ptr0 = reinterpret_cast( // + const auto q_lds_ptr0 = reinterpret_cast( // smem_ptr_ + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeOGrad()); - const auto q_lds_ptr1 = reinterpret_cast( // + const auto q_lds_ptr1 = reinterpret_cast( // smem_ptr_ + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeQ()); - const auto lse_lds_ptr0 = reinterpret_cast( + const auto lse_lds_ptr = reinterpret_cast( smem_ptr_ + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ()); - const auto lse_lds_ptr1 = reinterpret_cast( + const auto d_lds_ptr = reinterpret_cast( smem_ptr_ + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeLSE()); - const auto d_lds_ptr0 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE() + - Policy::template GetSmemSizeLSE()); - const auto d_lds_ptr1 = reinterpret_cast( - smem_ptr_ + Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeOGrad() + - Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE() + - Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()); const auto ds_lds_ptr = reinterpret_cast( smem_ptr_ + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeOGrad() + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + - Policy::template GetSmemSizeLSE() + - Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD() + - Policy::template GetSmemSizeD()); + Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()); const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); return run(k_lds_ptr, v_lds_ptr, @@ -145,10 +131,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR do_lds_ptr1, q_lds_ptr0, q_lds_ptr1, - lse_lds_ptr0, - lse_lds_ptr1, - d_lds_ptr0, - d_lds_ptr1, + lse_lds_ptr, + d_lds_ptr, ds_lds_ptr, bias_lds_ptr, std::forward(args)...); @@ -172,10 +156,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR OGradDataType* __restrict__ do_lds_ptr1, QDataType* __restrict__ q_lds_ptr0, QDataType* __restrict__ q_lds_ptr1, - LSEDataType* __restrict__ lse_lds_ptr0, - LSEDataType* __restrict__ lse_lds_ptr1, - DDataType* __restrict__ d_lds_ptr0, - DDataType* __restrict__ d_lds_ptr1, + LSEDataType* __restrict__ lse_lds_ptr, + DDataType* __restrict__ d_lds_ptr, GemmDataType* __restrict__ ds_lds_ptr, BiasDataType* __restrict__ bias_lds_ptr, const QDramBlockWindowTmp& q_dram_block_window_tmp, @@ -407,38 +389,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR "BiasDataType and BiasGradDataType should be the same!"); // LSE: HBM -> LDS ->Reg - auto lse_dram_window = - make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), - lse_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}, - Policy::template MakeLSEDDramTileDistribution()); + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); auto lse_lds = make_tensor_view( - lse_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); - auto lse_lds_read_window = - make_tile_window(lse_lds, - make_tuple(number{}), - {0}, - Policy::template MakeLSEDLdsReadBlockDescriptor()); + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); // D: HBM ->Reg - auto d_dram_window = - make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), - d_dram_block_window_tmp.get_window_lengths(), - {seqlen_q_start}, - Policy::template MakeLSEDDramTileDistribution()); + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); auto d_lds = make_tensor_view( - d_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); - auto d_lds_read_window = - make_tile_window(d_lds, - make_tuple(number{}), - {0}, - Policy::template MakeLSEDLdsReadBlockDescriptor()); + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); // RandVal: HBM ->Reg auto randval_dram_window = dropout.template MakeRandvalDramWindow( @@ -489,31 +471,27 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR decltype(gemm_2.MakeCBlockTile()) dp_acc, ds; decltype(gemm_4.MakeCBlockTile()) dq_acc; + decltype(load_tile(lse_dram_window)) lse_block_tile; + decltype(load_tile(d_dram_window)) d_block_tile; + index_t i_total_bodys = 0; auto main_body_impl = [&](auto is_prologue_, auto is_epilogue_, QDataType* const __restrict__ q_lds_ptr_curr, QDataType* const __restrict__ q_lds_ptr_next, OGradDataType* const __restrict__ do_lds_ptr_curr, - OGradDataType* const __restrict__ do_lds_ptr_next, - LSEDataType* const __restrict__ lse_lds_ptr_curr, - LSEDataType* const __restrict__ lse_lds_ptr_next, - DDataType* const __restrict__ d_lds_ptr_curr, - DDataType* const __restrict__ d_lds_ptr_next - - ) mutable { + OGradDataType* const __restrict__ do_lds_ptr_next) mutable { constexpr bool is_prologue = is_prologue_.value; constexpr bool is_epilogue = is_epilogue_.value; static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true"); constexpr bool is_main_body = is_prologue && is_epilogue; + if constexpr(is_prologue) { - lse_lds_write_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_next); - async_load_tile(lse_lds_write_window, lse_dram_window); + lse_block_tile = load_tile(lse_dram_window); move_tile_window(lse_dram_window, {kM0}); - d_lds_write_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_next); - async_load_tile(d_lds_write_window, d_dram_window); + d_block_tile = load_tile(d_dram_window); move_tile_window(d_dram_window, {kM0}); q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); @@ -532,13 +510,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); dot_reg_tensor = load_tile_transpose(dot_lds_read_window); } - if constexpr(is_epilogue) - { - lse_lds_read_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_curr); - lse = load_tile(lse_lds_read_window); - d_lds_read_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_curr); - d = load_tile(d_lds_read_window); - } if constexpr(is_main_body) Policy::template HotLoopScheduler::SchedulerGemm0(); __builtin_amdgcn_sched_barrier(0); @@ -646,6 +617,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR if constexpr(is_main_body) Policy::template HotLoopScheduler::SchedulerGemm12(); __builtin_amdgcn_sched_barrier(0); + if constexpr(is_prologue) + { + store_tile(lse_lds_write_window, lse_block_tile); + store_tile(d_lds_write_window, d_block_tile); + } if constexpr(is_epilogue) { // STAGE 5, P^T(PGrad^T - D) @@ -700,12 +676,13 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR store_tile(ds_lds_window, ds_gemm); } - s_waitcnt(); + __builtin_amdgcn_s_waitcnt(3952); block_sync_lds(); if constexpr(is_prologue) { q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); q_reg_tensor = load_tile(q_lds_read_window); + lse = load_tile(lse_lds_read_window); } if constexpr(is_epilogue) { @@ -743,6 +720,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR { do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); do_reg_tensor = load_tile(do_lds_read_window); + d = load_tile(d_lds_read_window); } if constexpr(is_main_body) Policy::template HotLoopScheduler::SchedulerGemm4(); @@ -771,25 +749,17 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR }; auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { - const bool is_even = (i_total_bodys % 2 == 0); - const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; - const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; - const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; - const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; - const auto lse_lds_ptr_curr = is_even ? lse_lds_ptr1 : lse_lds_ptr0; - const auto lse_lds_ptr_next = is_even ? lse_lds_ptr0 : lse_lds_ptr1; - const auto d_lds_ptr_curr = is_even ? d_lds_ptr1 : d_lds_ptr0; - const auto d_lds_ptr_next = is_even ? d_lds_ptr0 : d_lds_ptr1; + const bool is_even = (i_total_bodys % 2 == 0); + const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; + const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; + const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; + const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; main_body_impl(is_prologue_, is_epilogue_, q_lds_ptr_curr, q_lds_ptr_next, do_lds_ptr_curr, - do_lds_ptr_next, - lse_lds_ptr_curr, - lse_lds_ptr_next, - d_lds_ptr_curr, - d_lds_ptr_next); + do_lds_ptr_next); i_total_bodys += 1; }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index d4a4e6a2ea..5adb64564d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -363,38 +363,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR "BiasDataType and BiasGradDataType should be the same!"); // LSE: HBM -> LDS ->Reg - auto lse_dram_window = - make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), - lse_dram_block_window_tmp.get_window_lengths(), - {0}, - Policy::template MakeLSEDDramTileDistribution()); + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {0}, + Policy::template MakeLSEDDramTileDistribution()); auto lse_lds = make_tensor_view( lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); - auto lse_lds_read_window = - make_tile_window(lse_lds, - make_tuple(number{}), - {0}, - Policy::template MakeLSEDLdsReadBlockDescriptor()); + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); // D: HBM ->Reg - auto d_dram_window = - make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), - d_dram_block_window_tmp.get_window_lengths(), - {0}, - Policy::template MakeLSEDDramTileDistribution()); + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {0}, + Policy::template MakeLSEDDramTileDistribution()); auto d_lds = make_tensor_view( d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); - auto d_lds_read_window = - make_tile_window(d_lds, - make_tuple(number{}), - {0}, - Policy::template MakeLSEDLdsReadBlockDescriptor()); + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); // RandVal: HBM ->Reg auto randval_dram_window = dropout.template MakeRandvalDramWindow( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp index 30c2c26416..6259e5b473 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp @@ -194,7 +194,13 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() { - return GetTransposedAlignmentX(); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + return total_pixels / GetAlignmentOGrad(); } template @@ -352,30 +358,11 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy Problem::BlockFmhaShape::kVHeaddim>(); } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - - constexpr index_t N0 = MWarp * NWarp; - - constexpr index_t M1 = kMPerBlock; - constexpr index_t M0 = get_warp_size() / M1; - static_assert(M1 <= get_warp_size() && get_warp_size() % M1 == 0, - "M1 must be a factor of warp size"); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple>, - tuple, sequence<0, 1>>, - tuple, sequence<1, 0>>, - sequence<1>, - sequence<1>>{}); + return BlockFmhaBwdPipelineDefaultPolicy::MakeLSEDDramTileDistribution(); } template @@ -806,10 +793,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy return lsed_lds_block_desc; } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor() { - using BlockGemm = remove_cvref_t())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; constexpr index_t MWarp = config.template at<1>(); @@ -998,16 +984,15 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE() { - return static_cast(max( // - sizeof(int) * get_warp_size(), - sizeof(typename Problem::LSEDataType) * - MakeLSEDLdsWriteBlockDescriptor().get_element_space_size())); + return sizeof(typename Problem::LSEDataType) * + MakeLSEDLdsWriteBlockDescriptor().get_element_space_size(); } template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD() { - return GetSmemSizeLSE(); + return sizeof(typename Problem::DDataType) * + MakeLSEDLdsWriteBlockDescriptor().get_element_space_size(); } template @@ -1054,9 +1039,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy constexpr index_t smem_size_bias = GetSmemSizeBias(); constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v; - constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + - smem_size_lse * 2 + smem_size_d * 2 + - max(smem_size_bias, smem_size_ds); + constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + smem_size_lse + + smem_size_d + max(smem_size_bias, smem_size_ds); return max(smem_size_stage0, smem_size_stage1); } @@ -1106,8 +1090,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy static constexpr index_t LSE_VMEM_READ = 1; static constexpr index_t D_VMEM_READ = 1; - static constexpr index_t DQ_VMEM_WRITE = kM0 * kQKHeaddim / kBlockSize; // atomic add - // LDS Read static constexpr index_t OGradT_LDS_READ = kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad(); @@ -1134,12 +1116,11 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); static constexpr index_t OGradT_LDS_WRITE = kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad(); + static constexpr index_t LSE_LDS_WRITE = 1; + static constexpr index_t D_LDS_WRITE = 1; static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize; public: - static constexpr index_t TOTAL_VMEM_READ = - Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ + DQ_VMEM_WRITE; - CK_TILE_DEVICE static constexpr void SchedulerGemm0() { // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load @@ -1147,7 +1128,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy constexpr index_t VMEM_READ_INST = Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ; constexpr index_t MFMA_INST = Gemm0MFMA; - constexpr index_t LDS_READ_INST = OGradT_LDS_READ + LSE_LDS_READ + D_LDS_READ; + constexpr index_t LDS_READ_INST = OGradT_LDS_READ; constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST); static_for<0, lcm_inst, 1>{}([&](auto i) { @@ -1180,8 +1161,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy { // Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load. // Comp: SGradT x QT - constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE; - constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ; + constexpr index_t LDS_WRITE_INST = LSE_LDS_WRITE + D_LDS_WRITE + SGradT_LDS_WRITE; + constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ; constexpr index_t MFMA_INST = Gemm3MFMA; constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST; @@ -1204,7 +1185,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy { // Mem: SGrad, OGrad, D LDS load. // Comp: SGrad x KT - constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ; + constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ; constexpr index_t MFMA_INST = Gemm4MFMA; constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST);