mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
FMHA BWD workspace: 4K-align dq_acc base
Mark the FmhaBwdWorkspaceManager size/offset accessors as CK_TILE_HOST (they are only invoked from host-side workspace setup), and pad GetWorkspaceHostSize up to a 4K boundary so the GPU dq_acc buffer always starts on a page-aligned offset.
This commit is contained in:
@@ -51,7 +51,7 @@ struct FmhaBwdWorkspaceManager
|
||||
static constexpr size_t ALIGNMENT = 16;
|
||||
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsSize(const int batch)
|
||||
CK_TILE_HOST static size_t GetDqAccSplitsSize(const int batch)
|
||||
{
|
||||
if constexpr(kUseQrQtrDorPipeline)
|
||||
return 0;
|
||||
@@ -59,28 +59,31 @@ struct FmhaBwdWorkspaceManager
|
||||
(kIsGroupMode && kIsDeterministic) ? static_cast<size_t>(batch) : 1;
|
||||
return integer_least_multiple(sizeof(index_t) * dqAccSplitsElems, ALIGNMENT);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsSize(const int batch)
|
||||
CK_TILE_HOST static size_t GetDqAccOffsetsSize(const int batch)
|
||||
{
|
||||
const auto dqAccOffsetsElems =
|
||||
(kIsGroupMode && kIsDeterministic) ? static_cast<size_t>(batch) : 0;
|
||||
return integer_least_multiple(sizeof(long_index_t) * dqAccOffsetsElems, ALIGNMENT);
|
||||
}
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST_DEVICE static size_t GetWorkspaceHostSize(const int batch)
|
||||
CK_TILE_HOST static size_t GetWorkspaceHostSize(const int batch)
|
||||
{
|
||||
if constexpr(kUseQrQtrDorPipeline)
|
||||
return 0;
|
||||
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch);
|
||||
const size_t raw =
|
||||
GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch);
|
||||
// Pad to 4K so dq_acc buffer always starts on a page-aligned boundary.
|
||||
return integer_least_multiple(raw, static_cast<size_t>(4096));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsOffset(const int) { return 0; }
|
||||
CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; }
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsOffset(const int batch)
|
||||
CK_TILE_HOST static size_t GetDqAccOffsetsOffset(const int batch)
|
||||
{
|
||||
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch);
|
||||
}
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccDataOffset(const int batch)
|
||||
CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch)
|
||||
{
|
||||
return GetWorkspaceHostSize<kUseQrQtrDorPipeline>(batch);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user