From 0f4d68633be8617cf416f080661b6cf6ce5f9380 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 16 Jun 2025 07:54:55 -0700 Subject: [PATCH] Revert "fix the flatmm (#2349)" (#2352) This reverts commit fc651956056798bf530a06cd322a8f6893d533ab. [ROCm/composable_kernel commit: 5523df4b2dfab16d6144d7717b3b075f8c6d5104] --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 3 --- include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp | 3 +-- include/ck_tile/ops/gemm.hpp | 2 +- script/run_ck_profiler_gemm_with_csv_shapes.py | 4 ++-- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 8782d2bb6a..c564d7d1b1 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -49,12 +49,9 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, - ck_tile::tuple<>, CLayout, - ck_tile::element_wise::PassThrough, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index d2e1bde58f..a9ed1519e6 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -447,7 +447,6 @@ struct FlatmmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr); @@ -455,7 +454,7 @@ struct FlatmmKernel auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr); + c_block_window, c_block_tile, smem_ptr); } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index a1d37f0824..8db822ebd1 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,8 +31,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/script/run_ck_profiler_gemm_with_csv_shapes.py b/script/run_ck_profiler_gemm_with_csv_shapes.py index 54b4b337de..1f7ec7585f 100644 --- a/script/run_ck_profiler_gemm_with_csv_shapes.py +++ b/script/run_ck_profiler_gemm_with_csv_shapes.py @@ -278,13 +278,13 @@ def main(): shapes = tuples(filename) all_results = [] + from tqdm import tqdm from functools import partial from os import path profiler_bin = path.join(args["build_dir"], "bin", "ckProfiler") - total = len(shapes) - for idx, s in enumerate(shapes, 1): + for s in tqdm(shapes): run_shape_stdout_lines = run_shape( s, profiler_bin, args["op_name"], args["dtype"], args["layout"] )