Bench, add tuple param

This commit is contained in:
Mateusz Ozga
2025-07-02 09:49:41 +00:00
parent 70e5534864
commit 138b576531
4 changed files with 26 additions and 22 deletions

View File

@@ -27,8 +27,8 @@ struct Default2DEpilogueProblem
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
};
template <typename ADataType_,
typename BDataType_,
template <typename AsDataType_,
typename BsDataType_,
typename AccDataType_,
typename ODataType_,
typename CLayout_,
@@ -47,8 +47,8 @@ struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataTyp
UseRawStore_,
MemoryOperation_>
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
@@ -115,13 +115,17 @@ template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using DsDataType = ck_tile::tuple<>;
using DsLayout = ck_tile::tuple<>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;

View File

@@ -29,8 +29,8 @@ LAYOUT_MAP = {
DEFAULT_EPILOGUE = """
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
BDataType,
ck_tile::DefaultGemm2DEpilogueProblem<ck_tile::tuple<ADataType>,
ck_tile::tuple<BDataType>,
AccDataType,
CDataType,
CLayout,
@@ -46,8 +46,8 @@ DEFAULT_EPILOGUE = """
CSHUFFLE_EPILOGUE = """
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::CShuffleEpilogueProblem<ck_tile::tuple<ADataType>,
ck_tile::tuple<BDataType>,
ck_tile::tuple<>,
AccDataType,
CDataType,

View File

@@ -257,14 +257,14 @@ struct GemmKernel {{
TileParitionerM01>;
using Traits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ck_tile::tuple<ALayout>, ck_tile::tuple<BLayout>, CLayout>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
ck_tile::tuple<ALayout>, ck_tile::tuple<BLayout>, CLayout, TransposeC, structured_sparsity>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
ck_tile::GemmPipelineProblem<ck_tile::tuple<ADataType>, ck_tile::tuple<BDataType>, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}<GemmPipelineProblem>;
@@ -283,8 +283,8 @@ struct GemmKernel {{
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
ck_tile::UniversalGemmPipelineProblem<ck_tile::tuple<ADataType>,
ck_tile::tuple<BDataType>,
AccDataType,
GemmShape,
GemmUniversalTraits,
@@ -327,15 +327,15 @@ struct GemmKernel {{
}};
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{{}})));
args.M, args.K, args.stride_As[0], is_row_major(ALayout{{}})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{{}})));
args.K, args.N, args.stride_Bs[0], is_row_major(BLayout{{}})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer);
kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {{

View File

@@ -90,16 +90,16 @@ class GemmProfiler
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs<> gemm_args = {
a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{a_m_k_dev_buf.GetDeviceBuffer()},
{b_k_n_dev_buf.GetDeviceBuffer()},
{}, // ds_ptr
c_m_n_dev_buf.GetDeviceBuffer(),
gemm_problem.split_k_,
gemm_problem.m_,
gemm_problem.n_,
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
{gemm_problem.stride_a_},
{gemm_problem.stride_b_},
{}, // stride_Ds
gemm_problem.stride_c_,
};