[CK_tile] Add rotating buffer feature for universal gemm (#2200)

* Add rotating buffer feature for universal gemm

* adding changes in tile_engine

* Updated code to merge kernel_launch

* removing comments

* Enable rotating buffer changes to flatmm

* Created diff launch_kernel function for rotating buffer

* Simplfied calculation using macros

* merge code with new changes in tile_engine

* clang formatted

* Redefine macros
This commit is contained in:
Khushbu Agarwal
2025-05-27 23:00:58 -07:00
committed by GitHub
parent c52649ad57
commit 99857e10e6
17 changed files with 409 additions and 74 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -11,6 +11,13 @@
#include <cstddef>
namespace ck_tile {
#define CU_FOR_MI308 80
#define CU_FOR_MI300X 228
#define OPTIMAL_LATENCY_MI308 0.005
#define OPTIMAL_LATENCY_MI300X 0.0015
#define OPTIMAL_LATENCY_SAFE_MARGIN 0.01
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
@@ -81,6 +88,8 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
template <typename... Callables>
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
{
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
if(!s.time_kernel_)
{
launch_and_check(s, std::forward<Callables>(callables)...);
@@ -88,7 +97,7 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
}
auto time_launches = [&](auto timer) {
// warmup
// Warmup
for(int i = 0; i < s.cold_niters_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
@@ -114,4 +123,52 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
}
}
template <typename PreprocessFunc, typename... Callables>
CK_TILE_HOST float launch_kernel_preprocess(const stream_config& s,
PreprocessFunc preprocess,
Callables&&... callables)
{
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
if(!s.time_kernel_)
{
preprocess();
launch_and_check(s, std::forward<Callables>(callables)...);
return 0;
}
auto time_launches = [&](auto timer) {
// Warmup
for(int i = 0; i < s.cold_niters_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++)
{
preprocess();
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.stop(s.stream_id_);
hipDeviceProp_t deviceProps;
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
float preprocess_offset =
(deviceProps.multiProcessorCount >= CU_FOR_MI300X) ? OPTIMAL_LATENCY_MI300X
: (deviceProps.multiProcessorCount == CU_FOR_MI308) ? OPTIMAL_LATENCY_MI308
: OPTIMAL_LATENCY_SAFE_MARGIN;
return (timer.duration() - preprocess_offset * s.nrepeat_) / s.nrepeat_;
};
if(s.is_gpu_timer_)
{
return time_launches(gpu_timer{});
}
else
{
return time_launches(cpu_timer{});
}
}
} // namespace ck_tile