mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Simple HostArgs struct
This commit is contained in:
@@ -22,7 +22,7 @@ template <typename ADataType,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
|
||||
@@ -219,5 +219,4 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -166,18 +166,18 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C};
|
||||
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C};
|
||||
|
||||
float ave_time =
|
||||
gemm<ADataType,
|
||||
|
||||
@@ -22,7 +22,7 @@ template <typename ADataType,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
|
||||
@@ -54,7 +54,7 @@ using BDataType = Types::BDataType;
|
||||
using AccDataType = Types::AccDataType;
|
||||
using CDataType = Types::CDataType;
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
|
||||
@@ -62,6 +62,6 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
using multiple_d_gemm_kargs = ck_tile::GemmHostArgs<DsDataType::size()>;
|
||||
using multiple_d_gemm_kargs = ck_tile::GemmHostArgs;
|
||||
|
||||
float multiple_d_gemm(const multiple_d_gemm_kargs& kargs, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -16,21 +16,21 @@ template <typename ADataType,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_multi_d_gemm(const void* a_m_k_dev_buf,
|
||||
const void* b_k_n_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& d_m_n_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
|
||||
void* c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t StrideA,
|
||||
ck_tile::index_t StrideB,
|
||||
const std::array<ck_tile::index_t, DsDataType::size()> StrideDs,
|
||||
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
|
||||
ck_tile::index_t StrideC,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
multiple_d_gemm_kargs gemm_descs({a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
d_m_n_dev_buf,
|
||||
ds_m_n_dev_buf.data(),
|
||||
c_m_n_dev_buf,
|
||||
/*kbatch */ 1,
|
||||
M,
|
||||
@@ -38,7 +38,7 @@ float invoke_multi_d_gemm(const void* a_m_k_dev_buf,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideDs.data(),
|
||||
StrideC});
|
||||
|
||||
float ave_time = multiple_d_gemm<ADataType,
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
|
||||
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
|
||||
{
|
||||
CK_TILE_HOST BatchedGemmHostArgs() = default;
|
||||
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
|
||||
|
||||
@@ -12,13 +12,12 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GemmHostArgs
|
||||
{
|
||||
CK_TILE_HOST GemmHostArgs() = default;
|
||||
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
const void* ds_ptr_,
|
||||
void* c_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
@@ -26,7 +25,7 @@ struct GemmHostArgs
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
const index_t* stride_Ds_,
|
||||
index_t stride_C_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
@@ -45,14 +44,14 @@ struct GemmHostArgs
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
const void* ds_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
const index_t* stride_Ds;
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
};
|
||||
@@ -126,19 +125,18 @@ struct GemmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr GemmKernelArgs
|
||||
MakeKernelArgs(const GemmHostArgs<NumDTensor>& hostArgs)
|
||||
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
|
||||
{
|
||||
return GemmKernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
static_cast<const void*>(hostArgs.ds_ptr.data()),
|
||||
hostArgs.ds_ptr, // static_cast<const void*>(hostArgs.ds_ptr.data()),
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds.data(),
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
|
||||
@@ -55,16 +55,15 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
__host__ static auto
|
||||
GetWorkSpaceSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs) -> std::size_t
|
||||
__host__ static auto GetWorkSpaceSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
-> std::size_t
|
||||
{
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
|
||||
|
||||
__host__ static constexpr auto
|
||||
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||
__host__ static constexpr auto GridSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
@@ -75,8 +74,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
MakeKargs(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
-> std::vector<GemmTransKernelArg>
|
||||
{
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
|
||||
@@ -82,8 +82,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
// TODO: expose tile size through test t-param ?
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
@@ -424,7 +423,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args;
|
||||
ck_tile::GemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
static const ck_tile::index_t K_Warp_Tile = 8;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
|
||||
@@ -64,8 +64,7 @@ class TestCkTileMultipleDGemm : public ::testing::Test
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
void invoke_multi_d_gemm(const ck_tile::GemmHostArgs<DsDataType::size()>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
void invoke_multi_d_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
@@ -291,18 +290,18 @@ class TestCkTileMultipleDGemm : public ::testing::Test
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {StrideD0, StrideD1};
|
||||
|
||||
ck_tile::GemmHostArgs<DsDataType::size()> args({a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
/* kBatch */ 1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
stridesDs,
|
||||
StrideC});
|
||||
ck_tile::GemmHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf.data(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
/* kBatch */ 1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
stridesDs.data(),
|
||||
StrideC});
|
||||
|
||||
invoke_multi_d_gemm<ADataType,
|
||||
BDataType,
|
||||
|
||||
Reference in New Issue
Block a user