From b124a72ff5953c2a2a4727548f6b3843d06616f3 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Fri, 30 Jan 2026 12:40:48 -0500 Subject: [PATCH] revert mostly back to original comp_async --- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 366 ++++++++++++------ ...ine_ag_bg_cr_comp_async_default_policy.hpp | 305 +++++++++++---- 2 files changed, 475 insertions(+), 196 deletions(-) diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 30ae9d9058..9af8654e5b 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -298,38 +298,96 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< auto a_tile_windows = generate_tuple( [&](auto idx) { + /// NOTE: flatmm style byte tensor approach: // Create tile window with STORAGE dimensions to match LDS + // auto&& tensor_view_tmp = a_dram_block_window_tmp[number{}].get_bottom_tensor_view(); + // auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); + // const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); + // auto&& a_tensor_view = make_naive_tensor_view( + // static_cast(byte_ptr), + // make_tuple(rows, cols / APackedSize), + // make_tuple(cols / APackedSize, 1), + // number<16>{}, + // number<1>{}); + // return make_tile_window(a_tensor_view, + // make_tuple(number{}, number{}), + // [&]() { + // auto origin = a_dram_block_window_tmp[number{}].get_window_origin(); + // if constexpr(is_a_col_major) { + // origin[0] = origin[0] / APackedSize; // Adjust K origin + // } else { + // origin[1] = origin[1] / APackedSize; // Adjust K origin + // } + // return origin; + // }(), + // Policy::template MakeADramTileDistribution()); + /// NOTE: re-use original tensor view but with adjusted origin and K/PackedSize + // return make_tile_window( + // a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + // make_tuple(number{}, number{}), + // [&]() { + // auto origin = a_dram_block_window_tmp[number{}].get_window_origin(); + // if constexpr(is_a_col_major) { + // origin[0] = origin[0] / APackedSize; // Adjust K origin + // } else { + // origin[1] = origin[1] / APackedSize; // Adjust K origin + // } + // return origin; + // }(), + // Policy::template MakeADramTileDistribution()); + /// NOTE: use original shapes return make_tile_window( - a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - [&]() { - auto origin = a_dram_block_window_tmp[number{}].get_window_origin(); - if constexpr(is_a_col_major) { - origin[0] = origin[0] / APackedSize; // Adjust K origin - } else { - origin[1] = origin[1] / APackedSize; // Adjust K origin - } - return origin; - }(), + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), Policy::template MakeADramTileDistribution()); }, number{}); // B DRAM window(s) for load auto b_tile_windows = generate_tuple( [&](auto idx) { + /// NOTE: flatmm style byte tensor approach: // Create tile window with STORAGE dimensions to match LDS + // auto&& tensor_view_tmp = b_dram_block_window_tmp[number{}].get_bottom_tensor_view(); + // auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); + // const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); + // auto&& b_tensor_view = make_naive_tensor_view( + // static_cast(byte_ptr), + // make_tuple(rows, cols / BPackedSize), + // make_tuple(cols / BPackedSize, 1), + // number<16>{}, + // number<1>{}); + // return make_tile_window(b_tensor_view, + // make_tuple(number{}, number{}), + // [&]() { + // auto origin = b_dram_block_window_tmp[number{}].get_window_origin(); + // if constexpr(is_b_row_major) { + // origin[0] = origin[0] / BPackedSize; // Adjust K origin + // } else { + // origin[1] = origin[1] / BPackedSize; // Adjust K origin + // } + // return origin; + // }(), + // Policy::template MakeBDramTileDistribution()); + /// NOTE: re-use original tensor view but with adjusted origin and K/PackedSize + // return make_tile_window( + // b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + // make_tuple(number{}, number{}), + // [&]() { + // auto origin = b_dram_block_window_tmp[number{}].get_window_origin(); + // if constexpr(is_b_row_major) { + // origin[0] = origin[0] / BPackedSize; // Adjust K origin + // } else { + // origin[1] = origin[1] / BPackedSize; // Adjust K origin + // } + // return origin; + // }(), + // Policy::template MakeBDramTileDistribution()); + /// NOTE: use original shapes return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - [&]() { - auto origin = b_dram_block_window_tmp[number{}].get_window_origin(); - if constexpr(is_b_row_major) { - origin[0] = origin[0] / BPackedSize; // Adjust K origin - } else { - origin[1] = origin[1] / BPackedSize; // Adjust K origin - } - return origin; - }(), + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), Policy::template MakeBDramTileDistribution()); }, number{}); @@ -382,22 +440,41 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< // this pipeline has a pair of LDS buffers per logical tile // TODO: check for packed size - are these blocks too big? + /// NOTE: flatmm style byte tensor approach: + // auto&& [a_lds_block0, b_lds_block0] = Base::template GetABLdsTensorViews(p_smem_0); + // auto&& [a_lds_block1, b_lds_block1] = Base::template GetABLdsTensorViews(p_smem_1); + /// NOTE: with original fp4 types: auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); // set up LDS tile shapes - always use STORAGE dimensions for K + /// NOTE: flatmm style byte tensor approach: + // constexpr auto a_lds_shape = []() { + // if constexpr(is_a_load_tr_v) + // return make_tuple(number{}, number{}); + // else + // return make_tuple(number{}, number{}); + // }(); + + // constexpr auto b_lds_shape = []() { + // if constexpr(is_b_load_tr_v) + // return make_tuple(number{}, number{}); + // else + // return make_tuple(number{}, number{}); + // }(); + /// NOTE: use original shapes constexpr auto a_lds_shape = []() { if constexpr(is_a_load_tr_v) - return make_tuple(number{}, number{}); + return make_tuple(number{}, number{}); else - return make_tuple(number{}, number{}); + return make_tuple(number{}, number{}); }(); constexpr auto b_lds_shape = []() { if constexpr(is_b_load_tr_v) - return make_tuple(number{}, number{}); + return make_tuple(number{}, number{}); else - return make_tuple(number{}, number{}); + return make_tuple(number{}, number{}); }(); // LDS tile windows for storing, one per LDS buffer @@ -413,10 +490,16 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + /// NOTE: flatmm style way to calculate steps with packed size + // constexpr ADramTileWindowStep a_dram_tile_window_step = + // is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize); + // constexpr BDramTileWindowStep b_dram_tile_window_step = + // is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize); + /// NOTE: use original steps and assume that PackedSize is correctly applied elsewhere constexpr ADramTileWindowStep a_dram_tile_window_step = - is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize); + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); constexpr BDramTileWindowStep b_dram_tile_window_step = - is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize); + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); // read A(0), B(0) from DRAM to LDS window(0) // and advance the DRAM windows @@ -426,8 +509,13 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step); // Initialize WarpGemm for MX scaling - using WarpGemm = typename remove_cvref_t())>::WarpGemm; - using CWarpTensor = typename WarpGemm::CWarpTensor; + // using WarpGemm = typename remove_cvref_t())>::WarpGemm; + // using CWarpTensor = typename WarpGemm::CWarpTensor; + + // Initialize block gemm and C block tile + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + clear_tile(c_block_tile); // read A(1), B(1) from DRAM to LDS window(1) // and advance the DRAM windows @@ -449,6 +537,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< ALdsTile a_block_tile0, a_block_tile1; BLdsTile b_block_tile0, b_block_tile1; + // Some sanity checks on the LDS tile sizes static_assert(sizeof(ALdsTile) == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize) * NWarp / BlockSize, "ALdsTile size is wrong!"); static_assert(sizeof(BLdsTile) == NPerBlock * (KPerBlock * sizeof(BDataType) / BPackedSize) * MWarp / BlockSize, "BLdsTile size is wrong!"); static_assert(Policy::template GetSmemSizeA() == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), "SmemSizeA size is wrong!"); @@ -496,36 +585,44 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< move_tile_window(scale_b_dram_window, {KPerBlock / ScaleBlockSize / KXdlPack, 0}); }; - constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() { - if constexpr(is_a_load_tr_v) - return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename decltype(ALdsTileDistr)::DstrEncode, - typename Problem::ADataType>::TransposedDstrEncode{}); - else - return ALdsTileDistr; - }(); - constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() { - if constexpr(is_b_load_tr_v) - return make_static_tile_distribution( - typename InputTileDistributionTraits< - typename decltype(BLdsTileDistr)::DstrEncode, - typename Problem::BDataType>::TransposedDstrEncode{}); - else - return BLdsTileDistr; - }(); + // constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() { + // if constexpr(is_a_load_tr_v) + // return make_static_tile_distribution( + // typename InputTileDistributionTraits< + // typename decltype(ALdsTileDistr)::DstrEncode, + // typename Problem::ADataType>::TransposedDstrEncode{}); + // else + // return ALdsTileDistr; + // }(); + // constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() { + // if constexpr(is_b_load_tr_v) + // return make_static_tile_distribution( + // typename InputTileDistributionTraits< + // typename decltype(BLdsTileDistr)::DstrEncode, + // typename Problem::BDataType>::TransposedDstrEncode{}); + // else + // return BLdsTileDistr; + // }(); // LDS tile windows for reading; // they share the data pointer with the LDS windows for storing // but also associate with a distribution to produce a register tile when reading auto a_lds_ld_window0 = - make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr); + make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, ALdsTileDistr); auto a_lds_ld_window1 = - make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr); + make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, ALdsTileDistr); auto b_lds_ld_window0 = - make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr); + make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, BLdsTileDistr); auto b_lds_ld_window1 = - make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr); + make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, BLdsTileDistr); + // auto a_lds_ld_window0 = + // make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, Policy::template MakeMX_ALDSBytes_TileDistribution()); + // auto a_lds_ld_window1 = + // make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, Policy::template MakeMX_ALDSBytes_TileDistribution()); + // auto b_lds_ld_window0 = + // make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, Policy::template MakeMX_BLDSBytes_TileDistribution()); + // auto b_lds_ld_window1 = + // make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, Policy::template MakeMX_BLDSBytes_TileDistribution()); static_assert(!(is_tile_window_linear_v) && !(is_tile_window_linear_v) && @@ -534,61 +631,62 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< "LDS windows must not be linear"); // Create warp-level C tensors (one per M/N iteration) - statically_indexed_array, MIterPerWarp> c_warp_tensors; + // statically_indexed_array, MIterPerWarp> c_warp_tensors; // Initialize C tensors - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - clear_tile(c_warp_tensors(mIter)(nIter)); - }); - }); + /// TODO: create CBlockTile with block_gemm.MakeCBlockTile() + // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // clear_tile(c_warp_tensors(mIter)(nIter)); + // }); + // }); // Warp GEMM loop with MX scaling - auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) { - // Extract A/B values from block tiles to warp iteration structure - constexpr auto a_warp_y_lengths = - to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto b_warp_y_lengths = - to_sequence(typename WarpGemm::BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + // auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) { + // // Extract A/B values from block tiles to warp iteration structure + // constexpr auto a_warp_y_lengths = + // to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + // constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + // constexpr auto b_warp_y_lengths = + // to_sequence(typename WarpGemm::BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + // constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; - static_for<0, KIterPerWarp, 1>{}([&](auto k_iter) { - // Map k_iter to packed scale index and OpSel - constexpr index_t kScalePacked = (k_iter * KPerXdl) / (ScaleBlockSize * KXdlPack); - // constexpr index_t kScaleInPack = ((k_iter * KPerXdl) / ScaleBlockSize) % KXdlPack; - constexpr index_t kScaleInPack = k_iter; + // static_for<0, KIterPerWarp, 1>{}([&](auto k_iter) { + // // Map k_iter to packed scale index and OpSel + // constexpr index_t kScalePacked = (k_iter * KPerXdl) / (ScaleBlockSize * KXdlPack); + // // constexpr index_t kScaleInPack = ((k_iter * KPerXdl) / ScaleBlockSize) % KXdlPack; + // constexpr index_t kScaleInPack = k_iter; - static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) { - constexpr auto OpSelA = kScaleInPack; + // static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) { + // constexpr auto OpSelA = kScaleInPack; - // read A warp tensor from A block tensor - typename WarpGemm::AWarpTensor a_warp_tensor; + // // read A warp tensor from A block tensor + // typename WarpGemm::AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data( + // merge_sequences(sequence{}, a_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) { - constexpr auto OpSelB = kScaleInPack; + // static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) { + // constexpr auto OpSelB = kScaleInPack; - // read B warp tensor from B block tensor - typename WarpGemm::BWarpTensor b_warp_tensor; + // // read B warp tensor from B block tensor + // typename WarpGemm::BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data( + // merge_sequences(sequence{}, b_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - WarpGemm{}.template operator()( - c_warp_tensors(m_iter)(n_iter), - a_warp_tensor, - b_warp_tensor, - scale_a(m_iter)(number{}).get_thread_buffer()[0], - scale_b(n_iter)(number{}).get_thread_buffer()[0]); - }); - }); - }); - }; + // WarpGemm{}.template operator()( + // c_warp_tensors(m_iter)(n_iter), + // a_warp_tensor, + // b_warp_tensor, + // scale_a(m_iter)(number{}).get_thread_buffer()[0], + // scale_b(n_iter)(number{}).get_thread_buffer()[0]); + // }); + // }); + // }); + // }; // write to LDS window(0) must complete before the local prefetch block_sync_lds_direct_load(); @@ -636,12 +734,16 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< b_tile_windows[number<0>{}], b_dram_tile_window_step); // C(i-3) = A(i-3) @ B(i-3) with MX scaling - warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + /// TODO: remove these after creating a block gemm with scales + ignore = scale_a_tile_ping; + ignore = scale_b_tile_ping; + HotLoopScheduler(); // Load scales for iteration i+2 (ping) if (i_global_read + 2 < num_loop) { load_scales_(scale_a_tile_ping, scale_b_tile_ping); } - HotLoopScheduler(); } // pong { @@ -661,13 +763,17 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< b_tile_windows[number<0>{}], b_dram_tile_window_step); // C(i-2) = A(i-2) @ B(i-2) with MX scaling - warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); + // warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + /// TODO: remove these after creating a block gemm with scales + ignore = scale_a_tile_pong; + ignore = scale_b_tile_pong; + HotLoopScheduler(); // Load scales for iteration i+2 (pong) /// TODO: check condition if (i_global_read + 2 < num_loop) { load_scales_(scale_a_tile_pong, scale_b_tile_pong); } - HotLoopScheduler(); } i_global_read += 2; } while(i_global_read < num_loop); @@ -681,7 +787,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); // C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) with MX scaling - warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + /// TODO: remove these after creating a block gemm with scales + ignore = scale_a_tile_ping; + ignore = scale_b_tile_ping; /// TODO: load next scales to ping for the last iteration } { @@ -691,11 +801,19 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling - warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); + // warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + /// TODO: remove these after creating a block gemm with scales + ignore = scale_a_tile_pong; + ignore = scale_b_tile_pong; } { // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + /// TODO: remove these after creating a block gemm with scales + ignore = scale_a_tile_ping; + ignore = scale_b_tile_ping; } } else if(TailNum == TailNumber::Two) @@ -706,36 +824,48 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling - warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + /// TODO: remove these after creating a block gemm with scales + ignore = scale_a_tile_ping; + ignore = scale_b_tile_ping; } { // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); + // warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + /// TODO: remove these after creating a block gemm with scales + ignore = scale_a_tile_pong; + ignore = scale_b_tile_pong; } } else if(TailNum == TailNumber::One) { block_sync_lds(); // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + /// TODO: remove these after creating a block gemm with scales + ignore = scale_a_tile_ping; + ignore = scale_b_tile_ping; __builtin_amdgcn_sched_barrier(0); } // Convert warp-level C tensors to block tile format - auto c_block_tile = BlockGemm{}.MakeCBlockTile(); - using CWarpDstr = typename WarpGemm::CWarpDstr; - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + // auto c_block_tile = BlockGemm{}.MakeCBlockTile(); + // using CWarpDstr = typename WarpGemm::CWarpDstr; + // constexpr auto c_warp_y_lengths = + // to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + // constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensors(mIter)(nIter).get_thread_buffer()); - }); - }); + // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // c_block_tile.set_y_sliced_thread_data( + // merge_sequences(sequence{}, c_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + // c_warp_tensors(mIter)(nIter).get_thread_buffer()); + // }); + // }); return c_block_tile; } diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index 146d42abb2..55c7efb10a 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -4,9 +4,11 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include namespace ck_tile { // Default policy for MXGemmPipelineAgBgCrCompAsync @@ -70,91 +72,234 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return vector_size; } - // DRAM tile distributions use STORAGE dimensions (for the storage tensor view) - template - CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - using AsDataType = remove_cvref_t; - using ADataType = remove_cvref_t{}, AsDataType>>; - constexpr index_t APackedSize = numeric_traits>::PackedSize; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions - constexpr index_t VecLoadSize = GetVectorSizeA(); - constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + // // DRAM tile distributions use STORAGE dimensions (for the storage tensor view) + // template + // CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + // { + // // using AsDataType = remove_cvref_t; + // // using ADataType = remove_cvref_t{}, AsDataType>>; - using ALayout = remove_cvref_t< - std::tuple_element_t{}, remove_cvref_t>>; + // // constexpr index_t BlockSize = Problem::kBlockSize; + // // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + // // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // // constexpr index_t APackedSize = numeric_traits>::PackedSize; + + // // constexpr index_t K2 = 16; // 16 bytes + // // constexpr index_t K1 = 128 / K2; // 8 + // // constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize + + // // constexpr index_t M2 = get_warp_size() / K1; // 8 + // // constexpr index_t M1 = BlockSize / get_warp_size(); // 4 + // // constexpr index_t M0 = MPerBlock / (M2 * M1); + + // // static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); + // // static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + // // "K0, K1, K2 must cover whole KPerBlock!"); + + // // return make_static_tile_distribution( + // // tile_distribution_encoding< // + // // sequence<1>, + // // tuple, sequence>, // ?,4,8 1,8,32 or 2,8,16 + // // tuple, sequence<1, 2>>, // M1 M2,K1 + // // tuple, sequence<2, 1>>, + // // sequence<1, 2, 2>, // M0,K0,K2 + // // sequence<0, 0, 2>>{}); + // constexpr index_t BlockSize = Problem::kBlockSize; + // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + // /// NOTE: for flatmm style byte tensor, divide KPerBlock by APackedSize to get STORAGE dimensions + // // using AsDataType = remove_cvref_t; + // // using ADataType = remove_cvref_t{}, AsDataType>>; + // // constexpr index_t APackedSize = numeric_traits>::PackedSize; + // // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions + // /// NOTE: use original KPerBlock + // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // constexpr index_t VecLoadSize = GetVectorSizeA(); + // constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + // using ALayout = remove_cvref_t< + // std::tuple_element_t{}, remove_cvref_t>>; - if constexpr(std::is_same_v) - { - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - return TileEncodingPattern::make_2d_static_tile_distribution(); - } - else - { - static_assert(false, "Not implemented"); - // using TileEncodingPattern = - // tile_distribution_encoding_pattern_2d; - // return TileEncodingPattern::make_2d_static_tile_distribution(); - } - } + // if constexpr(std::is_same_v) + // { + // using TileEncodingPattern = + // tile_distribution_encoding_pattern_2d; + // return TileEncodingPattern::make_2d_static_tile_distribution(); + // } + // else + // { + // static_assert(false, "Not implemented"); + // // using TileEncodingPattern = + // // tile_distribution_encoding_pattern_2d; + // // return TileEncodingPattern::make_2d_static_tile_distribution(); + // } + // } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - using BsDataType = remove_cvref_t; - using BDataType = remove_cvref_t{}, BsDataType>>; - constexpr index_t BPackedSize = numeric_traits>::PackedSize; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions - constexpr index_t VecLoadSize = GetVectorSizeB(); - constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + // template + // CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + // { + // /// NOTE: flatmm style dstr + // // using BsDataType = remove_cvref_t; + // // using BDataType = remove_cvref_t{}, BsDataType>>; + + // // constexpr index_t BlockSize = Problem::kBlockSize; + // // constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + // // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // // constexpr index_t BPackedSize = numeric_traits>::PackedSize; - using BLayout = remove_cvref_t< - std::tuple_element_t{}, remove_cvref_t>>; + // // constexpr index_t K2 = 16; // 16 bytes + // // constexpr index_t K1 = 128 / K2; // 8 + // // constexpr index_t K0 = KPerBlock / (K1 * K2 * BPackedSize); // KPerBlock/256/packsize + + // // constexpr index_t N2 = get_warp_size() / K1; // 8 + // // constexpr index_t N1 = BlockSize / get_warp_size(); // 4 + // // constexpr index_t N0 = NPerBlock / (N2 * N1); + + // // static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!"); + // // static_assert(K0 * K1 * K2 * BPackedSize == KPerBlock, + // // "K0, K1, K2 must cover whole KPerBlock!"); + + // // return make_static_tile_distribution( + // // tile_distribution_encoding< // + // // sequence<1>, + // // tuple, sequence>, // ?,4,8 1,8,32 or 2,8,16 + // // tuple, sequence<1, 2>>, // M1 M2,K1 + // // tuple, sequence<2, 1>>, + // // sequence<1, 2, 2>, // N0,K0,K2 + // // sequence<0, 0, 2>>{}); + // constexpr index_t BlockSize = Problem::kBlockSize; + // constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + // /// NOTE: for flatmm style byte tensor, divide KPerBlock by BPackedSize to get STORAGE dimensions + // // using BsDataType = remove_cvref_t; + // // using BDataType = remove_cvref_t{}, BsDataType>>; + // // constexpr index_t BPackedSize = numeric_traits>::PackedSize; + // // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions + // /// NOTE: use original KPerBlock + // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // constexpr index_t VecLoadSize = GetVectorSizeB(); + // constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + // using BLayout = remove_cvref_t< + // std::tuple_element_t{}, remove_cvref_t>>; - if constexpr(std::is_same_v) - { - static_assert(false, "Not implemented"); - } - else - { - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - return TileEncodingPattern::make_2d_static_tile_distribution(); - } - } + // if constexpr(std::is_same_v) + // { + // static_assert(false, "Not implemented"); + // } + // else + // { + // using TileEncodingPattern = + // tile_distribution_encoding_pattern_2d; + // return TileEncodingPattern::make_2d_static_tile_distribution(); + // } + // } + + // template + // CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDSBytes_TileDistribution() + // { + // // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); + // using AsDataType = remove_cvref_t; + // using ADataType = remove_cvref_t{}, AsDataType>>; + // constexpr index_t APackedSize = numeric_traits>::PackedSize; + // using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + // constexpr index_t MWarps = BlockWarps::at(number<0>{}); + // constexpr index_t NWarps = BlockWarps::at(number<1>{}); + // constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); + // // constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); + // constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2); + // constexpr index_t K_Lane = get_warp_size() / 16; // 4 + // constexpr index_t K_Thread = KPerXdl / K_Lane; // 32 + // constexpr index_t DWORDx4 = 16; + // constexpr index_t AK1 = DWORDx4 * APackedSize; + + // if constexpr(K_Thread == AK1) + // return make_static_tile_distribution( + // tile_distribution_encoding< // + // sequence, + // tuple, sequence>, + // tuple, sequence<2, 1>>, + // tuple, sequence<0, 2>>, + // sequence<2>, + // sequence<1>>{}); + // else + // return make_static_tile_distribution( + // tile_distribution_encoding< // + // sequence, + // tuple, + // sequence>, + // tuple, sequence<2, 1>>, + // tuple, sequence<1, 2>>, + // sequence<2, 2>, + // sequence<0, 2>>{}); + // } + + // template + // CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLDSBytes_TileDistribution() + // { + // // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); + // using BsDataType = remove_cvref_t; + // using BDataType = remove_cvref_t{}, BsDataType>>; + // constexpr index_t BPackedSize = numeric_traits>::PackedSize; + // using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + // constexpr index_t MWarps = BlockWarps::at(number<0>{}); + // constexpr index_t NWarps = BlockWarps::at(number<1>{}); + // // constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); + // constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); + // constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2); + // constexpr index_t K_Lane = get_warp_size() / 16; // 4 + // constexpr index_t K_Thread = KPerXdl / K_Lane; // 32 + // constexpr index_t DWORDx4 = 16; + // constexpr index_t BK1 = DWORDx4 * BPackedSize; + + // if constexpr(K_Thread == BK1) + // return make_static_tile_distribution( + // tile_distribution_encoding< // + // sequence, + // tuple, sequence>, + // tuple, sequence<2, 1>>, + // tuple, sequence<0, 2>>, + // sequence<2>, + // sequence<1>>{}); + // else + // return make_static_tile_distribution( + // tile_distribution_encoding< // + // sequence, + // tuple, + // sequence>, + // tuple, sequence<2, 1>>, + // tuple, sequence<1, 2>>, + // sequence<2, 2>, + // sequence<0, 2>>{}); + // } template > CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using AsDataType = remove_cvref_t; - using ADataType = remove_cvref_t{}, AsDataType>>; - constexpr index_t APackedSize = numeric_traits>::PackedSize; - + { constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions + /// NOTE: for flatmm style byte tensor, divide KPerBlock by APackedSize to get STORAGE dimensions + // using AsDataType = remove_cvref_t; + // using ADataType = remove_cvref_t{}, AsDataType>>; + // constexpr index_t APackedSize = numeric_traits>::PackedSize; + // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions + /// NOTE: use original KPerBlock + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; if constexpr(is_a_load_tr) { // TODO: better LDS descriptor for performance @@ -170,6 +315,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy else { constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack >= 16, "KPack must be at least 16"); constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -190,12 +336,14 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BsDataType = remove_cvref_t; - using BDataType = remove_cvref_t{}, BsDataType>>; - constexpr index_t BPackedSize = numeric_traits>::PackedSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; + /// NOTE: for flatmm style byte tensor, divide KPerBlock by BPackedSize to get STORAGE dimensions + // using BsDataType = remove_cvref_t; + // using BDataType = remove_cvref_t{}, BsDataType>>; + // constexpr index_t BPackedSize = numeric_traits>::PackedSize; + // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions + /// NOTE: use original KPerBlock + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; if constexpr(is_b_load_tr) { // TODO: better LDS descriptor for performance @@ -211,6 +359,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy else { constexpr index_t KPack = GetSmemPackB(); + static_assert(KPack >= 16, "KPack must be at least 16"); constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}),