diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index baa44f1ad6..918fb28ea9 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include #include #include @@ -43,8 +42,8 @@ struct RotatingMemWrapperMultiD { { void* pADeviceBuf; - HIP_CHECK_ERROR(hipMalloc(static_cast(&pADeviceBuf), size_a_)); - HIP_CHECK_ERROR(hipMemcpy(static_cast(pADeviceBuf), + hip_check_error(hipMalloc(static_cast(&pADeviceBuf), size_a_)); + hip_check_error(hipMemcpy(static_cast(pADeviceBuf), const_cast(p_a_grids[0]), size_a_, hipMemcpyDeviceToDevice)); @@ -53,8 +52,8 @@ struct RotatingMemWrapperMultiD { void* pBDeviceBuf; - HIP_CHECK_ERROR(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); - HIP_CHECK_ERROR(hipMemcpy(static_cast(pBDeviceBuf), + hip_check_error(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); + hip_check_error(hipMemcpy(static_cast(pBDeviceBuf), const_cast(p_b_grids[0]), size_b_, hipMemcpyDeviceToDevice)); @@ -66,8 +65,8 @@ struct RotatingMemWrapperMultiD DsGridPointer ds_buffer; static_for<0, NumDs, 1>{}([&](auto j) { void* pDDeviceBuf; - HIP_CHECK_ERROR(hipMalloc(static_cast(&pDDeviceBuf), size_ds_[j])); - HIP_CHECK_ERROR(hipMemcpy(static_cast(pDDeviceBuf), + hip_check_error(hipMalloc(static_cast(&pDDeviceBuf), size_ds_[j])); + hip_check_error(hipMemcpy(static_cast(pDDeviceBuf), static_cast(p_ds_grids[0][j]), size_ds_[j], hipMemcpyDeviceToDevice)); @@ -94,10 +93,8 @@ struct RotatingMemWrapperMultiD } void Print() { - std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b; - static_for<0, NumDs, 1>{}( - [&](auto j) { std::cout << ", size_d" << j.value << ": " << size_ds[j]; }); - std::cout << ", rotating_count: " << rotating_count << "}" << std::endl; + std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b + << ", rotating_count: " << rotating_count << "}" << std::endl; } ~RotatingMemWrapperMultiD() { @@ -111,35 +108,13 @@ struct RotatingMemWrapperMultiD // free device mem for(size_t i = 1; i < rotating_count; i++) { - try - { - HIP_CHECK_ERROR(hipFree(const_cast(p_a_grids[i]))); - } - catch(std::runtime_error& re) - { - std::cerr << re.what() << std::endl; - } - - try - { - HIP_CHECK_ERROR(hipFree(const_cast(p_b_grids[i]))); - } - catch(std::runtime_error& re) - { - std::cerr << re.what() << std::endl; - } + hip_check_error(hipFree(const_cast(p_a_grids[i]))); + hip_check_error(hipFree(const_cast(p_b_grids[i]))); static_for<0, NumDs, 1>{}([&](auto j) { using DDataType = remove_cvref_t>; - try - { - HIP_CHECK_ERROR( - hipFree(static_cast(const_cast(p_ds_grids[i][j])))); - } - catch(std::runtime_error& re) - { - std::cerr << re.what() << std::endl; - } + hip_check_error( + hipFree(static_cast(const_cast(p_ds_grids[i][j])))); }); } } @@ -176,8 +151,8 @@ struct RotatingMemWrapper { { void* pADeviceBuf; - HIP_CHECK_ERROR(hipMalloc(static_cast(&pADeviceBuf), size_a_)); - HIP_CHECK_ERROR(hipMemcpy(static_cast(pADeviceBuf), + hip_check_error(hipMalloc(static_cast(&pADeviceBuf), size_a_)); + hip_check_error(hipMemcpy(static_cast(pADeviceBuf), const_cast(p_a_grids[0]), size_a_, hipMemcpyDeviceToDevice)); @@ -186,8 +161,8 @@ struct RotatingMemWrapper { void* pBDeviceBuf; - HIP_CHECK_ERROR(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); - HIP_CHECK_ERROR(hipMemcpy(static_cast(pBDeviceBuf), + hip_check_error(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); + hip_check_error(hipMemcpy(static_cast(pBDeviceBuf), const_cast(p_b_grids[0]), size_b_, hipMemcpyDeviceToDevice)); @@ -221,23 +196,8 @@ struct RotatingMemWrapper // free device mem for(size_t i = 1; i < rotating_count; i++) { - try - { - HIP_CHECK_ERROR(hipFree(const_cast(p_a_grids[i]))); - } - catch(std::runtime_error& re) - { - std::cerr << re.what() << std::endl; - } - - try - { - HIP_CHECK_ERROR(hipFree(const_cast(p_b_grids[i]))); - } - catch(std::runtime_error& re) - { - std::cerr << re.what() << std::endl; - } + hip_check_error(hipFree(const_cast(p_a_grids[i]))); + hip_check_error(hipFree(const_cast(p_b_grids[i]))); } } } @@ -255,20 +215,25 @@ struct RotatingMemWrapper inline void flush_icache() { hipDeviceProp_t deviceProps; - HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0)); + hip_check_error(hipGetDeviceProperties(&deviceProps, 0)); int32_t gpu_block3 = deviceProps.multiProcessorCount * 60; ck::flush_icache<<>>(); - HIP_CHECK_ERROR(hipGetLastError()); + hip_check_error(hipGetLastError()); } // if TimePrePress == false, return time does not include preprocess's time -template +template float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, + GemmArgs& gemm_args, Args... args) { #if CK_TIME_KERNEL @@ -291,8 +256,8 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, // warm up for(int i = 0; i < stream_config.cold_niters_; ++i) { - kernel<<>>(args...); - HIP_CHECK_ERROR(hipGetLastError()); + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); } const int nrepeat = stream_config.nrepeat_; @@ -312,36 +277,54 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #endif hipEvent_t start, stop; - HIP_CHECK_ERROR(hipEventCreate(&start)); - HIP_CHECK_ERROR(hipEventCreate(&stop)); + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); for(int i = 0; i < nrepeat; ++i) { - preprocess(); + if constexpr(!TimePreprocess) + { + preprocess(); + } + // hipEvent_t start, stop; + + // hip_check_error(hipEventCreate(&start)); + // hip_check_error(hipEventCreate(&stop)); + + // hip_check_error(hipDeviceSynchronize()); + // hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + // calculate preprocess time + if constexpr(TimePreprocess) + { + preprocess(); + } // run real kernel - hipExtLaunchKernelGGL(kernel, - grid_dim, - block_dim, - lds_byte, - stream_config.stream_id_, - start, - stop, - 0, - args...); - HIP_CHECK_ERROR(hipGetLastError()); + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); // end real kernel - HIP_CHECK_ERROR(hipEventRecord(stop, stream_config.stream_id_)); - HIP_CHECK_ERROR(hipEventSynchronize(stop)); + // hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + // hip_check_error(hipEventSynchronize(stop)); + // float cur_time = 0; + // hip_check_error(hipEventElapsedTime(&cur_time, start, stop)); + // #if MEDIAN + // times.insert(cur_time); + // #else + // total_time += cur_time; + // #endif - float cur_time = 0; - HIP_CHECK_ERROR(hipEventElapsedTime(&cur_time, start, stop)); -#if MEDIAN - times.insert(cur_time); -#else - total_time += cur_time; -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; + + printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n", + static_cast(gemm_args.p_a_grid), + static_cast(gemm_args.p_b_grid)); + } } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventSynchronize(stop)); @@ -367,20 +350,24 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, return (*mid + *mid_next) / 2; } #else - return total_time / nrepeat; + // return total_time / nrepeat; + hipDeviceProp_t deviceProps; + hip_check_error(hipGetDeviceProperties(&deviceProps, 0)); + float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01; + return (total_time - preprocess_offset * nrepeat) / nrepeat; #endif } else { preprocess(); - kernel<<>>(args...); - HIP_CHECK_ERROR(hipGetLastError()); + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); return 0; } #else - kernel<<>>(args...); - HIP_CHECK_ERROR(hipGetLastError()); + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); return 0; #endif