mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
fix conflict
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core[api_reference]==1.18.4
|
||||
rocm-docs-core[api_reference]==1.19.0
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
|
||||
@@ -237,7 +237,7 @@ requests==2.32.3
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core[api-reference]==1.18.4
|
||||
rocm-docs-core[api-reference]==1.19.0
|
||||
# via -r requirements.in
|
||||
rpds-py==0.24.0
|
||||
# via
|
||||
|
||||
@@ -214,4 +214,15 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,4 +220,11 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -178,7 +178,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
float ave_time =
|
||||
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename Pipeline, ck_tile::TailNumber TN>
|
||||
void try_run(ck_tile::TailNumber tn)
|
||||
@@ -74,64 +75,102 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
@@ -243,8 +282,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
@@ -345,7 +382,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
run_gemm_example(argc, argv);
|
||||
return !run_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -334,16 +334,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
|
||||
try
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec == "fp32" && index_prec == "int32")
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
|
||||
@@ -320,4 +320,15 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_batched_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,4 +319,15 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
constexpr bool Persistent = false;
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_grouped_gemm_example<Persistent>(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "flatmm_basic.hpp"
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -115,9 +116,47 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
float ave_time{0};
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
if(args.k_batch == 1)
|
||||
@@ -132,8 +171,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -177,4 +214,15 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_flatmm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,4 +133,11 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -122,7 +122,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
|
||||
float ave_time =
|
||||
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
|
||||
@@ -37,3 +37,5 @@
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
#include "ck_tile/host/flush_icache.hpp"
|
||||
#include "ck_tile/host/rotating_buffers.hpp"
|
||||
|
||||
30
include/ck_tile/host/flush_icache.hpp
Normal file
30
include/ck_tile/host/flush_icache.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
static __global__ void flush_cache()
|
||||
{
|
||||
asm __volatile__("s_icache_inv \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t" ::
|
||||
:);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -11,6 +11,13 @@
|
||||
#include <cstddef>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define CU_FOR_MI308 80
|
||||
#define CU_FOR_MI300X 228
|
||||
#define OPTIMAL_LATENCY_MI308 0.005
|
||||
#define OPTIMAL_LATENCY_MI300X 0.0015
|
||||
#define OPTIMAL_LATENCY_SAFE_MARGIN 0.01
|
||||
|
||||
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
@@ -81,6 +88,8 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
|
||||
template <typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
|
||||
{
|
||||
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
|
||||
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
@@ -88,7 +97,7 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
|
||||
}
|
||||
|
||||
auto time_launches = [&](auto timer) {
|
||||
// warmup
|
||||
// Warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
@@ -114,4 +123,52 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PreprocessFunc, typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel_preprocess(const stream_config& s,
|
||||
PreprocessFunc preprocess,
|
||||
Callables&&... callables)
|
||||
{
|
||||
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
|
||||
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
preprocess();
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto time_launches = [&](auto timer) {
|
||||
// Warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++)
|
||||
{
|
||||
preprocess();
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
hipDeviceProp_t deviceProps;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
|
||||
|
||||
float preprocess_offset =
|
||||
(deviceProps.multiProcessorCount >= CU_FOR_MI300X) ? OPTIMAL_LATENCY_MI300X
|
||||
: (deviceProps.multiProcessorCount == CU_FOR_MI308) ? OPTIMAL_LATENCY_MI308
|
||||
: OPTIMAL_LATENCY_SAFE_MARGIN;
|
||||
return (timer.duration() - preprocess_offset * s.nrepeat_) / s.nrepeat_;
|
||||
};
|
||||
|
||||
if(s.is_gpu_timer_)
|
||||
{
|
||||
return time_launches(gpu_timer{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return time_launches(cpu_timer{});
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
102
include/ck_tile/host/rotating_buffers.hpp
Normal file
102
include/ck_tile/host/rotating_buffers.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename BDataType>
|
||||
struct RotatingMemWrapper
|
||||
{
|
||||
RotatingMemWrapper() = delete;
|
||||
RotatingMemWrapper(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t size_a_,
|
||||
std::size_t size_b_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
rotating_count(rotating_count_),
|
||||
size_a(size_a_),
|
||||
size_b(size_b_)
|
||||
{
|
||||
p_a_grids.push_back(a_ptr);
|
||||
p_b_grids.push_back(b_ptr);
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_a_grids.push_back(pADeviceBuf);
|
||||
}
|
||||
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_b_grids.push_back(pBDeviceBuf);
|
||||
}
|
||||
}
|
||||
}
|
||||
void Next()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
std::size_t idx = iter++ % rotating_count;
|
||||
a_ptr = p_a_grids[idx];
|
||||
b_ptr = p_b_grids[idx];
|
||||
}
|
||||
}
|
||||
void Print()
|
||||
{
|
||||
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
|
||||
<< ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
~RotatingMemWrapper() noexcept
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
// restore ptr
|
||||
a_ptr = p_a_grids[0];
|
||||
b_ptr = p_b_grids[0];
|
||||
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
std::size_t iter = 0;
|
||||
std::size_t rotating_count = 1;
|
||||
std::size_t size_a = 0;
|
||||
std::size_t size_b = 0;
|
||||
std::vector<const void*> p_a_grids;
|
||||
std::vector<const void*> p_b_grids;
|
||||
};
|
||||
inline void flush_icache()
|
||||
{
|
||||
hipDeviceProp_t deviceProps;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
|
||||
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
|
||||
|
||||
ck_tile::flush_cache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -30,5 +30,7 @@ struct stream_config
|
||||
int cold_niters_ = 3;
|
||||
int nrepeat_ = 10;
|
||||
bool is_gpu_timer_ = true; // keep compatible
|
||||
bool flush_cache_ = false;
|
||||
int rotating_count_ = 1;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -54,7 +54,7 @@ target_link_libraries(gemm_host_api INTERFACE gemm_template_instances)
|
||||
|
||||
add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp)
|
||||
target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp profile_cache_db.hpp)
|
||||
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp benchmark_perf_db.hpp)
|
||||
target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api ${SQLite3_LIBRARIES})
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
|
||||
|
||||
@@ -36,14 +36,16 @@ rm -rf tile_engine/ && ninja benchmark_gemm # rebuild
|
||||
-stride_b The stride value for tensor B. Default is 0.
|
||||
-stride_c The stride value for tensor C Default is 0.
|
||||
-split_k The split value for k dimension. Default is 1.
|
||||
-enable_profile_cache Whether flush profile cache or not when benchmark kernel. Possible values are or false. Default is true.
|
||||
-flush_profile_cache Whether flush profile cache or not when benchmark kernel. Possible values are true or false. Default is false.
|
||||
-enable_perf_db Whether enable performance database or not when benchmark kernel. Possible values are true or false. Default is true.
|
||||
-clear_perf_db Whether clear performance database or not when benchmark kernel. Possible values are true or false. Default is false.
|
||||
-v The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 2, validation on GPU.
|
||||
-log Wether output kernel instance information or not. Possible values are true or false. Default is false.
|
||||
-warmup The number of iterations before benchmark the kernel. Default is 50.
|
||||
-repeat The number of iterations to benchmark the kernel. Default is 100.
|
||||
-timer Whether if the timer is gpu timer or not. Possible values are true or false. Default is true.
|
||||
-timer Whether the timer is gpu timer or not. Possible values are true or false. Default is true.
|
||||
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
|
||||
-flush_cache Whether flush cache or not in between different runs. Possible values are true or false. Default is false.
|
||||
-rotating_count The number of iterations to rotate the cache. Default is 5.
|
||||
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
|
||||
-csv_filename The filename of benchmark result. Default is gemm_kernel.
|
||||
-structured_sparsity whether use sparsity kernel or not. Possible values are true or false. Default is false.
|
||||
|
||||
@@ -26,17 +26,17 @@ void benchmark_gemm(const ck_tile::ArgParser& arg_parser)
|
||||
CLayout::name,
|
||||
arg_parser.get_bool("structured_sparsity")};
|
||||
|
||||
Setting setting{
|
||||
arg_parser.get_bool("enable_profile_cache"),
|
||||
arg_parser.get_bool("flush_profile_cache"),
|
||||
arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
};
|
||||
Setting setting{arg_parser.get_bool("enable_perf_db"),
|
||||
arg_parser.get_bool("clear_perf_db"),
|
||||
arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count")};
|
||||
|
||||
auto& profiler = GemmProfiler::instance(setting);
|
||||
|
||||
|
||||
@@ -126,8 +126,8 @@ struct KernelInstance
|
||||
|
||||
struct Setting
|
||||
{
|
||||
bool enable_profile_cache_;
|
||||
bool flush_profile_cache_;
|
||||
bool enable_perf_db_;
|
||||
bool clear_perf_db_;
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
@@ -135,6 +135,8 @@ struct Setting
|
||||
int init_method_;
|
||||
bool log_;
|
||||
std::string csv_filename_;
|
||||
bool flush_cache_;
|
||||
int rotating_count_;
|
||||
};
|
||||
|
||||
inline std::string get_rocm_version()
|
||||
|
||||
@@ -53,10 +53,10 @@ class StmtWrapper
|
||||
std::unique_ptr<sqlite3_stmt, decltype(&sqlite3_finalize)> stmt_;
|
||||
};
|
||||
|
||||
class ProfileCacheDB
|
||||
class BenchmarkPerfDB
|
||||
{
|
||||
public:
|
||||
explicit ProfileCacheDB(const std::filesystem::path& path)
|
||||
explicit BenchmarkPerfDB(const std::filesystem::path& path)
|
||||
: db_ptr_(
|
||||
[path] {
|
||||
sqlite3* raw_db_ptr = nullptr;
|
||||
20
tile_engine/ops/gemm/gemm_host_api.hpp
Executable file → Normal file
20
tile_engine/ops/gemm/gemm_host_api.hpp
Executable file → Normal file
@@ -73,14 +73,14 @@ inline auto create_args(int argc, char* argv[])
|
||||
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
|
||||
.insert("stride_c", "0", "The stride value for tensor C Default is 0.")
|
||||
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
|
||||
.insert("enable_profile_cache",
|
||||
.insert("enable_perf_db",
|
||||
"true",
|
||||
"Whether use profile cache or not when benchmark kernel, Possible values are true "
|
||||
"or false. Default is true.")
|
||||
.insert("flush_profile_cache",
|
||||
"Whether enable performance database or not when benchmark kernel. Possible "
|
||||
"values are true or false. Default is true.")
|
||||
.insert("clear_perf_db",
|
||||
"false",
|
||||
"Whether flush profile cache or not when benchmark kernel. Possible values are "
|
||||
"true or false. Default is false.")
|
||||
"Whether clear performance database or not when benchmark kernel. Possible values "
|
||||
"are true or false. Default is false.")
|
||||
.insert("verify",
|
||||
"2",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
|
||||
@@ -95,12 +95,18 @@ inline auto create_args(int argc, char* argv[])
|
||||
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Whether if the timer is gpu timer or not. Possible values are false or true. "
|
||||
"Whether the timer is gpu timer or not. Possible values are false or true. "
|
||||
"Default is true.")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1). Default is 0, random.")
|
||||
.insert("flush_cache",
|
||||
"false",
|
||||
"Whether flush cache or not in between different runs. Possible values are true or "
|
||||
"false. Default is false.")
|
||||
.insert(
|
||||
"rotating_count", "5", "The number of iterations to rotate the cache. Default is 5.")
|
||||
.insert("metric",
|
||||
"0",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
|
||||
@@ -273,9 +273,52 @@ struct GemmKernel {{
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(stream,
|
||||
if(stream.flush_cache_)
|
||||
{{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
auto is_row_major = [](auto layout_) {{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{{}};
|
||||
}};
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{{}})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{{}})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {{
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
|
||||
}};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
stream,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
else{{
|
||||
ave_time = ck_tile::launch_kernel(stream,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
return ave_time;
|
||||
|
||||
}};
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "benchmark_gemm.hpp"
|
||||
#include "profile_cache_db.hpp"
|
||||
#include "benchmark_perf_db.hpp"
|
||||
|
||||
class GemmProfiler
|
||||
{
|
||||
@@ -24,7 +24,7 @@ class GemmProfiler
|
||||
|
||||
bool if_should_profile(const GemmProblem& gemm_problem)
|
||||
{
|
||||
if(setting_.enable_profile_cache_)
|
||||
if(setting_.enable_perf_db_)
|
||||
{
|
||||
if(!cache_db_->check_if_record_problem(
|
||||
get_rocm_version(), ck_tile::get_device_name(), gemm_problem))
|
||||
@@ -42,7 +42,7 @@ class GemmProfiler
|
||||
kernel_instance.perf_result_.tflops_ = perf_result.tflops_;
|
||||
kernel_instance.perf_result_.bandwidth_ = perf_result.bandwidth_;
|
||||
std::cout << "Skip this instance for " << kernel_instance
|
||||
<< ", Because it has already been recorded in the cache database"
|
||||
<< ", Because it has already been recorded in the cache database. "
|
||||
<< std::endl;
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
return false;
|
||||
@@ -165,7 +165,9 @@ class GemmProfiler
|
||||
setting_.log_,
|
||||
setting_.n_warmup_,
|
||||
setting_.n_repeat_,
|
||||
setting_.is_gpu_timer_});
|
||||
setting_.is_gpu_timer_,
|
||||
setting_.flush_cache_,
|
||||
setting_.rotating_count_});
|
||||
process_result(gemm_problem,
|
||||
c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
@@ -173,7 +175,7 @@ class GemmProfiler
|
||||
kernel_run_result);
|
||||
}
|
||||
|
||||
if(setting_.enable_profile_cache_)
|
||||
if(setting_.enable_perf_db_)
|
||||
{
|
||||
cache_db_->insert_cache(
|
||||
get_rocm_version(), ck_tile::get_device_name(), kernel_instances_);
|
||||
@@ -299,7 +301,7 @@ class GemmProfiler
|
||||
|
||||
void initialize_profile_cache()
|
||||
{
|
||||
if(setting_.enable_profile_cache_)
|
||||
if(setting_.enable_perf_db_)
|
||||
{
|
||||
std::filesystem::path cache_db_prefix_path =
|
||||
std::filesystem::current_path() / ".tile_engine";
|
||||
@@ -347,7 +349,7 @@ class GemmProfiler
|
||||
|
||||
void handle_flush_cache(const std::filesystem::path& cache_db_path) const
|
||||
{
|
||||
if(setting_.flush_profile_cache_ && std::filesystem::exists(cache_db_path))
|
||||
if(setting_.clear_perf_db_ && std::filesystem::exists(cache_db_path))
|
||||
{
|
||||
std::error_code ec;
|
||||
if(std::filesystem::remove(cache_db_path, ec))
|
||||
@@ -366,7 +368,7 @@ class GemmProfiler
|
||||
{
|
||||
try
|
||||
{
|
||||
cache_db_ = std::make_unique<ProfileCacheDB>(path);
|
||||
cache_db_ = std::make_unique<BenchmarkPerfDB>(path);
|
||||
std::cout << "Loaded profile cache from " << path << std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
@@ -376,6 +378,6 @@ class GemmProfiler
|
||||
}
|
||||
|
||||
Setting setting_;
|
||||
std::unique_ptr<ProfileCacheDB> cache_db_;
|
||||
std::unique_ptr<BenchmarkPerfDB> cache_db_;
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user