Add template parameter to gemm_0 MakeCBlockTile() for the need of defining PcompBlockTileType

This commit is contained in:
Qianfeng Zhang
2025-10-16 15:41:26 +00:00
parent 1a8f2f21fb
commit a874839dc2
3 changed files with 7 additions and 23 deletions

View File

@@ -236,11 +236,9 @@ struct BlockGemmARegBSmemCRegV2Hack_0
return make_static_tile_distribution(a_block_dstr_encode);
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;

View File

@@ -166,10 +166,9 @@ struct HstuAttentionFwdPipelineQRKSVS
// SaccBlockTile size is [kM0, kK1]
// PcompBlockTile size is [kM0, kN0]
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
// using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
using PcompBlockTileType = decltype(make_static_distributed_tensor<CompDataType>(
Policy::template MakePRegTileDistribution<Problem>()));
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
SaccBlockTileType sacc_tile;
PcompBlockTileType pcomp_tile;
@@ -393,17 +392,10 @@ struct HstuAttentionFwdPipelineQRKSVS
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto sacc_tile_tmp = cast_tile<CompDataType>(sacc_tile);
using pcomp_tile_tmp_type =
decltype(get_slice_tile(pcomp_tile, sequence<0, 0>{}, sequence<kM0, kK1>{}));
pcomp_tile_tmp_type pcomp_tile_tmp;
pcomp_tile_tmp.get_thread_buffer() = sacc_tile_tmp.get_thread_buffer();
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
pcomp_tile_tmp,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});

View File

@@ -69,7 +69,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
@@ -78,12 +78,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
Problem::HstuAttentionTileSetting::kN0>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
{
return MakePRegTileDistribution<Problem>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{