From 4eeb5cc917cb62a517ab80f7ff749c958ed66812 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 27 Oct 2025 10:34:45 +0000 Subject: [PATCH] Update to gemm_0's CBlockDistribution encoding so that it is compatible with gemm_1's ABlockDistribution encoding --- .../block_gemm_areg_bsmem_creg_v2_hack_0.hpp | 54 +++++++++++++------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp index ee1f216740..731810364c 100644 --- a/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp +++ b/example/ck_tile/18_hstu_attention/block_gemm_areg_bsmem_creg_v2_hack_0.hpp @@ -59,13 +59,15 @@ struct BlockGemmARegBSmemCRegV2Hack_0 const index_t iNWarp = get_warp_id() % NWarp; - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + static_assert(NWarp == 1, "Check failed!"); + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); @@ -210,7 +212,7 @@ struct BlockGemmARegBSmemCRegV2Hack_0 } template - CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -233,11 +235,19 @@ struct BlockGemmARegBSmemCRegV2Hack_0 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + return a_block_dstr_encode; + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode(); + return make_static_tile_distribution(a_block_dstr_encode); } template - CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -248,18 +258,28 @@ struct BlockGemmARegBSmemCRegV2Hack_0 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + static_assert(NWarp == 1, "Check failed!"); + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } + + template + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode(); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); return c_block_tensor;