mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Remove some not very much required interfaces from pipeline problem
This commit is contained in:
@@ -38,7 +38,7 @@ CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
static_assert(false, "The data type is not supported!");
|
||||
return 1;
|
||||
};
|
||||
|
||||
template <typename DataType,
|
||||
@@ -108,33 +108,6 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Traits::QScaleEnum;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kMPerBlock = BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<QDataType, kBlockSize, kMPerBlock, kKPerBlock>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = BlockFmhaShape::kK0;
|
||||
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<KDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = BlockFmhaShape::kK1;
|
||||
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<VDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
};
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
|
||||
@@ -153,18 +153,18 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
// special consideration when shuffling is required before storing V to LDS
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
constexpr index_t kMaxVecLoad = Problem::GetVDramTileAccessMaxVectorSize();
|
||||
constexpr index_t kMaxVecLoad = detail::
|
||||
GetDramTileAccessMaxVectorSize<VDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
|
||||
// try to avoid writing sub-dword to LDS due to poor performance
|
||||
@@ -176,7 +176,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
return Problem::GetVDramTileAccessMaxVectorSize();
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<VDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
};
|
||||
}
|
||||
|
||||
@@ -615,8 +616,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr((std::is_same_v<typename Problem::QDataType, half_t> ||
|
||||
std::is_same_v<typename Problem::QDataType, bf16_t>)&&std::
|
||||
is_same_v<typename Problem::SaccDataType, float>)
|
||||
std::is_same_v<typename Problem::QDataType, bf16_t>) &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
constexpr index_t WarpGemmM =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
@@ -681,8 +682,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr((std::is_same_v<typename Problem::VDataType, half_t> ||
|
||||
std::is_same_v<typename Problem::VDataType, bf16_t>)&&std::
|
||||
is_same_v<typename Problem::OaccDataType, float>)
|
||||
std::is_same_v<typename Problem::VDataType, bf16_t>) &&
|
||||
std::is_same_v<typename Problem::OaccDataType, float>)
|
||||
{
|
||||
constexpr index_t WarpGemmM =
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
Reference in New Issue
Block a user