diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 3c33caf338..57b6b7d3c2 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -51,7 +51,7 @@ struct FmhaBwdWorkspaceManager static constexpr size_t ALIGNMENT = 16; template - CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsSize(const int batch) + CK_TILE_HOST static size_t GetDqAccSplitsSize(const int batch) { if constexpr(kUseQrQtrDorPipeline) return 0; @@ -59,28 +59,31 @@ struct FmhaBwdWorkspaceManager (kIsGroupMode && kIsDeterministic) ? static_cast(batch) : 1; return integer_least_multiple(sizeof(index_t) * dqAccSplitsElems, ALIGNMENT); } - CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsSize(const int batch) + CK_TILE_HOST static size_t GetDqAccOffsetsSize(const int batch) { const auto dqAccOffsetsElems = (kIsGroupMode && kIsDeterministic) ? static_cast(batch) : 0; return integer_least_multiple(sizeof(long_index_t) * dqAccOffsetsElems, ALIGNMENT); } template - CK_TILE_HOST_DEVICE static size_t GetWorkspaceHostSize(const int batch) + CK_TILE_HOST static size_t GetWorkspaceHostSize(const int batch) { if constexpr(kUseQrQtrDorPipeline) return 0; - return GetDqAccSplitsSize(batch) + GetDqAccOffsetsSize(batch); + const size_t raw = + GetDqAccSplitsSize(batch) + GetDqAccOffsetsSize(batch); + // Pad to 4K so dq_acc buffer always starts on a page-aligned boundary. + return integer_least_multiple(raw, static_cast(4096)); } - CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsOffset(const int) { return 0; } + CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; } template - CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsOffset(const int batch) + CK_TILE_HOST static size_t GetDqAccOffsetsOffset(const int batch) { return GetDqAccSplitsSize(batch); } template - CK_TILE_HOST_DEVICE static size_t GetDqAccDataOffset(const int batch) + CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch) { return GetWorkspaceHostSize(batch); }