From 70c7fcda43e163bf2be53a8185892dff6a1e677b Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 26 Jan 2026 11:33:45 -0500 Subject: [PATCH] WIP: debugging... --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 41 ++++- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 157 ++++++++++++++++++ 2 files changed, 190 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index cd486df521..aafcb70002 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -295,22 +295,46 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< ////////////// global window & register ///////////////// // A DRAM tile window(s) for load + auto a_tile_windows = generate_tuple( [&](auto idx) { + // Get bottom tensor view and window origin: need to divide by APackedSize + auto&& bottom_tensor_view = a_dram_block_window_tmp[number{}].get_bottom_tensor_view(); + auto&& tensor_ptr = reinterpret_cast(&(bottom_tensor_view.get_buffer_view()(0))); + auto&& tensor_view = make_naive_tensor_view( + tensor_ptr, + make_tuple(4096, 4096 / APackedSize), + make_tuple(4096 / APackedSize, 1), + number<32>{}, + number<1>{}); + const auto& origin = a_dram_block_window_tmp[number{}].get_window_origin(); return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + // a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + tensor_view, make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), + {origin[0], origin[1] / APackedSize}, Policy::template MakeADramTileDistribution()); }, number{}); // B DRAM window(s) for load auto b_tile_windows = generate_tuple( [&](auto idx) { + // Get bottom tensor view and window origin: need to divide by BPackedSize + auto&& bottom_tensor_view = b_dram_block_window_tmp[number{}].get_bottom_tensor_view(); + auto&& tensor_ptr = reinterpret_cast(&(bottom_tensor_view.get_buffer_view()(0))); + auto&& tensor_view = make_naive_tensor_view( + tensor_ptr, + make_tuple(4096, 4096 / BPackedSize), + make_tuple(4096 / BPackedSize, 1), + number<32>{}, + number<1>{}); + const auto& origin = b_dram_block_window_tmp[number{}].get_window_origin(); return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + // b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + tensor_view, make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), + // b_dram_block_window_tmp[number{}].get_window_origin(), + {origin[0], origin[1] / BPackedSize}, Policy::template MakeBDramTileDistribution()); }, number{}); @@ -397,9 +421,9 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; constexpr ADramTileWindowStep a_dram_tile_window_step = - is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize); constexpr BDramTileWindowStep b_dram_tile_window_step = - is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize); // read A(0), B(0) from DRAM to LDS window(0) // and advance the DRAM windows @@ -420,10 +444,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step); // tile distribution for the register tiles + // Use custom distributions that account for packed types constexpr auto ALdsTileDistr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + make_static_tile_distribution(Policy::template MakeALdsBlockDistributionEncode()); constexpr auto BLdsTileDistr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + make_static_tile_distribution(Policy::template MakeBLdsBlockDistributionEncode()); using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index 638e7fdff1..b996055e99 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -281,6 +281,163 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return BlockGemmARegBRegCRegV1{}; } + // Custom warp distribution encodings that account for packed types + // For 16x16x128 MFMA with pk_fp4_t, the K dimension must use storage elements, not logical elements + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_AWarpDstrEncoding() + { + // For 16x16x128 MFMA with pk_fp4_t (PackedSize=2) + // Physical layout in registers: [16 M-lanes, 4 K-lanes, 16 bytes per lane] + // Each byte stores 2 fp4 values, so 16 bytes = 32 fp4 values + // But we need to use STORAGE size (16) not LOGICAL size (32) in the distribution + using AsDataType = remove_cvref_t; + using ADataType = remove_cvref_t{}, AsDataType>>; + constexpr index_t APackedSize = numeric_traits::PackedSize; + + constexpr index_t kAMLane = 16; + constexpr index_t kABKLane = 4; + constexpr index_t kABKPerLane = 32 / APackedSize; // Storage elements, not logical! + + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BWarpDstrEncoding() + { + using BsDataType = remove_cvref_t; + using BDataType = remove_cvref_t{}, BsDataType>>; + constexpr index_t BPackedSize = numeric_traits::PackedSize; + + constexpr index_t kBNLane = 16; + constexpr index_t kABKLane = 4; + constexpr index_t kABKPerLane = 32 / BPackedSize; // Storage elements! + + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + + // Custom LDS block distributions that account for packed types + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDistributionEncode() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + constexpr index_t MWarp = BlockWarps::at(number<0>{}); + constexpr index_t NWarp = BlockWarps::at(number<1>{}); + constexpr index_t MPerXdl = WarpTile::at(number<0>{}); + constexpr index_t KPerXdl = WarpTile::at(number<2>{}); + + using AsDataType = remove_cvref_t; + using ADataType = remove_cvref_t{}, AsDataType>>; + constexpr index_t APackedSize = numeric_traits::PackedSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + // IMPORTANT: Use packed K for iteration count + // LDS shape is [MPerBlock, KPerBlock / APackedSize] + // WarpGemm expects [MPerXdl, KPerXdl / APackedSize] per warp per iteration + constexpr index_t KIterPerWarp = (KPerBlock / APackedSize) / (KPerXdl / APackedSize); + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); + + constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1); + + if constexpr(UseDefaultScheduler) + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + // Use custom warp encoding that accounts for packed types + return detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding()); + } + else + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + // Use custom warp encoding that accounts for packed types + return detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, MakeMX_AWarpDstrEncoding()); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDistributionEncode() + { + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + constexpr index_t MWarp = BlockWarps::at(number<0>{}); + constexpr index_t NWarp = BlockWarps::at(number<1>{}); + constexpr index_t NPerXdl = WarpTile::at(number<1>{}); + constexpr index_t KPerXdl = WarpTile::at(number<2>{}); + + using BsDataType = remove_cvref_t; + using BDataType = remove_cvref_t{}, BsDataType>>; + constexpr index_t BPackedSize = numeric_traits::PackedSize; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + // IMPORTANT: Use packed K for iteration count + // LDS shape is [NPerBlock, KPerBlock / BPackedSize] + // WarpGemm expects [NPerXdl, KPerXdl / BPackedSize] per warp per iteration + constexpr index_t KIterPerWarp = (KPerBlock / BPackedSize) / (KPerXdl / BPackedSize); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); + + constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1); + + if constexpr(UseDefaultScheduler) + { + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + // Use custom warp encoding that accounts for packed types + return detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding()); + } + else + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + // Use custom warp encoding that accounts for packed types + return detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, MakeMX_BWarpDstrEncoding()); + } + } + // MX Scale tile distributions for loading from global memory // Using the proven "Flat" patterns from v1 policy template