[CK_TILE] Introduces a new GEMM API that splits the existing basic GEMM class into multiple specialized classes. (#2520)

* Init commit new API

* apply clang-format

* PreShuffle preapring

* Apply Preshuffle condition to universal_gemm

* Fix: convert size_t to index_t

* Review changes

* Mode 100755 -> 100644

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Mateusz Ozga
2025-07-24 20:39:56 +02:00
committed by GitHub
parent 1e84fdaca7
commit b507d889c1
28 changed files with 2094 additions and 1519 deletions

View File

@@ -9,35 +9,41 @@
namespace ck_tile {
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
/// @brief The Batched GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref BatchedGemmKernel "BatchedGemmKernel" when creating kernel
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides.
struct BatchedGemmHostArgs : public ck_tile::UniversalGemmHostArgs<>
{
CK_TILE_HOST BatchedGemmHostArgs() = default;
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: GemmHostArgs(a_ptr_,
b_ptr_,
{},
c_ptr_,
k_batch_,
M_,
N_,
K_,
stride_A_,
stride_B_,
{},
stride_C_),
CK_TILE_HOST explicit BatchedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ck_tile::index_t batch_stride_A_,
ck_tile::index_t batch_stride_B_,
ck_tile::index_t batch_stride_C_,
ck_tile::index_t batch_count_)
: UniversalGemmHostArgs<>({a_ptr_},
{b_ptr_},
{/*ds_ptr*/},
c_ptr_,
k_batch_,
M_,
N_,
K_,
{stride_A_},
{stride_B_},
{/*stride_Ds_*/},
stride_C_),
batch_stride_A(batch_stride_A_),
batch_stride_B(batch_stride_B_),
batch_stride_E(batch_stride_C_),
@@ -52,36 +58,43 @@ struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
struct BatchedGemmKernel
{
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ADataType = typename Base::ADataType;
using BDataType = typename Base::BDataType;
using CDataType = typename Base::EDataType;
/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using TilePartitioner = typename Base::TilePartitioner;
using GemmPipeline = typename Base::GemmPipeline;
using EpiloguePipeline = typename Base::EpiloguePipeline;
using ALayout = typename Base::ALayout;
using BLayout = typename Base::BLayout;
using CLayout = typename Base::ELayout;
/// @brief Specify the data type configurations for A, B, E and D
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
struct BatchedGemmKernelArgs : GemmKernelArgs
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{
index_t batch_stride_A;
index_t batch_stride_B;
@@ -91,27 +104,41 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using KernelArgs = BatchedGemmKernelArgs;
__host__ static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
[[nodiscard]] CK_TILE_HOST static auto GetName() -> const std::string
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
CK_TILE_HOST static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3
{
return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
}
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return dim3(UniversalGemmKernel::KernelBlockSize);
}
CK_TILE_HOST static constexpr BatchedGemmKernelArgs
MakeKernelArgs(const BatchedGemmHostArgs& hostArgs)
{
return BatchedGemmKernelArgs{{hostArgs.a_ptr,
hostArgs.b_ptr,
{},
return BatchedGemmKernelArgs{{hostArgs.as_ptr,
hostArgs.bs_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
{},
hostArgs.stride_As,
hostArgs.stride_Bs,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch},
hostArgs.batch_stride_A,
@@ -125,6 +152,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
{
return UniversalGemmKernel::IsSupportedArgument(kargs);
}
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
@@ -134,18 +167,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
// options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A +
splitk_batch_offset.a_k_split_offset;
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + batch_offset_A +
splitk_batch_offset.as_k_split_offset[0];
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
splitk_batch_offset.b_k_split_offset;
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + batch_offset_B +
splitk_batch_offset.bs_k_split_offset[0];
const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E);
@@ -154,7 +187,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
UniversalGemmKernel::RunGemm(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
};

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,185 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/// @brief The MultiD GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GemmKernelMultiD "GemmKernelMultiD" when creating kernel
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides. NumDTensor
/// describes the number of D tensors.
template <index_t NumDTensor = 1>
struct GemmMultiDHostArgs
{
CK_TILE_HOST GemmMultiDHostArgs() = default;
CK_TILE_HOST GemmMultiDHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
ds_ptr(ds_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernelMultiD
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
/// @brief Specify the layout configurations for A, B, E and D
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
/// @brief Specify the data type configurations for A, B, E and D
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ALayout>::value &&
!is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars.");
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, BLayout>::value &&
!is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars.");
/// @brief ELayout and EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, ELayout>::value &&
!is_detected<is_tuple, EDataType>::value,
"ELayout and EDataType must be scalars.");
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
static_assert(is_detected<is_tuple, DsLayout>::value &&
is_detected<is_tuple, DsDataType>::value &&
DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
"DsLayout and DsDataType must be tuples and must have the same size.");
/// @brief The sizes of NumATensor and NumBTensor have always been 1; the size of D is set by
/// the user."
static constexpr index_t NumATensor = 1;
static constexpr index_t NumBTensor = 1;
static constexpr index_t NumDTensor = DsDataType::size();
CK_TILE_HOST static auto GetName() -> const std::string
{
return UniversalGemmKernel::GetName();
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
{
return UniversalGemmKernel::GridSize(M, N, KBatch);
}
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
return UniversalGemmKernel::MaxOccupancyGridSize(s);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return UniversalGemmKernel::BlockSize();
}
CK_TILE_HOST static constexpr auto
MakeKernelArgs(const GemmMultiDHostArgs<NumDTensor>& hostArgs) ->
typename UniversalGemmKernel::KernelArgs
{
/// @brief Universal GEMM requires array objects and corresponding stride information for
/// matrices A, B, and D.
return UniversalGemmKernel::MakeKernelArgs(
UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>({hostArgs.a_ptr},
{hostArgs.b_ptr},
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.k_batch,
hostArgs.M,
hostArgs.N,
hostArgs.K,
{hostArgs.stride_A},
{hostArgs.stride_B},
hostArgs.stride_Ds,
hostArgs.stride_E));
}
CK_TILE_HOST static auto
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
{
return UniversalGemmKernel::IsSupportedArgument(kargs);
}
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
{
UniversalGemmKernel{}.template operator()(kargs);
}
};
} // namespace ck_tile

View File

@@ -16,37 +16,116 @@
namespace ck_tile {
/// @brief The Grouped GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides.
struct GroupedGemmHostArgs
{
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
struct GemmTransKernelArg
{
GemmKernelArgs<> group_karg;
UniversalGemmKernelArgs<> group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
GemmTransKernelArg() = delete;
GemmTransKernelArg(GemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
GemmTransKernelArg(GemmKernelArgs<>&& karg) : group_karg{karg}, block_start{0}, block_end{0} {}
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg)
: group_karg{karg}, block_start{0}, block_end{0}
{
}
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
struct GroupedGemmKernel
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using Base = UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
//// @brief Specify the layout configurations for A, B, C/E
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
/// @brief Specify the data type configurations for A, B, C/E
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
@@ -65,8 +144,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// clang-format on
}
CK_TILE_HOST static auto
GetWorkSpaceSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs) -> std::size_t
CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::size_t
{
return gemm_descs.size() * sizeof(GemmTransKernelArg);
}
@@ -95,8 +174,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
CK_TILE_HOST static auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
@@ -107,8 +185,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto
MakeKargs(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
-> std::vector<GemmTransKernelArg>
{
std::vector<GemmTransKernelArg> gemm_kernel_args_;
@@ -138,18 +215,19 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
grid_size += grid_size_grp;
auto karg = GemmKernelArgs<>{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
{},
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
stride_a,
stride_b,
{},
stride_e,
gemm_descs[i].k_batch};
auto karg =
UniversalGemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
{type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
{/*ds_ptr*/},
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
{stride_a},
{stride_b},
{/*stride_ds*/},
stride_e,
gemm_descs[i].k_batch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
@@ -181,7 +259,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
Run(kargs.group_karg, block_idx_2d, block_idx_z);
}
CK_TILE_DEVICE void Run(const GemmKernelArgs<>& kargs,
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs,
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
@@ -192,10 +270,10 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) +
splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) +
splitk_batch_offset.bs_k_split_offset[0];
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// allocate LDS
@@ -208,7 +286,15 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
}
else
{
this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
Base::RunGemm({a_ptr},
{b_ptr},
{/*ds_ptr*/},
c_ptr,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
@@ -224,7 +310,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
* batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
@@ -234,7 +321,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const GemmKernelArgs<>& kargs,
const UniversalGemmKernelArgs<>& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
@@ -242,7 +329,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset);
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
@@ -258,8 +345,12 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template

File diff suppressed because it is too large Load Diff