Simple HostArgs struct

This commit is contained in:
Mateusz Ozga
2025-03-22 18:39:35 +00:00
parent f6a2cfb1ef
commit 8ce06348b5
13 changed files with 49 additions and 56 deletions

View File

@@ -22,7 +22,7 @@ template <typename ADataType,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.

View File

@@ -219,5 +219,4 @@ auto create_args(int argc, char* argv[])
}
// host API
float gemm_calc(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args,
const ck_tile::stream_config& s);
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -166,18 +166,18 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat)
{
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{},
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{},
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
float ave_time =
gemm<ADataType,

View File

@@ -22,7 +22,7 @@ template <typename ADataType,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<

View File

@@ -54,7 +54,7 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
auto create_args(int argc, char* argv[])
{

View File

@@ -62,6 +62,6 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
using multiple_d_gemm_kargs = ck_tile::GemmHostArgs<DsDataType::size()>;
using multiple_d_gemm_kargs = ck_tile::GemmHostArgs;
float multiple_d_gemm(const multiple_d_gemm_kargs& kargs, const ck_tile::stream_config& s);

View File

@@ -16,21 +16,21 @@ template <typename ADataType,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_multi_d_gemm(const void* a_m_k_dev_buf,
const void* b_k_n_dev_buf,
const std::array<const void*, DsDataType::size()>& d_m_n_dev_buf,
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
void* c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t StrideA,
ck_tile::index_t StrideB,
const std::array<ck_tile::index_t, DsDataType::size()> StrideDs,
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
ck_tile::index_t StrideC,
int n_warmup,
int n_repeat)
{
multiple_d_gemm_kargs gemm_descs({a_m_k_dev_buf,
b_k_n_dev_buf,
d_m_n_dev_buf,
ds_m_n_dev_buf.data(),
c_m_n_dev_buf,
/*kbatch */ 1,
M,
@@ -38,7 +38,7 @@ float invoke_multi_d_gemm(const void* a_m_k_dev_buf,
K,
StrideA,
StrideB,
StrideDs,
StrideDs.data(),
StrideC});
float ave_time = multiple_d_gemm<ADataType,

View File

@@ -9,7 +9,7 @@
namespace ck_tile {
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,

View File

@@ -12,13 +12,12 @@
namespace ck_tile {
template <index_t NumDTensor = 0>
struct GemmHostArgs
{
CK_TILE_HOST GemmHostArgs() = default;
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
const void* ds_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
@@ -26,7 +25,7 @@ struct GemmHostArgs
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
const index_t* stride_Ds_,
index_t stride_C_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
@@ -45,14 +44,14 @@ struct GemmHostArgs
const void* a_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
const void* ds_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
const index_t* stride_Ds;
index_t stride_C;
index_t k_batch;
};
@@ -126,19 +125,18 @@ struct GemmKernel
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr GemmKernelArgs
MakeKernelArgs(const GemmHostArgs<NumDTensor>& hostArgs)
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
{
return GemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
static_cast<const void*>(hostArgs.ds_ptr.data()),
hostArgs.ds_ptr, // static_cast<const void*>(hostArgs.ds_ptr.data()),
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_Ds.data(),
hostArgs.stride_Ds,
hostArgs.stride_C,
hostArgs.k_batch};
}

View File

@@ -55,16 +55,15 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// clang-format on
}
__host__ static auto
GetWorkSpaceSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs) -> std::size_t
__host__ static auto GetWorkSpaceSize(const std::vector<GemmHostArgs>& gemm_descs)
-> std::size_t
{
return gemm_descs.size() * sizeof(GemmTransKernelArg);
}
__host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
__host__ static constexpr auto
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
__host__ static constexpr auto GridSize(const std::vector<GemmHostArgs>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
@@ -75,8 +74,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto
MakeKargs(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
CK_TILE_HOST static auto MakeKargs(const std::vector<GemmHostArgs>& gemm_descs)
-> std::vector<GemmTransKernelArg>
{
std::vector<GemmTransKernelArg> gemm_kernel_args_;

View File

@@ -82,8 +82,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
// TODO: expose tile size through test t-param ?
template <bool PadM, bool PadN, bool PadK>
void invoke_gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args,
const ck_tile::stream_config& s)
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 256;
@@ -424,7 +423,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args;
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();

View File

@@ -47,7 +47,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const ck_tile::index_t K_Warp_Tile = 8;
};
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);

View File

@@ -64,8 +64,7 @@ class TestCkTileMultipleDGemm : public ::testing::Test
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
void invoke_multi_d_gemm(const ck_tile::GemmHostArgs<DsDataType::size()>& args,
const ck_tile::stream_config& s)
void invoke_multi_d_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
@@ -291,18 +290,18 @@ class TestCkTileMultipleDGemm : public ::testing::Test
d1_m_n_dev_buf.GetDeviceBuffer()};
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {StrideD0, StrideD1};
ck_tile::GemmHostArgs<DsDataType::size()> args({a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
c_m_n_dev_buf.GetDeviceBuffer(),
/* kBatch */ 1,
M,
N,
K,
StrideA,
StrideB,
stridesDs,
StrideC});
ck_tile::GemmHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
ds_ptr_buf.data(),
c_m_n_dev_buf.GetDeviceBuffer(),
/* kBatch */ 1,
M,
N,
K,
StrideA,
StrideB,
stridesDs.data(),
StrideC});
invoke_multi_d_gemm<ADataType,
BDataType,