Move GetKPackV() and GetAlignmentV() out of ck_tile fmha to hstu pipeline default policy for better visibility

This commit is contained in:
Qianfeng Zhang
2025-06-07 12:46:40 +00:00
parent b2db644dcd
commit 84eb9adc71

View File

@@ -101,6 +101,54 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return min(MaxVectorSize, ElemPerThread);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
return min(ElemPerThread, MaxVectorSize);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (ElemPerThread / kMinVecLoad);
return kVecLoad;
}
else
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
return min(ElemPerThread, MaxVectorSize);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{