Update to gemm_0's CBlockDistribution encoding so that it is compatible with gemm_1's ABlockDistribution encoding

This commit is contained in:
Qianfeng Zhang
2025-10-27 10:34:45 +00:00
parent 98a241a2eb
commit 4eeb5cc917

View File

@@ -59,13 +59,15 @@ struct BlockGemmARegBSmemCRegV2Hack_0
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
static_assert(NWarp == 1, "Check failed!");
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
@@ -210,7 +212,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
@@ -233,11 +235,19 @@ struct BlockGemmARegBSmemCRegV2Hack_0
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode<MPerBlock, KPerBlock>();
return make_static_tile_distribution(a_block_dstr_encode);
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
@@ -248,18 +258,28 @@ struct BlockGemmARegBSmemCRegV2Hack_0
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
static_assert(NWarp == 1, "Check failed!");
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode<MPerBlock, NPerBlock>();
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;