From f122fc731f85d76071a13147ee161ba34cd5c446 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Thu, 23 Apr 2026 03:46:52 -0500 Subject: [PATCH] fix --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 5 ++- example/ck_tile/01_fmha/fmha_bwd.hpp | 18 +++++----- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 33 ++++++------------- 3 files changed, 22 insertions(+), 34 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 813667df0f..f89a7d75e4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -187,11 +187,10 @@ size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_ -void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_( - void* device_ws, const void* host_ws, size_t device_ws_size, size_t host_ws_size) +bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_() {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - k_::PrepareWorkspaceDevice(device_ws, host_ws, device_ws_size, host_ws_size); + return k_::NeedsZeroDqAcc(); }} template <> diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index bab3d5ac15..14f4c210f0 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -479,10 +479,7 @@ size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws, const ck_tile::index_t* seqstart_qs, const ck_tile::index_t* seqstart_ks); template -void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_(void* device_ws, - const void* host_ws, - size_t device_ws_size, - size_t host_ws_size); +bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); template struct fmha_bwd_dot_do_o_traits_ @@ -629,10 +626,15 @@ struct fmha_bwd_launcher traits.seqstart_qs, traits.seqstart_ks); } - workspace_size = host_ws_size + device_ws_size; - prepare_workspace = [this](void* device_ws) { - fmha_bwd_dq_dk_dv_dq_prepare_ws_device_( - device_ws, ws_host.get(), device_ws_size, host_ws_size); + workspace_size = host_ws_size + device_ws_size; + const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_(); + prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) { + if(host_ws_size > 0) + HIP_CHECK_ERROR( + hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice)); + if(needs_zero_dq_acc) + HIP_CHECK_ERROR( + hipMemset(static_cast(device_ws) + host_ws_size, 0, device_ws_size)); }; } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 57b6b7d3c2..fcbf6dfc2d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -156,26 +156,15 @@ struct FmhaBwdWorkspaceManager } template - CK_TILE_HOST static void PrepareWorkspaceDevice(void* device_ws, - const void* host_ws, - size_t device_ws_size, - size_t host_ws_size) + CK_TILE_HOST static constexpr bool NeedsZeroDqAcc() { - constexpr bool NeedsZeroDqAcc = []() { - constexpr bool kUsePersistent = - !kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode; - // non-deterministic and persistent kernels use atomic-add to write dq - if constexpr(kUsePersistent || !kIsDeterministic) - return true; - // Some block may be skipped with causal mask and dq are not set to zeros - // In these cases we need to zero out it first - return kHasMask; - }(); - if(host_ws_size > 0) - HIP_CHECK_ERROR(hipMemcpy(device_ws, host_ws, host_ws_size, hipMemcpyHostToDevice)); - if(NeedsZeroDqAcc) - HIP_CHECK_ERROR( - hipMemset(reinterpret_cast(device_ws) + host_ws_size, 0, device_ws_size)); + constexpr bool kUsePersistent = !kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode; + // non-deterministic and persistent kernels use atomic-add to write dq + if constexpr(kUsePersistent || !kIsDeterministic) + return true; + // Some block may be skipped with causal mask and dq are not set to zeros + // In these cases we need to zero out it first + return kHasMask; } }; @@ -292,11 +281,9 @@ struct FmhaBwdDQDKDVKernel FmhaPipeline::BlockFmhaShape::kN0>( std::forward(args)...); } - template - CK_TILE_HOST static constexpr void PrepareWorkspaceDevice(Args&&... args) + CK_TILE_HOST static constexpr bool NeedsZeroDqAcc() { - WorkspaceManager::template PrepareWorkspaceDevice( - std::forward(args)...); + return WorkspaceManager::template NeedsZeroDqAcc(); } template // to avoid duplicated base class prblem, introduce an template