Files
composable_kernel/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Mateusz Ozga 4cb3d7d7ea Ck tile grouped GEMM example (#1713)
* Ck-tile, impl. grouped gemm

* Workspace is allocated by user, and is passed to the function

* Prepare test to new api design

* Unify GemTransKernelArgs, removing N0 param

* Add 1 to dim3 in paritioner

* Typo: gem - > gemm

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
2024-12-04 21:40:01 +01:00

54 lines
1.8 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
template <typename DataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using CDataType = ck_tile::half_t;
using AccDataType = float;
};
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("validate", "1", "0. No validation, 1. Validation on CPU")
.insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("group_count", "16", "group count");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs);
float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);