From 0033748c627827d86542d8623f757f7bfed05237 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Wed, 28 Jan 2026 10:37:13 -0500 Subject: [PATCH] revert custom ldstile, should be able to use the regular ones --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 10 +- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 142 ------------------ 2 files changed, 7 insertions(+), 145 deletions(-) diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index cdb00679e5..30ae9d9058 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -437,11 +437,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step); // tile distribution for the register tiles - // Use custom distributions that account for packed types constexpr auto ALdsTileDistr = - make_static_tile_distribution(Policy::template MakeALdsBlockDistributionEncode()); + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); constexpr auto BLdsTileDistr = - make_static_tile_distribution(Policy::template MakeBLdsBlockDistributionEncode()); + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); @@ -450,6 +449,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< ALdsTile a_block_tile0, a_block_tile1; BLdsTile b_block_tile0, b_block_tile1; + static_assert(sizeof(ALdsTile) == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize) * NWarp / BlockSize, "ALdsTile size is wrong!"); + static_assert(sizeof(BLdsTile) == NPerBlock * (KPerBlock * sizeof(BDataType) / BPackedSize) * MWarp / BlockSize, "BLdsTile size is wrong!"); + static_assert(Policy::template GetSmemSizeA() == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), "SmemSizeA size is wrong!"); + static_assert(Policy::template GetSmemSizeB() == (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, "SmemSizeB size is wrong!"); + ////////////// MX Scale register tiles (ping-pong buffers) ///////////////// // Calculate scale iterations: each scale covers 32 elements in K // Each K iteration processes KPerXdl elements diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index 01c4b4b9cc..7d5feecb8f 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -244,148 +244,6 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return BlockGemmARegBRegCRegV1{}; } - // Custom warp distribution encodings that account for packed types - // For 16x16x128 MFMA with pk_fp4_t, the K dimension must use storage elements, not logical elements - template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_AWarpDstrEncoding() - { - // For 16x16x128 MFMA with pk_fp4_t (PackedSize=2) - // Physical layout in registers: [16 M-lanes, 4 K-lanes, 16 bytes per lane] - // Each byte stores 2 fp4 values, so 16 bytes = 32 fp4 values - // WarpGemm expects LOGICAL dimensions, so use 32 (logical fp4), not 16 (storage) - constexpr index_t kAMLane = 16; - constexpr index_t kABKLane = 4; - constexpr index_t kABKPerLane = 32; // LOGICAL elements (not divided by PackedSize)! - // have also tried 16 here - - return tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BWarpDstrEncoding() - { - constexpr index_t kBNLane = 16; - constexpr index_t kABKLane = 4; - constexpr index_t kABKPerLane = 32; // LOGICAL elements (not divided by PackedSize)! have also tried 16 - - return tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - - // Custom LDS block distributions that account for packed types - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDistributionEncode() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t MWarp = BlockWarps::at(number<0>{}); - constexpr index_t NWarp = BlockWarps::at(number<1>{}); - constexpr index_t MPerXdl = WarpTile::at(number<0>{}); - constexpr index_t KPerXdl = WarpTile::at(number<2>{}); - - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - // Use LOGICAL dimensions for iteration count (matches WarpGemm expectations) - // LDS shape is [MPerBlock, KPerBlock / APackedSize] in storage (bytes) - constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; // Logical K iterations - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); - - constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1); - - if constexpr(UseDefaultScheduler) - { - // here the iters don't get affected by PackedSize - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple<>, - tuple<>, - sequence<1, 2>, - sequence<0, 0>>{}; - // Use custom warp encoding that accounts for packed types - return detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding()); - } - else - { - constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - // Use custom warp encoding that accounts for packed types - return detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding()); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDistributionEncode() - { - using BlockGemmShape = typename Problem::BlockGemmShape; - using BlockWarps = typename BlockGemmShape::BlockWarps; - using WarpTile = typename BlockGemmShape::WarpTile; - - constexpr index_t MWarp = BlockWarps::at(number<0>{}); - constexpr index_t NWarp = BlockWarps::at(number<1>{}); - constexpr index_t NPerXdl = WarpTile::at(number<1>{}); - constexpr index_t KPerXdl = WarpTile::at(number<2>{}); - - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - // Use LOGICAL dimensions for iteration count (matches WarpGemm expectations) - // LDS shape is [NPerBlock, KPerBlock / BPackedSize] in storage - // But distributions work in logical space - constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; // Logical K iterations - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); - - constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1); - - if constexpr(UseDefaultScheduler) - { - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple<>, - tuple<>, - sequence<1, 2>, - sequence<0, 0>>{}; - // Use custom warp encoding that accounts for packed types - return detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding()); - } - else - { - constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - // Use custom warp encoding that accounts for packed types - return detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding()); - } - } - // MX Scale tile distributions for loading from global memory // Using the proven "Flat" patterns from v1 policy template