mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
multi instance generation for CkTileEngine (#2080)
* Add support for multi-instance verification, print detail for each instance, documentation fix
* clang formatted
* Added Readme file
* updated readme
* Addressing review comments
* clang formatted
* Updated ReadMe and GPU reference code
* simplified dispatch kernel code
* indentation
[ROCm/composable_kernel commit: 7cadf187e2]
This commit is contained in:
51
tile_engine/ops/gemm/README.md
Normal file
51
tile_engine/ops/gemm/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# GEMM Matrix Multiplication
|
||||
|
||||
Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file `./configs/instance_combination.json`.
|
||||
|
||||
# Kernel Configurations
|
||||
|
||||
User needs to provide kernel configuration such as datatype, layout, tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/instance_combination.json`
|
||||
|
||||
## Build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
# To generate the executable
|
||||
make tile_engine_gemm -j
|
||||
```
|
||||
`tile_engine_gemm` will be located in the `./bin/` directory.
|
||||
|
||||
## tile_engine_gemm inputs
|
||||
```
|
||||
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-split_k SplitK value (default:1)
|
||||
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
|
||||
-warmup Number of iterations before benchmark the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
|
||||
-pipeline possible values are: compv3, compv4, mem (default:compv3)
|
||||
-scheduler possible values are: intrawave, interwave (default:intrawave)
|
||||
-epilogue possible values are: cshuffle, default (default:cshuffle)
|
||||
-pad_m Pad in m direction - true/false (default:false)
|
||||
-pad_n Pad in n direction - true/false (default:false)
|
||||
-pad_k Pad in k direction - true/false (default:false)
|
||||
|
||||
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
Below example will run gemm kernel with default dimensions of matrices, for compv3 pipeline, intrawave scheduler and default epilogue with all possible tile sizes mentioned in Config file.
|
||||
|
||||
```
|
||||
./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default
|
||||
```
|
||||
@@ -19,7 +19,7 @@
|
||||
"values": [256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [64]
|
||||
"values": [64, 32]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [2]
|
||||
|
||||
@@ -6,11 +6,16 @@
|
||||
#include "gemm_dispatcher.hpp"
|
||||
#include "gemm_host_api.hpp"
|
||||
|
||||
float gemm_kernel_launch(KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify,
|
||||
KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
return GemmDispatcher::dispatch(trait, args, s);
|
||||
return GemmDispatcher::dispatch(
|
||||
c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
@@ -20,11 +25,10 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
void run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
const ALayout a_layout = ALayout{};
|
||||
const BLayout b_layout = BLayout{};
|
||||
// const CLayout c_layout = CLayout{};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
@@ -113,43 +117,47 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
trait.kPadN = arg_parser.get_bool("pad_n");
|
||||
trait.kPadK = arg_parser.get_bool("pad_k");
|
||||
|
||||
float ave_time = gemm_kernel_launch(
|
||||
trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name << std::endl;
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
if(verify)
|
||||
{
|
||||
pass = gemm_verify<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
verify,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_dev_result,
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch);
|
||||
gemm_host_reference<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(verify,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_host_result,
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
}
|
||||
return pass;
|
||||
|
||||
gemm_kernel_launch(c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
trait,
|
||||
gemm_args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -159,7 +167,8 @@ int main(int argc, char* argv[])
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
return run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
|
||||
run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
@@ -54,24 +57,21 @@ struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief trait for GEMM kernel
|
||||
* @param pipeline: pipeline name
|
||||
* @param scheduler: scheduler name
|
||||
* @param epilogue: epilogue name
|
||||
* @param kPadM: padding for M dimension
|
||||
* @param kPadN: padding for N dimension
|
||||
* @param kPadK: padding for K dimension
|
||||
*
|
||||
*/
|
||||
|
||||
/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a
|
||||
/// specific kernel instance based on the provided settings.
|
||||
struct KernelTraits
|
||||
{
|
||||
/// @brief The name of the pipeline.
|
||||
std::string pipeline;
|
||||
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
|
||||
std::string scheduler;
|
||||
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
|
||||
std::string epilogue;
|
||||
/// @brief Indicates whether padding is applied to the M dimension.
|
||||
bool kPadM;
|
||||
/// @brief Indicates whether padding is applied to the N dimension.
|
||||
bool kPadN;
|
||||
/// @brief Indicates whether padding is applied to the K dimension.
|
||||
bool kPadK;
|
||||
};
|
||||
|
||||
@@ -184,11 +184,28 @@ void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Function to verify the kernel output with reference implementation on CPU/GPU
|
||||
*
|
||||
*/
|
||||
/// @brief Function to compare the results of the device and host computations
|
||||
void compare(ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -196,43 +213,25 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool gemm_verify(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch)
|
||||
void gemm_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C)
|
||||
{
|
||||
bool pass = true;
|
||||
if(verify == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
c_m_n_host_result.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
a_m_k, b_k_n, c_m_n_host_result);
|
||||
}
|
||||
else if(verify == 2)
|
||||
{
|
||||
@@ -241,29 +240,14 @@ bool gemm_verify(int verify,
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
|
||||
c_m_n_host_result.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(
|
||||
hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_A,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
a_m_k.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
@@ -273,30 +257,6 @@ bool gemm_verify(int verify,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
|
||||
d_C,
|
||||
c_m_n_dev_result.get_element_space_size_in_bytes(),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_A));
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_gpu_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -447,6 +447,17 @@ struct GemmKernel {{
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
static std::string get_name() {{
|
||||
return std::string("GemmKernel<Bllktile: ") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + ", " +
|
||||
"WaveMap: " + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + ", " +
|
||||
"WarpTile: " + std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + ", " +
|
||||
"PadidngM: " + "{kPadM}" + ", " +
|
||||
"PaddingN: " + "{kPadN}" + ", " +
|
||||
"PaddingK: " + "{kPadK}" + ", " +
|
||||
"Pipeline: " + "{pipeline}" + ", " +
|
||||
"Epilogue: " + "{epilogue}" + ", " +
|
||||
"Scheduler: " + "{scheduler}";
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
@@ -476,7 +487,10 @@ struct GemmDispatcher {
|
||||
static auto& get_kernel_map() {
|
||||
// Use a static local variable
|
||||
static std::unordered_map<std::string,
|
||||
std::function<float(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
|
||||
std::function<void(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
@@ -499,9 +513,12 @@ struct GemmDispatcher {
|
||||
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [](ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s) {{
|
||||
std::vector<float> results;"""
|
||||
content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s) {{
|
||||
"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
|
||||
@@ -509,21 +526,46 @@ struct GemmDispatcher {
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
//we can have multiple tiles config for the one kernel_trait
|
||||
return {group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>::launch(args, s);"""
|
||||
content += """
|
||||
};\n"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
|
||||
content += f"""
|
||||
}};\n"""
|
||||
|
||||
content += """ }
|
||||
|
||||
|
||||
static float dispatch(const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
template <typename Kernel>
|
||||
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
float avg_time = Kernel::launch(args, s);
|
||||
std::string description = Kernel::get_name();
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Performance for " << description << " : " << avg_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
if(verify)
|
||||
compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& s) {
|
||||
init();
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
|
||||
return it->second(gemm_args, s); //Running single instance
|
||||
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s);
|
||||
}
|
||||
throw std::runtime_error("No suitable kernel found: " + key);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user