[CK_TILE] Introduces a new GEMM API that splits the existing basic GEMM class into multiple specialized classes. (#2520)

* Init commit new API

* apply clang-format

* PreShuffle preapring

* Apply Preshuffle condition to universal_gemm

* Fix: convert size_t to index_t

* Review changes

* Mode 100755 -> 100644

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Mateusz Ozga
2025-07-24 20:39:56 +02:00
committed by GitHub
parent 1e84fdaca7
commit b507d889c1
28 changed files with 2094 additions and 1519 deletions

View File

@@ -233,7 +233,7 @@ struct GemmKernel {{
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
static constexpr bool permuteA = false;
static constexpr bool permuteB = false;
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
@@ -335,7 +335,7 @@ struct GemmKernel {{
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer);
kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {{
@@ -680,7 +680,7 @@ struct GemmDispatcher {
// Use a static local variable
static std::unordered_map<
std::string,
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>>
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>>
kernel_map;
return kernel_map;
}
@@ -705,7 +705,7 @@ struct GemmDispatcher {
warp_tile_n,
warp_tile_k,
) = tile[j]
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
content += f"""
if(structured_sparsity){{ // SMFMA"""
sparse = (
@@ -746,7 +746,7 @@ struct GemmDispatcher {
content += """ }
template <typename Kernel>
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream)
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
{
std::string name = Kernel::get_name();
float avg_time = Kernel::launch(args, stream);

View File

@@ -22,7 +22,7 @@ class GemmProfiler
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
@@ -89,10 +89,9 @@ class GemmProfiler
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs<> gemm_args = {
ck_tile::GemmHostArgs gemm_args = {
a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{}, // ds_ptr
c_m_n_dev_buf.GetDeviceBuffer(),
gemm_problem.split_k_,
gemm_problem.m_,
@@ -100,7 +99,6 @@ class GemmProfiler
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
{}, // stride_Ds
gemm_problem.stride_c_,
};