Override and fix GetAlignmentK()

This commit is contained in:
Qianfeng Zhang
2025-05-03 16:17:28 +00:00
parent da89540ee0
commit 611f2ce1f9

View File

@@ -51,6 +51,21 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return 4;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{