mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Add template parameter to gemm_0 MakeCBlockTile() for the need of defining PcompBlockTileType
This commit is contained in:
@@ -236,11 +236,9 @@ struct BlockGemmARegBSmemCRegV2Hack_0
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
@@ -166,10 +166,9 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
// SaccBlockTile size is [kM0, kK1]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
// using PcompBlockTileType = decltype(cast_tile<CompDataType>(SaccBlockTileType{}));
|
||||
using PcompBlockTileType = decltype(make_static_distributed_tensor<CompDataType>(
|
||||
Policy::template MakePRegTileDistribution<Problem>()));
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
|
||||
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
|
||||
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
|
||||
|
||||
SaccBlockTileType sacc_tile;
|
||||
PcompBlockTileType pcomp_tile;
|
||||
@@ -393,17 +392,10 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
|
||||
auto sacc_tile_tmp = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
using pcomp_tile_tmp_type =
|
||||
decltype(get_slice_tile(pcomp_tile, sequence<0, 0>{}, sequence<kM0, kK1>{}));
|
||||
|
||||
pcomp_tile_tmp_type pcomp_tile_tmp;
|
||||
|
||||
pcomp_tile_tmp.get_thread_buffer() = sacc_tile_tmp.get_thread_buffer();
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
set_slice_tile(pcomp_tile,
|
||||
pcomp_tile_tmp,
|
||||
tmp_tile,
|
||||
sequence<0, i_k1 * kK1>{},
|
||||
sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
});
|
||||
|
||||
@@ -69,7 +69,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
|
||||
@@ -78,12 +78,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
Problem::HstuAttentionTileSetting::kN0>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
|
||||
{
|
||||
return MakePRegTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user