From f6bfcad4379be760693354be5cff2c476885dd69 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Fri, 13 Mar 2026 14:13:32 +0800 Subject: [PATCH] [CK_TILE] FMHA BWD Use Persistent Kernels in Deterministic Mode (#5174) ## Motivation This PR enables a persistent-kernel execution path for FMHA backward (dQ/dK/dV) in deterministic mode, adjusting how dQ accumulation is split, stored, and converted back to final gradients. ## Technical Details - Introduces a persistent-kernel grid mapping in deterministic mode and updates split-count calculation accordingly. - Extends kernel kargs to carry batch-related info needed for persistent scheduling and dQ conversion. - Refactors dQ store conditions and adds mask-type traits/utilities and runner logging updates. ## Test Plan - Jenkins [base](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/10/pipeline) - Jenkins [AITER](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/12/pipeline) - Jenkins [FMHA](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/11/pipeline) - local FA tests ## Test Result ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 19 +- example/ck_tile/01_fmha/fmha_bwd.hpp | 12 +- example/ck_tile/01_fmha/fmha_bwd_runner.hpp | 12 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 169 ++++++++++++++---- ...k_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 3 +- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 6 +- ...bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 3 +- 7 files changed, 175 insertions(+), 49 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 02055ffd9e..6739abf621 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -169,10 +169,17 @@ int fmha_bwd_dq_dk_dv_maxq_() }} template <> -int fmha_bwd_dq_dk_dv_dq_acc_splits_(ck_tile::index_t seqlen_k) +int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t) {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - return k_::GetDqAccSplits(seqlen_k); + return k_::GetDqAccSplits(t.batch, t.nhead_q, t.max_seqlen_k); +}} + +template <> +bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::NeedsZeroDqAcc(); }} template <> @@ -192,6 +199,7 @@ fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{ {F_launcher} run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }}; dq_acc_splits = 1; + needs_zero_dq_acc = false; }} @@ -231,7 +239,8 @@ FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """ run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{ return fmha_bwd_, {F_arch.tag}>(s, a); }}; - dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_(t.max_seqlen_k); + dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_(t); + needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); return; }} """ @@ -447,7 +456,7 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): results = KernelComponentFactoryGfx9.get_dq_dk_dv_tiles(dtype, tr_load) if dtype in ["fp16", "bf16"] and tr_load == "t": results.extend([ - FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + FmhaBwdDQDKDVTileSize( 32, 256, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32), @@ -823,7 +832,7 @@ class FmhaBwdApiTrait: @property def extra_cond(self) -> str: - if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128: + if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128 and self.tile.F_bhdq == 128: return " && (t.seqlen_k <= 256)" else: return "" diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 983ac50231..8eb8834e12 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -251,6 +251,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.seqlen_k_ptr, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, + args.batch, args.hdim_q, args.hdim_v, args.nhead_q, @@ -300,6 +301,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) dq_ptr, args.seqlen_q, args.seqlen_k, + args.batch, args.hdim_q, args.hdim_v, args.nhead_q, @@ -429,7 +431,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_dq_acc, args.batch_stride_dq, args.batch_stride_dq_acc, - args.split_stride_dq_acc); + args.split_stride_dq_acc, + args.batch, + args.nhead_q); } }(); @@ -465,8 +469,11 @@ template std::string fmha_bwd_dq_dk_dv_get_name_(); template int fmha_bwd_dq_dk_dv_maxq_(); +struct fmha_bwd_traits; template -int fmha_bwd_dq_dk_dv_dq_acc_splits_(ck_tile::index_t seqlen_k); +int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t); +template +bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); template struct fmha_bwd_dot_do_o_traits_ @@ -569,6 +576,7 @@ struct fmha_bwd_launcher { std::function run{}; ck_tile::index_t dq_acc_splits{0}; + bool needs_zero_dq_acc{true}; fmha_bwd_launcher(const fmha_bwd_traits&); diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 92ae94d9b1..3123e4f2a8 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -416,9 +416,10 @@ bwd_result fmha_bwd_run(mode_enum mode, << "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval << ", deterministic:" << deterministic - << (deterministic ? std::string(", workspace:") + - std::to_string(workspace_size_in_megabytes) + "MiB" - : "") + << (deterministic + ? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) + + "MiB|" + std::to_string(nsplits) + "splits" + : "") << ", mask:" << mask << std::flush; auto fmha_args = [&]() { @@ -842,10 +843,7 @@ bwd_result fmha_bwd_run(mode_enum mode, lse_buf.ToDevice(lse_host.data()); dbias_buf.SetZero(); - // non-deterministic kernels use atomic add to write dq - // Some block may be skipped with causal mask and dq are not set to zeros - // In these cases thus we need to zero out it first - if(!deterministic || mask.type != mask_enum::no_mask) + if(launcher.needs_zero_dq_acc) dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index ee9f87c525..d32d5a321d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -79,6 +79,8 @@ struct FmhaBwdDQDKDVKernel #else static constexpr bool kIsAvailable = !kUseTrLoad; #endif + static constexpr bool kUsePersistent = + kIsDeterministic && !kIsGroupMode && !kUseQrQtrDorPipeline; // clang-format off template struct t2s; @@ -124,13 +126,43 @@ struct FmhaBwdDQDKDVKernel #undef _TS_ // clang-format on } - CK_TILE_HOST static index_t GetDqAccSplits(index_t seqlen_k) + CK_TILE_HOST static index_t + GetDqAccSplits(index_t batch_size_, index_t nhead_, index_t seqlen_k_) { - if constexpr(kIsDeterministic) - return integer_divide_ceil(seqlen_k, FmhaPipeline::BlockFmhaShape::kN0); + // Be consistent with convert_dq kernel, though qrqtrdor pipeline doesn't use persistent + static constexpr bool kUsePersistent__ = kIsDeterministic && !kIsGroupMode; + if constexpr(kUsePersistent__) + { + const index_t dqdqkdv_workers = get_num_cus(); + const index_t jobs_per_head = + integer_divide_ceil(seqlen_k_, FmhaPipeline::BlockFmhaShape::kN0); + const index_t total_jobs = batch_size_ * nhead_ * jobs_per_head; + const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers); + if(jobs_per_head % jobs_per_worker == 0) + return jobs_per_head / jobs_per_worker; + else if(jobs_per_worker % jobs_per_head == 0) + return 1; + else + return 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker); + } + else if constexpr(kIsDeterministic) + return integer_divide_ceil(seqlen_k_, FmhaPipeline::BlockFmhaShape::kN0); else return 1; } + CK_TILE_HOST static constexpr bool NeedsZeroDqAcc() + { + // Be consistent with convert_dq kernel, though qrqtrdor pipeline doesn't use persistent + constexpr bool kUsePersistent__ = kIsDeterministic && !kIsGroupMode; + + // non-deterministic adn persistent kernels use atomic-add to write dq + if constexpr(kUsePersistent__ || !kIsDeterministic) + return true; + + // Some block may be skipped with causal mask and dq are not set to zeros + // In these cases we need to zero out it first + return kHasMask; + } template // to avoid duplicated base class prblem, introduce an template // arg @@ -282,6 +314,7 @@ struct FmhaBwdDQDKDVKernel struct FmhaBwdDeterministicKargs { ck_tile::index_t split_stride_dq_acc = 0; + ck_tile::index_t batch; // used for persistent kernel implementation }; struct FmhaBwdBatchModeKargs @@ -362,6 +395,7 @@ struct FmhaBwdDQDKDVKernel void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, + ck_tile::index_t batch, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -507,9 +541,10 @@ struct FmhaBwdDQDKDVKernel } if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline) - { kargs.split_stride_dq_acc = split_stride_dq_acc; - } + + if constexpr(kUsePersistent) + kargs.batch = batch; return kargs; } @@ -534,6 +569,7 @@ struct FmhaBwdDQDKDVKernel const void* seqlen_k_ptr, const void* cu_seqlen_q_ptr, const void* cu_seqlen_k_ptr, + ck_tile::index_t batch, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -659,9 +695,9 @@ struct FmhaBwdDQDKDVKernel } } if constexpr(kIsDeterministic) - { kargs.split_stride_dq_acc = split_stride_dq_acc; - } + if constexpr(kUsePersistent) + kargs.batch = batch; return kargs; } @@ -669,19 +705,12 @@ struct FmhaBwdDQDKDVKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) { - return dim3( - kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), - nhead_, - batch_size_); - } - - CK_TILE_DEVICE static constexpr auto GetTileIndex() - { - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - return ck_tile::make_tuple(i_block, i_nhead, i_batch); + const index_t jobs_per_head = + kUseQrQtrDorPipeline ? 1 : integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0); + if constexpr(kUsePersistent) + return dim3(get_num_cus(), 1, 1); + else + return dim3(jobs_per_head, nhead_, batch_size_); } CK_TILE_HOST static dim3 BlockSize() @@ -706,16 +735,64 @@ struct FmhaBwdDQDKDVKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { if constexpr(kIsAvailable) - run_(std::move(kargs)); + { + if constexpr(!kUsePersistent) + { + run_(std::move(kargs), blockIdx, blockIdx.x); + } + else + { + static_assert(!kUseQrQtrDorPipeline, + "Persistent kernel is not compatible with QR/QTR/DOR pipeline"); + const index_t worker_id = blockIdx.x; + const index_t worker_num = gridDim.x; + + const index_t jobs_per_head = + integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + const index_t total_heads = kargs.batch * kargs.num_head_q; + const index_t total_jobs = jobs_per_head * total_heads; + const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num); + + const index_t begin_job_id = worker_id * jobs_per_worker; + if(begin_job_id >= total_jobs) + return; // worker_id exceeds total jobs, exit early + const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs); + + // 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case + constexpr auto tile_n_interleave = [](index_t x, index_t n) { + if constexpr(kHasMask == false) + return x; + else + return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2); + }; + + index_t job_id = begin_job_id; + index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker); + do + { // loop over jobs assigned to this worker + const index_t i_head_flatten = job_id / jobs_per_head; + const index_t i_tile_n_ = job_id % jobs_per_head; + const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head); + const index_t i_batch = i_head_flatten / kargs.num_head_q; + const index_t i_nhead = i_head_flatten % kargs.num_head_q; + + if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head + i_split = 0; + run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split); + } while(++job_id < end_job_id); + } + } } - CK_TILE_DEVICE void run_(Kargs kargs) const + CK_TILE_DEVICE void run_(Kargs kargs, const dim3& tile_index, const index_t i_split) const { // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; // divide problem - const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex(); + const index_t i_tile_n = tile_index.x; + const index_t i_nhead = tile_index.y; + const index_t i_batch = tile_index.z; const index_t i_n0 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN0); @@ -931,21 +1008,21 @@ struct FmhaBwdDQDKDVKernel make_tuple(number{}, number{}), {0, 0}); - auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() { + auto dq_dram_window = [&, i_nhead_ = i_nhead]() { constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic; using DType = std::conditional_t; auto dq_acc_ptr = reinterpret_cast(kargs.dq_acc_ptr) + [&]() { if constexpr(kUseKSplit) return static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + - static_cast(i_tile_n_) * kargs.split_stride_dq_acc + + static_cast(i_split) * kargs.split_stride_dq_acc + batch_offset_dq_acc; else return static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + batch_offset_dq_acc; }(); - constexpr auto DstInMemOp = conditional_expr( + constexpr auto DstInMemOp = conditional_expr<(kUseKSplit && !kUsePersistent)>( memory_operation_enum::set, memory_operation_enum::atomic_add); const auto dq_acc_dram_naive = make_naive_tensor_view( @@ -1528,6 +1605,7 @@ struct FmhaBwdConvertQGradKernel static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ; static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic; + static constexpr bool kUsePersistent = kIsDeterministic && !kIsGroupMode; // clang-format off template struct t2s; @@ -1586,7 +1664,10 @@ struct FmhaBwdConvertQGradKernel struct FmhaBwdConvertQGradDeterministicKargs { - ck_tile::index_t split_stride_dq_acc = 0; + index_t split_stride_dq_acc = 0; + index_t dqdqkdv_workers = 0; // 0 for not using persistent kernel + index_t batch_size = 0; // for nsplits calc of persistent kernel + index_t nhead = 0; // for nsplits calc of persistent kernel }; struct FmhaBwdConvertQGradBatchModeKargs @@ -1630,7 +1711,9 @@ struct FmhaBwdConvertQGradKernel ck_tile::long_index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, ck_tile::long_index_t batch_stride_dq_acc, - ck_tile::index_t split_stride_dq_acc) + ck_tile::index_t split_stride_dq_acc, + ck_tile::index_t batch_size, + ck_tile::index_t nhead) { Kargs kargs{{dq_acc_ptr, dq_ptr, @@ -1648,6 +1731,12 @@ struct FmhaBwdConvertQGradKernel if constexpr(kIsDeterministic) { kargs.split_stride_dq_acc = split_stride_dq_acc; + if constexpr(kUsePersistent) + { + kargs.dqdqkdv_workers = get_num_cus(); + kargs.batch_size = batch_size; + kargs.nhead = nhead; + } } return kargs; @@ -1783,6 +1872,27 @@ struct FmhaBwdConvertQGradKernel QGradDataType* dq_ptr = reinterpret_cast(kargs.dq_ptr) + static_cast(i_nhead) * kargs.nhead_stride_dq + batch_offset_dq; + const index_t nsplits = [&]() { + const index_t jobs_per_head = integer_divide_ceil(kargs.seqlen_k, kN0); + if constexpr(!kIsDeterministic) + return 1; + else if constexpr(!kUsePersistent) + return jobs_per_head; + else + { + const index_t total_heads = kargs.batch_size * kargs.nhead; + const index_t total_jobs = jobs_per_head * total_heads; + const index_t jobs_per_worker = + integer_divide_ceil(total_jobs, kargs.dqdqkdv_workers); + const index_t i_head_flatten = i_batch * kargs.nhead + i_nhead; + + const index_t i_job_start = jobs_per_head * i_head_flatten; + const index_t begin_worker_id = i_job_start / jobs_per_worker; + const index_t end_worker_id = // inclusive + (i_job_start + jobs_per_head - 1) / jobs_per_worker; + return end_worker_id - begin_worker_id + 1; + } + }(); // dQAcc/dQ DRAM and DRAM window const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() { @@ -1793,8 +1903,6 @@ struct FmhaBwdConvertQGradKernel static_cast(i_nhead_) * (kargs.nhead_stride_dq_acc) + batch_offset_dq_acc; - const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0); - auto dq_acc_dram_naive = make_naive_tensor_view( dq_acc_ptr, make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q), @@ -1856,7 +1964,6 @@ struct FmhaBwdConvertQGradKernel if constexpr(kIsDeterministic) { - const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0); FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits); } else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 7cc424597a..e4332df930 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -753,7 +753,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR { tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); } - if constexpr(kIsDeterministic) + if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp == + memory_operation_enum::set) { store_tile(dq_dram_window, dq_acc); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 95c9a7ad19..03ee1486da 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -789,7 +789,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP { tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); } - if constexpr(kIsDeterministic) + if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp == + memory_operation_enum::set) { store_tile(dq_dram_window, dq_acc); } @@ -1034,7 +1035,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); } - if constexpr(kIsDeterministic) + if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp == + memory_operation_enum::set) { store_tile(dq_dram_window, dq_acc); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 16212c0d13..7f893a93ba 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -764,7 +764,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR { tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); } - if constexpr(kIsDeterministic) + if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp == + memory_operation_enum::set) { store_tile(dq_dram_window, dq_acc); }