From 138b576531f55203d6c7908ea299f4176cdff480 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Wed, 2 Jul 2025 09:49:41 +0000 Subject: [PATCH] Bench, add tuple param --- .../ck_tile/ops/epilogue/default_2d_epilogue.hpp | 16 ++++++++++------ tile_engine/ops/gemm/codegen_utils.py | 8 ++++---- tile_engine/ops/gemm/gemm_instance_builder.py | 16 ++++++++-------- tile_engine/ops/gemm/gemm_profiler.hpp | 8 ++++---- 4 files changed, 26 insertions(+), 22 deletions(-) diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index ff41ac0d61..034ef72ef7 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -27,8 +27,8 @@ struct Default2DEpilogueProblem static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; }; -template { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = kMPerXdl_; static constexpr index_t kNPerXdl = kNPerXdl_; @@ -115,13 +115,17 @@ template struct DefaultGemm2DEpilogue : public Default2DEpilogue { using Problem = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A + using ADataType = remove_cvref_t{}, AsDataType>>; + using BDataType = remove_cvref_t{}, BsDataType>>; + using BTypeToUse = std::conditional_t, ADataType, BDataType>; + using DsDataType = ck_tile::tuple<>; using DsLayout = ck_tile::tuple<>; using CLayout = remove_cvref_t; diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 9ff76724cc..d9955a3294 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -29,8 +29,8 @@ LAYOUT_MAP = { DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem, + ck_tile::tuple, AccDataType, CDataType, CLayout, @@ -46,8 +46,8 @@ DEFAULT_EPILOGUE = """ CSHUFFLE_EPILOGUE = """ using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, + ck_tile::tuple, ck_tile::tuple<>, AccDataType, CDataType, diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index de1fd0bb62..d895954182 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -257,14 +257,14 @@ struct GemmKernel {{ TileParitionerM01>; using Traits = - ck_tile::TileGemmTraits; + ck_tile::TileGemmTraits, ck_tile::tuple, CLayout>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + ck_tile::tuple, ck_tile::tuple, CLayout, TransposeC, structured_sparsity>; using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + ck_tile::GemmPipelineProblem, ck_tile::tuple, AccDataType, GemmShape, Traits>; using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; @@ -283,8 +283,8 @@ struct GemmKernel {{ constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem, + ck_tile::tuple, AccDataType, GemmShape, GemmUniversalTraits, @@ -327,15 +327,15 @@ struct GemmKernel {{ }}; ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{{}}))); + args.M, args.K, args.stride_As[0], is_row_major(ALayout{{}}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{{}}))); + args.K, args.N, args.stride_Bs[0], is_row_major(BLayout{{}}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer); + kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer); rotating_mem.Print(); auto run_flush_cache = [&]() {{ diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 2b0cbe7880..2f63061976 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -90,16 +90,16 @@ class GemmProfiler c_m_n_dev_result.SetZero(); ck_tile::GemmHostArgs<> gemm_args = { - a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), + {a_m_k_dev_buf.GetDeviceBuffer()}, + {b_k_n_dev_buf.GetDeviceBuffer()}, {}, // ds_ptr c_m_n_dev_buf.GetDeviceBuffer(), gemm_problem.split_k_, gemm_problem.m_, gemm_problem.n_, gemm_problem.k_, - gemm_problem.stride_a_, - gemm_problem.stride_b_, + {gemm_problem.stride_a_}, + {gemm_problem.stride_b_}, {}, // stride_Ds gemm_problem.stride_c_, };