Update in the implementation of GetAlignmentQ/GetAlignmentK/GetAlignmentV

This commit is contained in:
Qianfeng Zhang
2025-12-11 10:47:54 +00:00
parent 8640ffe8eb
commit db39b44bab
2 changed files with 89 additions and 31 deletions

View File

@@ -116,18 +116,23 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
using QDataType = remove_cvref_t<typename Problem::QKVDataType>;
if constexpr(Problem::kLoadWholeQTileOnceThroughLds)
{
return Problem::GetQDramTileAccessMaxVectorSize();
}
else
{
using QDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
? Problem::HstuAttentionTileSetting::kM0
: GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
return Problem::template GetDramTileAccessMaxVectorSize<QDataType,
kBlockSize,
kMPerBlock,
kKPerBlock>();
};
}
template <typename Problem>
@@ -142,16 +147,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{
using KDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
return Problem::GetKDramTileAccessMaxVectorSize();
}
template <typename Problem>
@@ -166,21 +162,22 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
// special consideration when shuffling is required before storing V to LDS
if constexpr(!Problem::kUseTrLoad)
{
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
// try to avoid writing sub-dword to LDS due to poor performance
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (ElemPerThread / kMinVecLoad);
@@ -189,10 +186,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
}
else
{
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
return Problem::GetVDramTileAccessMaxVectorSize();
};
}

View File

@@ -54,6 +54,70 @@ struct HstuAttentionFwdPipelineProblem
static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = AttentionTileSetting_::NumGemm1Warps;
static constexpr index_t kBlockSize = AttentionTileSetting_::NumWarps * get_warp_size();
template <typename DataType, index_t ElemPerThread>
CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
{
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
{
if constexpr(ElemPerThread % 8 == 0)
return 8;
else if constexpr(ElemPerThread % 6 == 0)
return 6;
else if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 2 == 0)
return 2;
return 1;
}
else if constexpr(std::is_same_v<DataType, float>)
{
if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 3 == 0)
return 3;
else if constexpr(ElemPerThread % 2 == 0)
return 2;
return 1;
}
else
static_assert(false, "The data type is not supported!");
};
template <typename DataType,
index_t kThreadBlockSize,
index_t kHigherDimSize,
index_t kLowerDimSize>
CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
{
constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize;
return GetMaxVectorSize<DataType, ElemPerThread>();
}
CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize()
{
constexpr index_t kMPerBlock = HstuAttentionTileSetting::kM0;
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim;
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kMPerBlock, kKPerBlock>();
}
CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize()
{
constexpr index_t kNPerBlock = HstuAttentionTileSetting::kN0Sub;
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim;
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kNPerBlock, kKPerBlock>();
}
CK_TILE_HOST_DEVICE static constexpr auto GetVDramTileAccessMaxVectorSize()
{
constexpr index_t kNPerBlock = HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kK1;
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kNPerBlock, kKPerBlock>();
};
};
} // namespace ck_tile