mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include "config.h"
|
||||
|
||||
struct DeviceMem
|
||||
{
|
||||
@@ -27,4 +28,31 @@ struct KernelTimer
|
||||
std::unique_ptr<KernelTimerImpl> impl;
|
||||
};
|
||||
|
||||
void launch_kernel(const void* func, dim3 grid_dim, dim3 block_dim, void** args, float& time);
|
||||
template <typename... Args, typename F>
|
||||
float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, Args... args)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
#if DEVICE_BACKEND_HIP
|
||||
timer.Start();
|
||||
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, 0, 0, args...);
|
||||
|
||||
timer.End();
|
||||
|
||||
hipGetErrorString(hipGetLastError());
|
||||
#elif DEVICE_BACKEND_CUDA
|
||||
const void* f = reinterpret_cast<const void*>(kernel);
|
||||
void* p_args = {&args...};
|
||||
|
||||
timer.Start();
|
||||
|
||||
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, 0, 0);
|
||||
|
||||
timer.End();
|
||||
|
||||
checkCudaErrors(error);
|
||||
#endif
|
||||
|
||||
return timer.GetElapsedTime();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user