mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
Rename GetKVBlockGemm to GetPVTBlockGemm
This commit is contained in:
@@ -46,9 +46,9 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
};
|
||||
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPVTWarpGemmKPerThreadSize()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem, kUseTrLoad>())>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVTBlockGemm<Problem, kUseTrLoad>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
@@ -108,7 +108,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
|
||||
{
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem, kUseTrLoad>() >= 8)
|
||||
if constexpr(GetPVTWarpGemmKPerThreadSize<Problem, kUseTrLoad>() >= 8)
|
||||
return 8;
|
||||
else
|
||||
return 4;
|
||||
@@ -181,7 +181,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
|
||||
constexpr index_t kKPack = GetPVTWarpGemmKPerThreadSize<Problem>();
|
||||
|
||||
return N0 * (N1 * kKPerBlock + kKPack);
|
||||
}
|
||||
@@ -632,14 +632,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmSingleRepN()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPVTBlockGemmSingleRepN()
|
||||
{
|
||||
return Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}) *
|
||||
Problem::HstuAttentionTileSetting::Gemm1BlockWarps::at(number<1>{});
|
||||
};
|
||||
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPVTBlockGemm()
|
||||
{
|
||||
using GemmProblem = BlockGemmProblem<
|
||||
typename Problem::QKVDataType,
|
||||
@@ -726,7 +726,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem, kUseTrLoad>())>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVTBlockGemm<Problem, kUseTrLoad>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
// used by NRepetitions2DEpilogue
|
||||
static constexpr index_t kGemm1SingleRepN =
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
Policy::template GetPVTBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Traits::kBlockPerCu != -1)
|
||||
@@ -164,7 +164,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVTBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
|
||||
@@ -70,7 +70,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
// used by NRepetitions2DEpilogue
|
||||
static constexpr index_t kGemm1SingleRepN =
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
Policy::template GetPVTBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Traits::kBlockPerCu != -1)
|
||||
@@ -165,7 +165,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem, true /*kUseTrLoad*/>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVTBlockGemm<Problem, true /*kUseTrLoad*/>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
|
||||
@@ -73,7 +73,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
// used by NRepetitions2DEpilogue
|
||||
static constexpr index_t kGemm1SingleRepN =
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
Policy::template GetPVTBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Traits::kBlockPerCu != -1)
|
||||
@@ -175,7 +175,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVTBlockGemm<Problem>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
|
||||
@@ -73,7 +73,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
// used by NRepetitions2DEpilogue
|
||||
static constexpr index_t kGemm1SingleRepN =
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
Policy::template GetPVTBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Traits::kBlockPerCu != -1)
|
||||
@@ -176,7 +176,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 =
|
||||
Policy::template GetKVBlockGemm<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
Policy::template GetPVTBlockGemm<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
|
||||
Reference in New Issue
Block a user