From d2a7c2f0417f3e525fa7ff7a7c22cd673e9e1456 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Fri, 23 Jan 2026 11:01:43 -0500 Subject: [PATCH] compiles again using get_y_sliced_thread_data in warpgemm loop --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 37 +++++++++--------- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 39 +++++++++++++------ 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 7377430a50..cd486df521 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -299,7 +299,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< [&](auto idx) { return make_tile_window( a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), a_dram_block_window_tmp[number{}].get_window_origin(), Policy::template MakeADramTileDistribution()); }, @@ -309,7 +309,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< [&](auto idx) { return make_tile_window( b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), b_dram_block_window_tmp[number{}].get_window_origin(), Policy::template MakeBDramTileDistribution()); }, @@ -364,6 +364,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< scale_b_dram_window.get_load_offset(tuple, number>{})); // this pipeline has a pair of LDS buffers per logical tile + // TODO: check for packed size - are these blocks too big? auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); @@ -372,14 +373,14 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< if constexpr(is_a_load_tr_v) return make_tuple(number{}, number{}); else - return make_tuple(number{}, number{}); + return make_tuple(number{}, number{}); }(); constexpr auto b_lds_shape = []() { if constexpr(is_b_load_tr_v) return make_tuple(number{}, number{}); else - return make_tuple(number{}, number{}); + return make_tuple(number{}, number{}); }(); // LDS tile windows for storing, one per LDS buffer @@ -439,6 +440,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); constexpr index_t ScaleKPackedPerIter = (KIterPerWarp * KPerXdl) / (ScaleBlockSize * KXdlPack); + static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter is wrong!"); // Load a sample scale tile to get the type auto scale_a_sample = load_tile_with_offset(scale_a_dram_window, tuple, number<0>>{}); @@ -520,7 +522,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< }); // Warp GEMM loop with MX scaling - auto warp_gemm_loop = [&](auto& a_block_tile, auto& b_block_tile, auto& scale_a, auto& scale_b) { + auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) { // Extract A/B values from block tiles to warp iteration structure constexpr auto a_warp_y_lengths = to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); @@ -537,25 +539,22 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) { constexpr auto OpSelA = kScaleInPack; + // read A warp tensor from A block tensor + typename WarpGemm::AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) { constexpr auto OpSelB = kScaleInPack; - // Extract A/B values for this iteration - create warp tensors - typename WarpGemm::AWarpTensor a_warp_tensor{}; - const auto a_thread_data = a_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, a_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) { - a_warp_tensor.get_thread_buffer()(i) = a_thread_data[i]; - }); - - typename WarpGemm::BWarpTensor b_warp_tensor{}; - const auto b_thread_data = b_block_tile.get_y_sliced_thread_data( + // read B warp tensor from B block tensor + typename WarpGemm::BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - static_for<0, b_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) { - b_warp_tensor.get_thread_buffer()(i) = b_thread_data[i]; - }); WarpGemm{}.template operator()( c_warp_tensors(m_iter)(n_iter), diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index 6633e9493e..638e7fdff1 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -29,9 +29,9 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA() { + // Get packed sizes for A/B using AsDataType = remove_cvref_t; - using ADataType = remove_cvref_t{}, AsDataType>>; - + using ADataType = remove_cvref_t{}, AsDataType>>; // Force 16-byte vector loads for optimal async buffer performance // For fp4 (1 byte): 16 elements = 16 bytes // For fp8 (1 byte): 16 elements = 16 bytes @@ -53,9 +53,9 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { + // Get packed sizes for A/B using BsDataType = remove_cvref_t; - using BDataType = remove_cvref_t{}, BsDataType>>; - + using BDataType = remove_cvref_t{}, BsDataType>>; // Force 16-byte vector loads for optimal async buffer performance // For fp4 (1 byte): 16 elements = 16 bytes // For fp8 (1 byte): 16 elements = 16 bytes @@ -86,13 +86,17 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy using ALayout = remove_cvref_t< std::tuple_element_t{}, remove_cvref_t>>; + // Get packed sizes for A/B + using AsDataType = remove_cvref_t; + using ADataType = remove_cvref_t{}, AsDataType>>; + constexpr index_t APackedSize = numeric_traits>::PackedSize; if constexpr(std::is_same_v) { using TileEncodingPattern = tile_distribution_encoding_pattern_2d; @@ -123,6 +127,11 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy using BLayout = remove_cvref_t< std::tuple_element_t{}, remove_cvref_t>>; + + // Get packed sizes for A/B + using BsDataType = remove_cvref_t; + using BDataType = remove_cvref_t{}, BsDataType>>; + constexpr index_t BPackedSize = numeric_traits>::PackedSize; if constexpr(std::is_same_v) { @@ -141,7 +150,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy using TileEncodingPattern = tile_distribution_encoding_pattern_2d; @@ -153,8 +162,13 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy typename OverrideADataType = remove_cvref_t> CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { + // Get packed sizes for A/B + using AsDataType = remove_cvref_t; + using ADataType = remove_cvref_t{}, AsDataType>>; + constexpr index_t APackedSize = numeric_traits>::PackedSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; if constexpr(is_a_load_tr) { // TODO: better LDS descriptor for performance @@ -191,8 +205,13 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { + // Get packed sizes for A/B + using BsDataType = remove_cvref_t; + using BDataType = remove_cvref_t{}, BsDataType>>; + constexpr index_t BPackedSize = numeric_traits>::PackedSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; if constexpr(is_b_load_tr) { // TODO: better LDS descriptor for performance @@ -300,10 +319,6 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy constexpr index_t NWarp = BlockWarps::at(number<1>{}); constexpr index_t NPerXdl = WarpTile::at(number<1>{}); constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 4 for 16x16 mfma - - static_assert(K_Lane == 4, "K_Lane must be 4 for 16x16 mfma"); - static_assert(NPerXdl == 16, "NPerXdl must be 16 for 16x16 mfma"); - static_assert(MWarp == 1, "MWarp must be 1 for 16x16 mfma"); // Scale B: [K/32/KXdlPack, NWarp * NPerXdl] for warp-level tile // Layout is [K, N] where K is packed int32