This commit is contained in:
Ding, Yi
2026-04-23 03:46:52 -05:00
parent 5c9134f72b
commit f122fc731f
3 changed files with 22 additions and 34 deletions

View File

@@ -187,11 +187,10 @@ size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag
}}
template <>
void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
void* device_ws, const void* host_ws, size_t device_ws_size, size_t host_ws_size)
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
k_::PrepareWorkspaceDevice(device_ws, host_ws, device_ws_size, host_ws_size);
return k_::NeedsZeroDqAcc();
}}
template <>

View File

@@ -479,10 +479,7 @@ size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
const ck_tile::index_t* seqstart_qs,
const ck_tile::index_t* seqstart_ks);
template <typename Traits_, typename Arch = void>
void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_(void* device_ws,
const void* host_ws,
size_t device_ws_size,
size_t host_ws_size);
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
@@ -629,10 +626,15 @@ struct fmha_bwd_launcher
traits.seqstart_qs,
traits.seqstart_ks);
}
workspace_size = host_ws_size + device_ws_size;
prepare_workspace = [this](void* device_ws) {
fmha_bwd_dq_dk_dv_dq_prepare_ws_device_<T1, Arch>(
device_ws, ws_host.get(), device_ws_size, host_ws_size);
workspace_size = host_ws_size + device_ws_size;
const bool needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<T1, Arch>();
prepare_workspace = [this, needs_zero_dq_acc](void* device_ws) {
if(host_ws_size > 0)
HIP_CHECK_ERROR(
hipMemcpy(device_ws, ws_host.get(), host_ws_size, hipMemcpyHostToDevice));
if(needs_zero_dq_acc)
HIP_CHECK_ERROR(
hipMemset(static_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
};
}

View File

@@ -156,26 +156,15 @@ struct FmhaBwdWorkspaceManager
}
template <bool kUseQrQtrDorPipeline, bool kHasMask>
CK_TILE_HOST static void PrepareWorkspaceDevice(void* device_ws,
const void* host_ws,
size_t device_ws_size,
size_t host_ws_size)
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
{
constexpr bool NeedsZeroDqAcc = []() {
constexpr bool kUsePersistent =
!kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode;
// non-deterministic and persistent kernels use atomic-add to write dq
if constexpr(kUsePersistent || !kIsDeterministic)
return true;
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases we need to zero out it first
return kHasMask;
}();
if(host_ws_size > 0)
HIP_CHECK_ERROR(hipMemcpy(device_ws, host_ws, host_ws_size, hipMemcpyHostToDevice));
if(NeedsZeroDqAcc)
HIP_CHECK_ERROR(
hipMemset(reinterpret_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
constexpr bool kUsePersistent = !kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode;
// non-deterministic and persistent kernels use atomic-add to write dq
if constexpr(kUsePersistent || !kIsDeterministic)
return true;
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases we need to zero out it first
return kHasMask;
}
};
@@ -292,11 +281,9 @@ struct FmhaBwdDQDKDVKernel
FmhaPipeline::BlockFmhaShape::kN0>(
std::forward<Args>(args)...);
}
template <typename... Args>
CK_TILE_HOST static constexpr void PrepareWorkspaceDevice(Args&&... args)
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
{
WorkspaceManager::template PrepareWorkspaceDevice<kUseQrQtrDorPipeline, kHasMask>(
std::forward<Args>(args)...);
return WorkspaceManager::template NeedsZeroDqAcc<kUseQrQtrDorPipeline, kHasMask>();
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template