mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Bench, add tuple param
This commit is contained in:
@@ -27,8 +27,8 @@ struct Default2DEpilogueProblem
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename CLayout_,
|
||||
@@ -47,8 +47,8 @@ struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataTyp
|
||||
UseRawStore_,
|
||||
MemoryOperation_>
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
@@ -115,13 +115,17 @@ template <typename Problem_, typename Policy_ = void>
|
||||
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
@@ -29,8 +29,8 @@ LAYOUT_MAP = {
|
||||
|
||||
DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<ck_tile::tuple<ADataType>,
|
||||
ck_tile::tuple<BDataType>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
@@ -46,8 +46,8 @@ DEFAULT_EPILOGUE = """
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ck_tile::tuple<ADataType>,
|
||||
ck_tile::tuple<BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
|
||||
@@ -257,14 +257,14 @@ struct GemmKernel {{
|
||||
TileParitionerM01>;
|
||||
|
||||
using Traits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ck_tile::tuple<ALayout>, ck_tile::tuple<BLayout>, CLayout>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
|
||||
ck_tile::tuple<ALayout>, ck_tile::tuple<BLayout>, CLayout, TransposeC, structured_sparsity>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
ck_tile::GemmPipelineProblem<ck_tile::tuple<ADataType>, ck_tile::tuple<BDataType>, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}<GemmPipelineProblem>;
|
||||
|
||||
@@ -283,8 +283,8 @@ struct GemmKernel {{
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::UniversalGemmPipelineProblem<ck_tile::tuple<ADataType>,
|
||||
ck_tile::tuple<BDataType>,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
@@ -327,15 +327,15 @@ struct GemmKernel {{
|
||||
}};
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{{}})));
|
||||
args.M, args.K, args.stride_As[0], is_row_major(ALayout{{}})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{{}})));
|
||||
args.K, args.N, args.stride_Bs[0], is_row_major(BLayout{{}})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {{
|
||||
|
||||
@@ -90,16 +90,16 @@ class GemmProfiler
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs<> gemm_args = {
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
{a_m_k_dev_buf.GetDeviceBuffer()},
|
||||
{b_k_n_dev_buf.GetDeviceBuffer()},
|
||||
{}, // ds_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_problem.split_k_,
|
||||
gemm_problem.m_,
|
||||
gemm_problem.n_,
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
{gemm_problem.stride_a_},
|
||||
{gemm_problem.stride_b_},
|
||||
{}, // stride_Ds
|
||||
gemm_problem.stride_c_,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user