diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index c564d7d1b1..8782d2bb6a 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -49,9 +49,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index a9ed1519e6..d2e1bde58f 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -447,6 +447,7 @@ struct FlatmmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_flat_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr); @@ -454,7 +455,7 @@ struct FlatmmKernel auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr); + c_block_window, c_block_tile, d_block_window, smem_ptr); } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 8db822ebd1..a1d37f0824 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,8 +31,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/script/run_ck_profiler_gemm_with_csv_shapes.py b/script/run_ck_profiler_gemm_with_csv_shapes.py index 1f7ec7585f..54b4b337de 100644 --- a/script/run_ck_profiler_gemm_with_csv_shapes.py +++ b/script/run_ck_profiler_gemm_with_csv_shapes.py @@ -278,13 +278,13 @@ def main(): shapes = tuples(filename) all_results = [] - from tqdm import tqdm from functools import partial from os import path profiler_bin = path.join(args["build_dir"], "bin", "ckProfiler") - for s in tqdm(shapes): + total = len(shapes) + for idx, s in enumerate(shapes, 1): run_shape_stdout_lines = run_shape( s, profiler_bin, args["op_name"], args["dtype"], args["layout"] )