mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
fix
This commit is contained in:
@@ -156,26 +156,15 @@ struct FmhaBwdWorkspaceManager
|
||||
}
|
||||
|
||||
template <bool kUseQrQtrDorPipeline, bool kHasMask>
|
||||
CK_TILE_HOST static void PrepareWorkspaceDevice(void* device_ws,
|
||||
const void* host_ws,
|
||||
size_t device_ws_size,
|
||||
size_t host_ws_size)
|
||||
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
|
||||
{
|
||||
constexpr bool NeedsZeroDqAcc = []() {
|
||||
constexpr bool kUsePersistent =
|
||||
!kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode;
|
||||
// non-deterministic and persistent kernels use atomic-add to write dq
|
||||
if constexpr(kUsePersistent || !kIsDeterministic)
|
||||
return true;
|
||||
// Some block may be skipped with causal mask and dq are not set to zeros
|
||||
// In these cases we need to zero out it first
|
||||
return kHasMask;
|
||||
}();
|
||||
if(host_ws_size > 0)
|
||||
HIP_CHECK_ERROR(hipMemcpy(device_ws, host_ws, host_ws_size, hipMemcpyHostToDevice));
|
||||
if(NeedsZeroDqAcc)
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemset(reinterpret_cast<char*>(device_ws) + host_ws_size, 0, device_ws_size));
|
||||
constexpr bool kUsePersistent = !kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode;
|
||||
// non-deterministic and persistent kernels use atomic-add to write dq
|
||||
if constexpr(kUsePersistent || !kIsDeterministic)
|
||||
return true;
|
||||
// Some block may be skipped with causal mask and dq are not set to zeros
|
||||
// In these cases we need to zero out it first
|
||||
return kHasMask;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -292,11 +281,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
FmhaPipeline::BlockFmhaShape::kN0>(
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST static constexpr void PrepareWorkspaceDevice(Args&&... args)
|
||||
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
|
||||
{
|
||||
WorkspaceManager::template PrepareWorkspaceDevice<kUseQrQtrDorPipeline, kHasMask>(
|
||||
std::forward<Args>(args)...);
|
||||
return WorkspaceManager::template NeedsZeroDqAcc<kUseQrQtrDorPipeline, kHasMask>();
|
||||
}
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
|
||||
Reference in New Issue
Block a user