Remove some not very much required interfaces from pipeline problem

This commit is contained in:
Qianfeng Zhang
2026-02-24 08:15:39 +00:00
parent 4d83c92fc4
commit b78a240d84
2 changed files with 14 additions and 40 deletions

View File

@@ -38,7 +38,7 @@ CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
return 1;
}
else
static_assert(false, "The data type is not supported!");
return 1;
};
template <typename DataType,
@@ -108,33 +108,6 @@ struct BlockFmhaPipelineProblem
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr auto QScaleEnum = Traits::QScaleEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize()
{
constexpr index_t kMPerBlock = BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = BlockFmhaShape::kQKHeaddim;
return detail::
GetDramTileAccessMaxVectorSize<QDataType, kBlockSize, kMPerBlock, kKPerBlock>();
}
CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize()
{
constexpr index_t kNPerBlock = BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = BlockFmhaShape::kK0;
return detail::
GetDramTileAccessMaxVectorSize<KDataType, kBlockSize, kNPerBlock, kKPerBlock>();
}
CK_TILE_HOST_DEVICE static constexpr auto GetVDramTileAccessMaxVectorSize()
{
constexpr index_t kNPerBlock = BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = BlockFmhaShape::kK1;
return detail::
GetDramTileAccessMaxVectorSize<VDataType, kBlockSize, kNPerBlock, kKPerBlock>();
};
};
template <typename QDataType_,

View File

@@ -153,18 +153,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// special consideration when shuffling is required before storing V to LDS
if constexpr(!Problem::kUseTrLoad)
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad = Problem::GetVDramTileAccessMaxVectorSize();
constexpr index_t kMaxVecLoad = detail::
GetDramTileAccessMaxVectorSize<VDataType, kBlockSize, kNPerBlock, kKPerBlock>();
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
// try to avoid writing sub-dword to LDS due to poor performance
@@ -176,7 +176,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
}
else
{
return Problem::GetVDramTileAccessMaxVectorSize();
return detail::
GetDramTileAccessMaxVectorSize<VDataType, kBlockSize, kNPerBlock, kKPerBlock>();
};
}
@@ -615,8 +616,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
auto warp_gemm = [&]() {
if constexpr((std::is_same_v<typename Problem::QDataType, half_t> ||
std::is_same_v<typename Problem::QDataType, bf16_t>)&&std::
is_same_v<typename Problem::SaccDataType, float>)
std::is_same_v<typename Problem::QDataType, bf16_t>) &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
@@ -681,8 +682,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
auto warp_gemm = [&]() {
if constexpr((std::is_same_v<typename Problem::VDataType, half_t> ||
std::is_same_v<typename Problem::VDataType, bf16_t>)&&std::
is_same_v<typename Problem::OaccDataType, float>)
std::is_same_v<typename Problem::VDataType, bf16_t>) &&
std::is_same_v<typename Problem::OaccDataType, float>)
{
constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});