fix formating

This commit is contained in:
Aleksander Dudek
2026-02-10 18:45:06 +00:00
parent 2c2125f73e
commit 9bfcce5566
7 changed files with 245 additions and 234 deletions

View File

@@ -55,25 +55,25 @@ int benchmark_pooling_single(int argc, char* argv[])
return -1;
// Parse problem dimensions
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
ck_tile::index_t Y = arg_parser.get_int("wy");
ck_tile::index_t X = arg_parser.get_int("wx");
ck_tile::index_t Sy = arg_parser.get_int("sy");
ck_tile::index_t Sx = arg_parser.get_int("sx");
ck_tile::index_t Dy = arg_parser.get_int("dy");
ck_tile::index_t Dx = arg_parser.get_int("dx");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
ck_tile::index_t Y = arg_parser.get_int("wy");
ck_tile::index_t X = arg_parser.get_int("wx");
ck_tile::index_t Sy = arg_parser.get_int("sy");
ck_tile::index_t Sx = arg_parser.get_int("sx");
ck_tile::index_t Dy = arg_parser.get_int("dy");
ck_tile::index_t Dx = arg_parser.get_int("dx");
ck_tile::index_t LeftPy = arg_parser.get_int("phy");
ck_tile::index_t RightPy = arg_parser.get_int("phyr");
ck_tile::index_t LeftPx = arg_parser.get_int("pwx");
ck_tile::index_t RightPx = arg_parser.get_int("pwxr");
bool verify = arg_parser.get_int("verify") != 0;
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int log_level = arg_parser.get_int("log");
bool verify = arg_parser.get_int("verify") != 0;
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int log_level = arg_parser.get_int("log");
// Calculate output dimensions
ck_tile::index_t Ys = (Y - 1) * Dy + 1;
@@ -106,30 +106,29 @@ int benchmark_pooling_single(int argc, char* argv[])
d_out_index.SetZero();
// Build host args
const auto input_shape = ck_tile::make_tuple(N, H, W, C);
const auto output_shape = ck_tile::make_tuple(N, Ho, Wo, C);
const auto input_strides = ck_tile::make_tuple(H * W * C, W * C, C, 1);
const auto output_strides = ck_tile::make_tuple(Ho * Wo * C, Wo * C, C, 1);
const auto input_shape = ck_tile::make_tuple(N, H, W, C);
const auto output_shape = ck_tile::make_tuple(N, Ho, Wo, C);
const auto input_strides = ck_tile::make_tuple(H * W * C, W * C, C, 1);
const auto output_strides = ck_tile::make_tuple(Ho * Wo * C, Wo * C, C, 1);
const auto window_lengths = ck_tile::make_tuple(Y, X);
const auto window_strides = ck_tile::make_tuple(Sy, Sx);
const auto window_dilations = ck_tile::make_tuple(Dy, Dx);
const auto input_left_pads = ck_tile::make_tuple(LeftPy, LeftPx);
const auto input_right_pads = ck_tile::make_tuple(RightPy, RightPx);
auto host_args =
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
// Stream configuration
ck_tile::stream_config stream{nullptr, true, log_level, warmup, repeat};
@@ -160,20 +159,19 @@ int benchmark_pooling_single(int argc, char* argv[])
d_out.FromDevice(h_out.data());
d_out_index.FromDevice(h_out_index.data());
auto kernel_args =
ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto kernel_args = ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::reference_pool2d<InDataType,
ComputeDataType,
@@ -191,10 +189,9 @@ int benchmark_pooling_single(int argc, char* argv[])
if(SelectedKernel::kOutputIndex)
{
bool pass_index = ck_tile::check_err(
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
std::cout << " Index verification: " << (pass_index ? "PASS" : "FAIL")
<< std::endl;
bool pass_index =
ck_tile::check_err(h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
std::cout << " Index verification: " << (pass_index ? "PASS" : "FAIL") << std::endl;
}
}