mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Update in the implementation of GetAlignmentQ/GetAlignmentK/GetAlignmentV
This commit is contained in:
@@ -116,18 +116,23 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
if constexpr(Problem::kLoadWholeQTileOnceThroughLds)
|
||||
{
|
||||
return Problem::GetQDramTileAccessMaxVectorSize();
|
||||
}
|
||||
else
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
|
||||
? Problem::HstuAttentionTileSetting::kM0
|
||||
: GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
return min(MaxVectorSize, ElemPerThread);
|
||||
return Problem::template GetDramTileAccessMaxVectorSize<QDataType,
|
||||
kBlockSize,
|
||||
kMPerBlock,
|
||||
kKPerBlock>();
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -142,16 +147,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
return min(MaxVectorSize, ElemPerThread);
|
||||
return Problem::GetKDramTileAccessMaxVectorSize();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -166,21 +162,22 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
|
||||
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
// special consideration when shuffling is required before storing V to LDS
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
|
||||
// try to avoid writing sub-dword to LDS due to poor performance
|
||||
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: (ElemPerThread / kMinVecLoad);
|
||||
@@ -189,10 +186,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
return min(MaxVectorSize, ElemPerThread);
|
||||
return Problem::GetVDramTileAccessMaxVectorSize();
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -54,6 +54,70 @@ struct HstuAttentionFwdPipelineProblem
|
||||
static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = AttentionTileSetting_::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = AttentionTileSetting_::NumWarps * get_warp_size();
|
||||
|
||||
template <typename DataType, index_t ElemPerThread>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
|
||||
{
|
||||
if constexpr(ElemPerThread % 8 == 0)
|
||||
return 8;
|
||||
else if constexpr(ElemPerThread % 6 == 0)
|
||||
return 6;
|
||||
else if constexpr(ElemPerThread % 4 == 0)
|
||||
return 4;
|
||||
else if constexpr(ElemPerThread % 2 == 0)
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
if constexpr(ElemPerThread % 4 == 0)
|
||||
return 4;
|
||||
else if constexpr(ElemPerThread % 3 == 0)
|
||||
return 3;
|
||||
else if constexpr(ElemPerThread % 2 == 0)
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
static_assert(false, "The data type is not supported!");
|
||||
};
|
||||
|
||||
template <typename DataType,
|
||||
index_t kThreadBlockSize,
|
||||
index_t kHigherDimSize,
|
||||
index_t kLowerDimSize>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize;
|
||||
|
||||
return GetMaxVectorSize<DataType, ElemPerThread>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kMPerBlock = HstuAttentionTileSetting::kM0;
|
||||
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kMPerBlock, kKPerBlock>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = HstuAttentionTileSetting::kN0Sub;
|
||||
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = HstuAttentionTileSetting::kK1;
|
||||
|
||||
return GetDramTileAccessMaxVectorSize<QKVDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user