Files
composable_kernel/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Mateusz Ozga 3c93d3c444 CK-Tile Grouped GEMM refactor and post PR fixes (#1756)
* Grouped gemm simple code refactor

* Offset invoker

* Invoke generic Run, and replace name of parrtitioner variable

* Tests fix type

* Removed namespaces

* Add template param to avoid implicit cast

* Remove generic function

* Constant value

* underline enum to int16_t

* Generalize partitioner function

* Remove whitespaces

* Rename function

* Using support

* Clang-format

* Clang-format

* Fn-partitioner description fn

* Typo

* Typo 2

* Better description

* Better description

* Refactor after review

* Use ctr instead of set fn

* Inovke ctr and typo

* Comments

* Remove unnecessary comment

* Review, remove modulo
2025-01-21 21:06:10 +01:00

60 lines
2.2 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
template <typename DataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using CDataType = ck_tile::half_t;
using AccDataType = float;
};
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("Ns", "", "N dimensions - empty by default.")
.insert("Ks", "", "K dimensions - empty by default.")
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "16", "group count.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);