FMHA BWD workspace: 4K-align dq_acc base

Mark the FmhaBwdWorkspaceManager size/offset accessors as CK_TILE_HOST
(they are only invoked from host-side workspace setup), and pad
GetWorkspaceHostSize up to a 4K boundary so the GPU dq_acc buffer always
starts on a page-aligned offset.
This commit is contained in:
Ding, Yi
2026-04-22 02:08:01 -05:00
parent 4195052efa
commit 3607588ca4

View File

@@ -51,7 +51,7 @@ struct FmhaBwdWorkspaceManager
static constexpr size_t ALIGNMENT = 16;
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsSize(const int batch)
CK_TILE_HOST static size_t GetDqAccSplitsSize(const int batch)
{
if constexpr(kUseQrQtrDorPipeline)
return 0;
@@ -59,28 +59,31 @@ struct FmhaBwdWorkspaceManager
(kIsGroupMode && kIsDeterministic) ? static_cast<size_t>(batch) : 1;
return integer_least_multiple(sizeof(index_t) * dqAccSplitsElems, ALIGNMENT);
}
CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsSize(const int batch)
CK_TILE_HOST static size_t GetDqAccOffsetsSize(const int batch)
{
const auto dqAccOffsetsElems =
(kIsGroupMode && kIsDeterministic) ? static_cast<size_t>(batch) : 0;
return integer_least_multiple(sizeof(long_index_t) * dqAccOffsetsElems, ALIGNMENT);
}
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST_DEVICE static size_t GetWorkspaceHostSize(const int batch)
CK_TILE_HOST static size_t GetWorkspaceHostSize(const int batch)
{
if constexpr(kUseQrQtrDorPipeline)
return 0;
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch);
const size_t raw =
GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch);
// Pad to 4K so dq_acc buffer always starts on a page-aligned boundary.
return integer_least_multiple(raw, static_cast<size_t>(4096));
}
CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsOffset(const int) { return 0; }
CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; }
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsOffset(const int batch)
CK_TILE_HOST static size_t GetDqAccOffsetsOffset(const int batch)
{
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch);
}
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST_DEVICE static size_t GetDqAccDataOffset(const int batch)
CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch)
{
return GetWorkspaceHostSize<kUseQrQtrDorPipeline>(batch);
}