From fef66ea9617987949929e0c5171f2b91c26b121a Mon Sep 17 00:00:00 2001 From: ltqin Date: Fri, 26 Apr 2024 04:07:14 +0800 Subject: [PATCH] Universal gemm flush cache (#1251) * add flush cache to device op * add flush cache parameter to ckProfiler * change calculate size a and b method * chang evaluation time method foro AVERAGE to MEDIAN * format code * adjust some code * fix core dumped * remove loop call flush icache in kernel * remove loop(outer) call flush icache --------- Co-authored-by: letaoqin [ROCm/composable_kernel commit: f448d179b7670e8e5d821aa0a49156009ab48a7a] --- include/ck/host_utility/flush_cache.hpp | 229 ++++++++++++++++++ include/ck/stream_config.hpp | 3 + .../impl/device_gemm_xdl_cshuffle_v3.hpp | 50 +++- include/ck/utility/flush_icache.hpp | 30 +++ .../profiler/profile_gemm_universal_impl.hpp | 20 +- profiler/src/profile_gemm_universal.cpp | 14 +- 6 files changed, 331 insertions(+), 15 deletions(-) create mode 100644 include/ck/host_utility/flush_cache.hpp create mode 100644 include/ck/utility/flush_icache.hpp diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp new file mode 100644 index 0000000000..805fb571fb --- /dev/null +++ b/include/ck/host_utility/flush_cache.hpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/flush_icache.hpp" +namespace ck { +namespace utility { + +template +struct RotatingMemWrapper +{ + using ADataType = decltype(Argument::p_a_grid); + using BDataType = decltype(Argument::p_b_grid); + + RotatingMemWrapper() = delete; + RotatingMemWrapper(Argument& arg_, + std::size_t rotating_count_, + std::size_t size_a_, + std::size_t size_b_) + : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_) + { + p_a_grids.push_back(arg.p_a_grid); + p_b_grids.push_back(arg.p_b_grid); + for(size_t i = 1; i < rotating_count; i++) + { + { + void* pADeviceBuf; + hip_check_error(hipMalloc(static_cast(&pADeviceBuf), size_a_)); + hip_check_error(hipMemcpy(static_cast(pADeviceBuf), + const_cast(p_a_grids[0]), + size_a_, + hipMemcpyDeviceToDevice)); + p_a_grids.push_back(pADeviceBuf); + } + + { + void* pBDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); + hip_check_error(hipMemcpy(static_cast(pBDeviceBuf), + const_cast(p_b_grids[0]), + size_b_, + hipMemcpyDeviceToDevice)); + p_b_grids.push_back(pBDeviceBuf); + } + } + } + + void Next() + { + if(rotating_count > 1) + { + std::size_t idx = iter++ % rotating_count; + arg.p_a_grid = reinterpret_cast(p_a_grids[idx]); + arg.p_b_grid = reinterpret_cast(p_b_grids[idx]); + } + } + void Print() + { + std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b + << ", rotating_count: " << rotating_count << "}" << std::endl; + } + ~RotatingMemWrapper() + { + if(rotating_count > 1) + { + // restore ptr + arg.p_a_grid = reinterpret_cast(p_a_grids[0]); + arg.p_b_grid = reinterpret_cast(p_b_grids[0]); + + // free device mem + for(size_t i = 1; i < rotating_count; i++) + { + hip_check_error(hipFree(const_cast(p_a_grids[i]))); + hip_check_error(hipFree(const_cast(p_b_grids[i]))); + } + } + } + + private: + Argument& arg; + std::size_t iter = 0; + std::size_t rotating_count = 1; + std::size_t size_a = 0; + std::size_t size_b = 0; + std::vector p_a_grids; + std::vector p_b_grids; +}; + +inline void flush_icache() +{ + hipDeviceProp_t deviceProps; + hip_check_error(hipGetDeviceProperties(&deviceProps, 0)); + int32_t gpu_block3 = deviceProps.multiProcessorCount * 60; + + ck::flush_icache<<>>(); + hip_check_error(hipGetLastError()); +} +// if TimePrePress == false, return time does not include preprocess's time +template +float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, + PreProcessFunc preprocess, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args& args) +{ +#if CK_TIME_KERNEL +#define MEDIAN 1 + if(stream_config.time_kernel_) + { +#if DEBUG_LOG + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); +#endif + // warm up + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(args); + hip_check_error(hipGetLastError()); + } + + const int nrepeat = stream_config.nrepeat_; + if(nrepeat == 0) + { + return 0.0; + } +#if DEBUG_LOG + printf("Start running %d times...\n", nrepeat); +#endif + +#if MEDIAN + std::set times; +#else + float total_time = 0; +#endif + for(int i = 0; i < nrepeat; ++i) + { + if constexpr(!TimePreprocess) + { + preprocess(); + } + + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + // calculate preprocess time + if constexpr(TimePreprocess) + { + preprocess(); + } + // run real kernel + kernel<<>>(args); + hip_check_error(hipGetLastError()); + // end real kernel + + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + float cur_time = 0; + hip_check_error(hipEventElapsedTime(&cur_time, start, stop)); +#if MEDIAN + times.insert(cur_time); +#else + total_time += cur_time; +#endif + +#if DEBUG_LOG + std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; + + printf("args.p_a_grid: %p, args.p_b_grid:%p\n", + static_cast(args.p_a_grid), + static_cast(args.p_b_grid)); +#endif + } + +#if MEDIAN + auto mid = times.begin(); + std::advance(mid, (nrepeat - 1) / 2); + if(nrepeat % 2 == 1) + { + return *mid; + } + else + { + auto mid_next = mid; + std::advance(mid_next, 1); + return (*mid + *mid_next) / 2; + } +#else + return total_time / nrepeat; +#endif + } + else + { + preprocess(); + kernel<<>>(args); + hip_check_error(hipGetLastError()); + + return 0; + } +#else + kernel<<>>(args); + hip_check_error(hipGetLastError()); + + return 0; +#endif +} + +} // namespace utility +} // namespace ck diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index a5b1407305..37ba250cf5 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -13,4 +13,7 @@ struct StreamConfig int log_level_ = 0; int cold_niters_ = 5; int nrepeat_ = 50; + + bool flush_cache = false; + int rotating_count = 1; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 9d3e97c3e4..57a25526ce 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" namespace ck { namespace tensor_operation { @@ -151,14 +152,49 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); + if(stream_config.flush_cache) + { + Argument arg_ = arg; + ck::utility::RotatingMemWrapper rotating_mem( + arg_, + stream_config.rotating_count, + arg_.M * arg_.K * sizeof(ADataType), + arg_.K * arg_.N * sizeof(BDataType)); + rotating_mem.Print(); - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } }; constexpr index_t minimum_occupancy = diff --git a/include/ck/utility/flush_icache.hpp b/include/ck/utility/flush_icache.hpp new file mode 100644 index 0000000000..7378ba5c26 --- /dev/null +++ b/include/ck/utility/flush_icache.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck { +static __global__ void flush_icache() +{ + asm __volatile__("s_icache_inv \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" :: + :); +} +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index c77541e0e0..362a5dccd1 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -43,7 +43,8 @@ bool profile_gemm_universal_impl(int do_verification, int StrideC, int KBatch, int n_warmup, - int n_iter) + int n_iter, + uint64_t rotating = 0) { bool pass = true; @@ -66,9 +67,16 @@ bool profile_gemm_universal_impl(int do_verification, Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; switch(init_method) { @@ -200,8 +208,14 @@ bool profile_gemm_universal_impl(int do_verification, std::string op_name = op_ptr->GetTypeString(); - float ave_time = invoker_ptr->Run( - argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + float ave_time = invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, + time_kernel, + 0, + n_warmup, + n_iter, + rotating_count > 1, + rotating_count}); std::size_t flop = std::size_t(2) * M * N * K; diff --git a/profiler/src/profile_gemm_universal.cpp b/profiler/src/profile_gemm_universal.cpp index 940ef09e59..2185ad8495 100644 --- a/profiler/src/profile_gemm_universal.cpp +++ b/profiler/src/profile_gemm_universal.cpp @@ -33,7 +33,7 @@ enum struct GemmDataType int profile_gemm_universal(int argc, char* argv[]) { - if(argc != 15 && argc != 17) + if(argc != 15 && argc != 18) { printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, " @@ -51,6 +51,7 @@ int profile_gemm_universal(int argc, char* argv[]) printf("optional:\n"); printf("arg15: number of warm-up cycles (default 1)\n"); printf("arg16: number of iterations (default 10)\n"); + printf("arg17: memory for rotating buffer (default 0, size in MB)\n"); exit(1); } @@ -70,12 +71,14 @@ int profile_gemm_universal(int argc, char* argv[]) const int StrideC = std::stoi(argv[13]); const int KBatch = std::stoi(argv[14]); - int n_warmup = 1; - int n_iter = 10; - if(argc == 17) + int n_warmup = 1; + int n_iter = 10; + uint64_t rotating = 0; + if(argc == 18) { n_warmup = std::stoi(argv[15]); n_iter = std::stoi(argv[16]); + rotating = std::stoull(argv[17]) * 1024 * 1024; } using F32 = float; @@ -124,7 +127,8 @@ int profile_gemm_universal(int argc, char* argv[]) (StrideC < 0) ? DefaultStrideC : StrideC, KBatch, n_warmup, - n_iter); + n_iter, + rotating); return pass ? 0 : 1; };