From a874839dc282cddaf63ba4bace62e666fdb3a8ef Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 16 Oct 2025 15:41:26 +0000 Subject: [PATCH] Add template parameter to gemm_0 MakeCBlockTile() for the need of defining PcompBlockTileType --- .../block_gemm_areg_bsmem_creg_v2_hack_0.hpp | 4 +--- .../hstu_attention_fwd_pipeline.hpp | 18 +++++------------- ...u_attention_fwd_pipeline_default_policy.hpp | 8 +------- 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp index 7764178ea5..ee1f216740 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp @@ -236,11 +236,9 @@ struct BlockGemmARegBSmemCRegV2Hack_0 return make_static_tile_distribution(a_block_dstr_encode); } + template CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr index_t MPerBlock = BlockGemmShape::kM; - constexpr index_t NPerBlock = BlockGemmShape::kN; - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index fc3a2e8bd0..caad20e33f 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -166,10 +166,9 @@ struct HstuAttentionFwdPipelineQRKSVS // SaccBlockTile size is [kM0, kK1] // PcompBlockTile size is [kM0, kN0] - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - // using PcompBlockTileType = decltype(cast_tile(SaccBlockTileType{})); - using PcompBlockTileType = decltype(make_static_distributed_tensor( - Policy::template MakePRegTileDistribution())); + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); SaccBlockTileType sacc_tile; PcompBlockTileType pcomp_tile; @@ -393,17 +392,10 @@ struct HstuAttentionFwdPipelineQRKSVS sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); - auto sacc_tile_tmp = cast_tile(sacc_tile); - - using pcomp_tile_tmp_type = - decltype(get_slice_tile(pcomp_tile, sequence<0, 0>{}, sequence{})); - - pcomp_tile_tmp_type pcomp_tile_tmp; - - pcomp_tile_tmp.get_thread_buffer() = sacc_tile_tmp.get_thread_buffer(); + auto tmp_tile = cast_tile(sacc_tile); set_slice_tile(pcomp_tile, - pcomp_tile_tmp, + tmp_tile, sequence<0, i_k1 * kK1>{}, sequence{}); }); diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 653c3a1da3..5694b948b5 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -69,7 +69,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy }; template - CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() { using BlockGemm = remove_cvref_t())>; @@ -78,12 +78,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy Problem::HstuAttentionTileSetting::kN0>(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() - { - return MakePRegTileDistribution(); - } - template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias() {