mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
using develop branch timer
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_ext.h>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
@@ -43,8 +42,8 @@ struct RotatingMemWrapperMultiD
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
@@ -53,8 +52,8 @@ struct RotatingMemWrapperMultiD
|
||||
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
@@ -66,8 +65,8 @@ struct RotatingMemWrapperMultiD
|
||||
DsGridPointer ds_buffer;
|
||||
static_for<0, NumDs, 1>{}([&](auto j) {
|
||||
void* pDDeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pDDeviceBuf),
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
|
||||
static_cast<const void*>(p_ds_grids[0][j]),
|
||||
size_ds_[j],
|
||||
hipMemcpyDeviceToDevice));
|
||||
@@ -94,10 +93,8 @@ struct RotatingMemWrapperMultiD
|
||||
}
|
||||
void Print()
|
||||
{
|
||||
std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b;
|
||||
static_for<0, NumDs, 1>{}(
|
||||
[&](auto j) { std::cout << ", size_d" << j.value << ": " << size_ds[j]; });
|
||||
std::cout << ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
|
||||
<< ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
~RotatingMemWrapperMultiD()
|
||||
{
|
||||
@@ -111,35 +108,13 @@ struct RotatingMemWrapperMultiD
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
try
|
||||
{
|
||||
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
}
|
||||
catch(std::runtime_error& re)
|
||||
{
|
||||
std::cerr << re.what() << std::endl;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
}
|
||||
catch(std::runtime_error& re)
|
||||
{
|
||||
std::cerr << re.what() << std::endl;
|
||||
}
|
||||
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
|
||||
static_for<0, NumDs, 1>{}([&](auto j) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
|
||||
try
|
||||
{
|
||||
HIP_CHECK_ERROR(
|
||||
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
|
||||
}
|
||||
catch(std::runtime_error& re)
|
||||
{
|
||||
std::cerr << re.what() << std::endl;
|
||||
}
|
||||
hip_check_error(
|
||||
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -176,8 +151,8 @@ struct RotatingMemWrapper
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
@@ -186,8 +161,8 @@ struct RotatingMemWrapper
|
||||
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
@@ -221,23 +196,8 @@ struct RotatingMemWrapper
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
try
|
||||
{
|
||||
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
}
|
||||
catch(std::runtime_error& re)
|
||||
{
|
||||
std::cerr << re.what() << std::endl;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
}
|
||||
catch(std::runtime_error& re)
|
||||
{
|
||||
std::cerr << re.what() << std::endl;
|
||||
}
|
||||
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -255,20 +215,25 @@ struct RotatingMemWrapper
|
||||
inline void flush_icache()
|
||||
{
|
||||
hipDeviceProp_t deviceProps;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
|
||||
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
|
||||
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
|
||||
|
||||
ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
// if TimePrePress == false, return time does not include preprocess's time
|
||||
template <bool TimePreprocess, typename... Args, typename F, typename PreProcessFunc>
|
||||
template <bool TimePreprocess,
|
||||
typename GemmArgs,
|
||||
typename... Args,
|
||||
typename F,
|
||||
typename PreProcessFunc>
|
||||
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
PreProcessFunc preprocess,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
GemmArgs& gemm_args,
|
||||
Args... args)
|
||||
{
|
||||
#if CK_TIME_KERNEL
|
||||
@@ -291,8 +256,8 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
// warm up
|
||||
for(int i = 0; i < stream_config.cold_niters_; ++i)
|
||||
{
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
@@ -312,36 +277,54 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
#endif
|
||||
hipEvent_t start, stop;
|
||||
|
||||
HIP_CHECK_ERROR(hipEventCreate(&start));
|
||||
HIP_CHECK_ERROR(hipEventCreate(&stop));
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
preprocess();
|
||||
if constexpr(!TimePreprocess)
|
||||
{
|
||||
preprocess();
|
||||
}
|
||||
|
||||
// hipEvent_t start, stop;
|
||||
|
||||
// hip_check_error(hipEventCreate(&start));
|
||||
// hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
// hip_check_error(hipDeviceSynchronize());
|
||||
// hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
// calculate preprocess time
|
||||
if constexpr(TimePreprocess)
|
||||
{
|
||||
preprocess();
|
||||
}
|
||||
// run real kernel
|
||||
hipExtLaunchKernelGGL(kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
lds_byte,
|
||||
stream_config.stream_id_,
|
||||
start,
|
||||
stop,
|
||||
0,
|
||||
args...);
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
// end real kernel
|
||||
|
||||
HIP_CHECK_ERROR(hipEventRecord(stop, stream_config.stream_id_));
|
||||
HIP_CHECK_ERROR(hipEventSynchronize(stop));
|
||||
// hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
// hip_check_error(hipEventSynchronize(stop));
|
||||
// float cur_time = 0;
|
||||
// hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
|
||||
// #if MEDIAN
|
||||
// times.insert(cur_time);
|
||||
// #else
|
||||
// total_time += cur_time;
|
||||
// #endif
|
||||
|
||||
float cur_time = 0;
|
||||
HIP_CHECK_ERROR(hipEventElapsedTime(&cur_time, start, stop));
|
||||
#if MEDIAN
|
||||
times.insert(cur_time);
|
||||
#else
|
||||
total_time += cur_time;
|
||||
#endif
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
// std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
|
||||
|
||||
printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
|
||||
static_cast<const void*>(gemm_args.p_a_grid),
|
||||
static_cast<const void*>(gemm_args.p_b_grid));
|
||||
}
|
||||
}
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
@@ -367,20 +350,24 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
return (*mid + *mid_next) / 2;
|
||||
}
|
||||
#else
|
||||
return total_time / nrepeat;
|
||||
// return total_time / nrepeat;
|
||||
hipDeviceProp_t deviceProps;
|
||||
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
|
||||
float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
|
||||
return (total_time - preprocess_offset * nrepeat) / nrepeat;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess();
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
return 0;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user