From a8d48f92247cb0ba64a15614a2a2223d0d5434da Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Thu, 5 Feb 2026 17:31:32 +0000 Subject: [PATCH] now offsetting with M/MPerXdl to get scales --- .../ops/gemm_mx/kernel/gemm_mx_kernel.hpp | 1 + .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 79 +++++++++++-------- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 12 +-- 3 files changed, 54 insertions(+), 38 deletions(-) diff --git a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp index 0776537c34..4a49bbe658 100644 --- a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp +++ b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp @@ -386,6 +386,7 @@ struct MXGemmKernel : UniversalGemmKernel{}, number{}), - scale_a_window.get_window_origin(), + // Scale tensor views and base origins for creating tile windows per iteration + const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view(); + const auto& scale_b_tensor_view = scale_b_window.get_bottom_tensor_view(); + auto scale_a_base_origin = scale_a_window.get_window_origin(); + auto scale_b_base_origin = scale_b_window.get_window_origin(); + + // Create sample scale windows to determine tile types + auto scale_a_dram_window_sample = make_tile_window( + scale_a_tensor_view, + make_tuple(number{}, number{}), + scale_a_base_origin, Policy::template MakeMX_ScaleA_DramTileDistribution()); - const auto scale_a_dram_step_m = amd_wave_read_first_lane( - scale_a_dram_window.get_load_offset(tuple, number<0>>{})); - - // Scale B DRAM Window: [ScaleKDimPerBlock, NWarp * NPerXdl] - // With strided packing: KXdlPack kIters share each int32 via OpSel - auto scale_b_dram_window = make_tile_window( - scale_b_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - scale_b_window.get_window_origin(), + auto scale_b_dram_window_sample = make_tile_window( + scale_b_tensor_view, + make_tuple(number{}, number{}), + scale_b_base_origin, Policy::template MakeMX_ScaleB_DramTileDistribution()); - - const auto scale_b_dram_step_n = amd_wave_read_first_lane( - scale_b_dram_window.get_load_offset(tuple, number<0>>{})); // this pipeline has a pair of LDS buffers per logical tile // TODO: check for packed size - are these blocks too big? @@ -561,8 +558,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter must be positive!"); // Load a sample scale tile to get the type after distribution - auto scale_a_sample = load_tile_with_offset(scale_a_dram_window, tuple, number<0>>{}); - auto scale_b_sample = load_tile_with_offset(scale_b_dram_window, tuple, number<0>>{}); + auto scale_a_sample = load_tile(scale_a_dram_window_sample); + auto scale_b_sample = load_tile(scale_b_dram_window_sample); using ScaleTileElementA = remove_cvref_t; using ScaleTileElementB = remove_cvref_t; @@ -578,22 +575,40 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< // Helper function to load scales auto load_scales_ = [&](auto& scale_a, auto& scale_b) { // Load scales for each M/N iteration + // Create tile windows from scratch with correct origins for each iteration static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) { - // scale_a(mIter)(kPacked) = load_tile_with_offset( - // scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k); - // }); - scale_a(mIter) = load_tile_with_offset(scale_a_dram_window, make_tuple(mIter * scale_a_dram_step_m, number<0>{})); + // Scale A: create window at origin {base_m + mIter * MPerXdl, base_k} + auto scale_a_origin = scale_a_base_origin; + scale_a_origin[number<0>{}] += mIter * MPerXdl; + + auto scale_a_tile_window = make_tile_window( + scale_a_tensor_view, + make_tuple(number{}, number{}), + scale_a_origin, + Policy::template MakeMX_ScaleA_DramTileDistribution()); + + scale_a(mIter) = load_tile(scale_a_tile_window); }); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // Scale B viewed as [N, K], so N is first dimension - scale_b(nIter) = load_tile_with_offset(scale_b_dram_window, make_tuple(nIter * scale_b_dram_step_n, number<0>{})); + // Scale B: layout is [N, K], create window at origin {base_n + nIter * NPerXdl, base_k} + auto scale_b_origin = scale_b_base_origin; + scale_b_origin[number<0>{}] += nIter * NPerXdl; + + auto scale_b_tile_window = make_tile_window( + scale_b_tensor_view, + make_tuple(number{}, number{}), + scale_b_origin, + Policy::template MakeMX_ScaleB_DramTileDistribution()); + + scale_b(nIter) = load_tile(scale_b_tile_window); }); - // Advance to next KPerBlock - // Scale A: [M, K] -> advance in K (second dimension) - // Scale B: viewed as [N, K] -> advance in K (second dimension) - move_tile_window(scale_a_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack}); - move_tile_window(scale_b_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack}); + + // Advance base origins to next KPerBlock + // Scale A: [M, K] -> advance in K (second dimension, index 1) + // Scale B: [N, K] -> advance in K (second dimension, index 1) + scale_a_base_origin[number<1>{}] += KPerBlock / ScaleBlockSize / KXdlPack; + scale_b_base_origin[number<1>{}] += KPerBlock / ScaleBlockSize / KXdlPack; }; // constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() { diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index 2d3c841483..e50dd388c7 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -221,11 +221,11 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding, // repeat over NWarps tuple, // M dimension - sequence>, // K dimension + sequence>, // K dimension tuple, sequence<2, 1>>, // , tuple, sequence<1, 1>>, - sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock - sequence<0>>{}); + sequence<2, 2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock + sequence<0, 2>>{}); } template @@ -251,11 +251,11 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding, // repeat over MWarps tuple, // N dimension (first) - sequence>, // K dimension (second) + sequence>, // K dimension (second) tuple, sequence<2, 1>>, // which direction tuple, sequence<1, 1>>, // which index - sequence<2>, // replicate N - sequence<0>>{}); + sequence<2, 2>, // replicate N + sequence<0, 2>>{}); } }; } // namespace ck_tile