diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index defeffc2ee..1906b0bda7 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -69,9 +69,12 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index beb6987605..3ec90e7f00 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -166,7 +166,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: // clear c mem if(args.k_batch > 1) hipGetErrorString(hipMemsetAsync( - args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; ave_time = ck_tile::launch_kernel_preprocess( s, diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index c564d7d1b1..8782d2bb6a 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -49,9 +49,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem, AccDataType, CDataType, + ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index a9ed1519e6..d2e1bde58f 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -447,6 +447,7 @@ struct FlatmmKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_flat_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr); @@ -454,7 +455,7 @@ struct FlatmmKernel auto& c_block_window = gemm_tile_windows.at(I2); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr); + c_block_window, c_block_tile, d_block_window, smem_ptr); } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 8db822ebd1..a1d37f0824 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,8 +31,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 9ef7f3f0ef..55220730cd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -1,5 +1,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 217408fffa..881467cb94 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -47,6 +47,8 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t kLdsAlignmentInBytes = 16; [[nodiscard]] CK_TILE_HOST static const std::string GetName() diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 678fb6eb46..b349991470 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -32,6 +32,8 @@ struct GemmPipelineProblemBase static constexpr bool TransposeC = Traits::TransposeC; + static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 353192d86f..c6f83068a9 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -28,6 +28,7 @@ struct TileGemmTraits static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; + static constexpr index_t NumWaveGroups = 1; }; template