mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_TILE] Switch to universal gemm for batched and grouped gemms (#1919)
* switch to universal gemm for batched and grouped gemms * added reviewer comments * fixed grouped gemm tests
This commit is contained in:
@@ -46,7 +46,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
{
|
||||
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using GemmKernelArgs = typename Base::GemmKernelArgs;
|
||||
using GemmKernelArgs = typename ck_tile::GemmKernelArgs;
|
||||
|
||||
using ADataType = typename Base::ADataType;
|
||||
using BDataType = typename Base::BDataType;
|
||||
@@ -65,7 +65,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
|
||||
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
|
||||
@@ -56,6 +56,20 @@ struct GemmHostArgs : public GemmProblem
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernel
|
||||
{
|
||||
@@ -90,20 +104,6 @@ struct GemmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
|
||||
{
|
||||
return GemmKernelArgs{hostArgs.a_ptr,
|
||||
|
||||
@@ -11,24 +11,17 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
CK_TILE_HOST GroupedGemmHostArgs() noexcept = default;
|
||||
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
ck_tile::index_t M_,
|
||||
ck_tile::index_t N_,
|
||||
ck_tile::index_t K_,
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_)
|
||||
: GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_)
|
||||
GemmKernelArgs group_karg;
|
||||
ck_tile::index_t block_start;
|
||||
ck_tile::index_t block_end;
|
||||
|
||||
GemmTransKernelArg() = default;
|
||||
GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
|
||||
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr index_t KBatch = 1;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
@@ -47,36 +40,22 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
using GemmKernelArgs = typename Base::GemmKernelArgs;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
GemmKernelArgs group_karg;
|
||||
ck_tile::index_t block_start;
|
||||
ck_tile::index_t block_end;
|
||||
|
||||
GemmTransKernelArg() = default;
|
||||
GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
|
||||
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
|
||||
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
__host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
__host__ static auto GetWorkSpaceSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
-> std::size_t
|
||||
{
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg);
|
||||
@@ -84,7 +63,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
__host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
|
||||
|
||||
__host__ static constexpr auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
__host__ static constexpr auto GridSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
@@ -95,7 +74,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
-> std::vector<GemmTransKernelArg>
|
||||
{
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
|
||||
Reference in New Issue
Block a user