This commit is contained in:
Ding, Yi
2026-04-23 03:46:52 -05:00
parent 5c9134f72b
commit f122fc731f
3 changed files with 22 additions and 34 deletions

View File

@@ -156,26 +156,15 @@ struct FmhaBwdWorkspaceManager
}
template <bool kUseQrQtrDorPipeline, bool kHasMask>
CK_TILE_HOST static void PrepareWorkspaceDevice(void* device_ws,
const void* host_ws,
size_t device_ws_size,
size_t host_ws_size)
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
{
constexpr bool NeedsZeroDqAcc = []() {
constexpr bool kUsePersistent =
!kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode;
// non-deterministic and persistent kernels use atomic-add to write dq
if constexpr(kUsePersistent || !kIsDeterministic)
return true;
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases we need to zero out it first
return kHasMask;
}();
if(host_ws_size > 0)
HIP_CHECK_ERROR(hipMemcpy(device_ws, host_ws, host_ws_size, hipMemcpyHostToDevice));
if(NeedsZeroDqAcc)
HIP_CHECK_ERROR(
hipMemset(reinterpret_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
constexpr bool kUsePersistent = !kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode;
// non-deterministic and persistent kernels use atomic-add to write dq
if constexpr(kUsePersistent || !kIsDeterministic)
return true;
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases we need to zero out it first
return kHasMask;
}
};
@@ -292,11 +281,9 @@ struct FmhaBwdDQDKDVKernel
FmhaPipeline::BlockFmhaShape::kN0>(
std::forward<Args>(args)...);
}
template <typename... Args>
CK_TILE_HOST static constexpr void PrepareWorkspaceDevice(Args&&... args)
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
{
WorkspaceManager::template PrepareWorkspaceDevice<kUseQrQtrDorPipeline, kHasMask>(
std::forward<Args>(args)...);
return WorkspaceManager::template NeedsZeroDqAcc<kUseQrQtrDorPipeline, kHasMask>();
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template