diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp index d765a29a89..849d244463 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_policy.hpp @@ -46,9 +46,9 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy }; template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPVTWarpGemmKPerThreadSize() { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -108,7 +108,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() { - if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) + if constexpr(GetPVTWarpGemmKPerThreadSize() >= 8) return 8; else return 4; @@ -181,7 +181,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy { constexpr index_t N1 = GetAlignmentV(); constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); + constexpr index_t kKPack = GetPVTWarpGemmKPerThreadSize(); return N0 * (N1 * kKPerBlock + kKPack); } @@ -632,14 +632,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmSingleRepN() + CK_TILE_HOST_DEVICE static constexpr auto GetPVTBlockGemmSingleRepN() { return Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}) * Problem::HstuAttentionTileSetting::Gemm1BlockWarps::at(number<1>{}); }; template - CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() + CK_TILE_HOST_DEVICE static constexpr auto GetPVTBlockGemm() { using GemmProblem = BlockGemmProblem< typename Problem::QKVDataType, @@ -726,7 +726,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index d928cef99e..488e3bda71 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -69,7 +69,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS // used by NRepetitions2DEpilogue static constexpr index_t kGemm1SingleRepN = - Policy::template GetKVBlockGemmSingleRepN(); + Policy::template GetPVTBlockGemmSingleRepN(); static constexpr index_t kBlockPerCu = []() { if constexpr(Traits::kBlockPerCu != -1) @@ -164,7 +164,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVTBlockGemm(); // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index 9a6949d382..60c367f7ef 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -70,7 +70,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad // used by NRepetitions2DEpilogue static constexpr index_t kGemm1SingleRepN = - Policy::template GetKVBlockGemmSingleRepN(); + Policy::template GetPVTBlockGemmSingleRepN(); static constexpr index_t kBlockPerCu = []() { if constexpr(Traits::kBlockPerCu != -1) @@ -165,7 +165,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVTBlockGemm(); // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 5df4d905ce..54a456a03d 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -73,7 +73,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS // used by NRepetitions2DEpilogue static constexpr index_t kGemm1SingleRepN = - Policy::template GetKVBlockGemmSingleRepN(); + Policy::template GetPVTBlockGemmSingleRepN(); static constexpr index_t kBlockPerCu = []() { if constexpr(Traits::kBlockPerCu != -1) @@ -175,7 +175,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVTBlockGemm(); // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0] diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index a69c65ceee..7f3902bfad 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -73,7 +73,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad // used by NRepetitions2DEpilogue static constexpr index_t kGemm1SingleRepN = - Policy::template GetKVBlockGemmSingleRepN(); + Policy::template GetPVTBlockGemmSingleRepN(); static constexpr index_t kBlockPerCu = []() { if constexpr(Traits::kBlockPerCu != -1) @@ -176,7 +176,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = - Policy::template GetKVBlockGemm(); + Policy::template GetPVTBlockGemm(); // SaccBlockTile size is [kM0, kN0Sub] // PcompBlockTile size is [kM0, kN0]