fix formating

This commit is contained in:
Aleksander Dudek
2026-02-10 18:45:06 +00:00
parent 2c2125f73e
commit 9bfcce5566
7 changed files with 245 additions and 234 deletions

View File

@@ -28,20 +28,36 @@ def extract_test_params(config_file, output_file, pooling_dim="2d"):
# Default 2D test parameters
test_params = [
{
"N": 1, "H": 8, "W": 8, "C": 32,
"Y": 2, "X": 2,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0,
"N": 1,
"H": 8,
"W": 8,
"C": 32,
"Y": 2,
"X": 2,
"stride_h": 2,
"stride_w": 2,
"dilation_h": 1,
"dilation_w": 1,
"pad_h_left": 0,
"pad_h_right": 0,
"pad_w_left": 0,
"pad_w_right": 0,
},
{
"N": 2, "H": 16, "W": 16, "C": 32,
"Y": 3, "X": 3,
"stride_h": 2, "stride_w": 2,
"dilation_h": 1, "dilation_w": 1,
"pad_h_left": 1, "pad_h_right": 1,
"pad_w_left": 1, "pad_w_right": 1,
"N": 2,
"H": 16,
"W": 16,
"C": 32,
"Y": 3,
"X": 3,
"stride_h": 2,
"stride_w": 2,
"dilation_h": 1,
"dilation_w": 1,
"pad_h_left": 1,
"pad_h_right": 1,
"pad_w_left": 1,
"pad_w_right": 1,
},
]
else: # 3d
@@ -51,13 +67,26 @@ def extract_test_params(config_file, output_file, pooling_dim="2d"):
# Default 3D test parameters
test_params = [
{
"N": 1, "D": 4, "H": 4, "W": 4, "C": 32,
"Z": 2, "Y": 2, "X": 2,
"stride_d": 2, "stride_h": 2, "stride_w": 2,
"dilation_d": 1, "dilation_h": 1, "dilation_w": 1,
"pad_d_left": 0, "pad_d_right": 0,
"pad_h_left": 0, "pad_h_right": 0,
"pad_w_left": 0, "pad_w_right": 0,
"N": 1,
"D": 4,
"H": 4,
"W": 4,
"C": 32,
"Z": 2,
"Y": 2,
"X": 2,
"stride_d": 2,
"stride_h": 2,
"stride_w": 2,
"dilation_d": 1,
"dilation_h": 1,
"dilation_w": 1,
"pad_d_left": 0,
"pad_d_right": 0,
"pad_h_left": 0,
"pad_h_right": 0,
"pad_w_left": 0,
"pad_w_right": 0,
},
]

View File

@@ -33,24 +33,24 @@
/// @brief Test parameters for 2D pooling
struct PoolTestParams2D
{
int N, H, W, C; // Input dimensions (NHWC)
int Y, X; // Window size
int stride_h, stride_w; // Strides
int dilation_h, dilation_w; // Dilations
int pad_h_left, pad_h_right; // Height padding
int pad_w_left, pad_w_right; // Width padding
int N, H, W, C; // Input dimensions (NHWC)
int Y, X; // Window size
int stride_h, stride_w; // Strides
int dilation_h, dilation_w; // Dilations
int pad_h_left, pad_h_right; // Height padding
int pad_w_left, pad_w_right; // Width padding
};
/// @brief Test parameters for 3D pooling
struct PoolTestParams3D
{
int N, D, H, W, C; // Input dimensions (NDHWC)
int Z, Y, X; // Window size
int stride_d, stride_h, stride_w; // Strides
int dilation_d, dilation_h, dilation_w; // Dilations
int pad_d_left, pad_d_right; // Depth padding
int pad_h_left, pad_h_right; // Height padding
int pad_w_left, pad_w_right; // Width padding
int Z, Y, X; // Window size
int stride_d, stride_h, stride_w; // Strides
int dilation_d, dilation_h, dilation_w; // Dilations
int pad_d_left, pad_d_right; // Depth padding
int pad_h_left, pad_h_right; // Height padding
int pad_w_left, pad_w_right; // Width padding
};
// Include config-specific test parameters (after parameter structs are defined)
@@ -67,17 +67,17 @@ class PoolingTileEngineTest2D : public ::testing::TestWithParam<PoolTestParams2D
protected:
void SetUp() override
{
auto params = GetParam();
N_ = params.N;
H_ = params.H;
W_ = params.W;
C_ = params.C;
Y_ = params.Y;
X_ = params.X;
stride_h_ = params.stride_h;
stride_w_ = params.stride_w;
dilation_h_ = params.dilation_h;
dilation_w_ = params.dilation_w;
auto params = GetParam();
N_ = params.N;
H_ = params.H;
W_ = params.W;
C_ = params.C;
Y_ = params.Y;
X_ = params.X;
stride_h_ = params.stride_h;
stride_w_ = params.stride_w;
dilation_h_ = params.dilation_h;
dilation_w_ = params.dilation_w;
pad_h_left_ = params.pad_h_left;
pad_h_right_ = params.pad_h_right;
pad_w_left_ = params.pad_w_left;
@@ -86,8 +86,8 @@ class PoolingTileEngineTest2D : public ::testing::TestWithParam<PoolTestParams2D
// Calculate output dimensions
ck_tile::index_t Ys = (Y_ - 1) * dilation_h_ + 1;
ck_tile::index_t Xs = (X_ - 1) * dilation_w_ + 1;
Ho_ = (H_ + pad_h_left_ + pad_h_right_ - Ys) / stride_h_ + 1;
Wo_ = (W_ + pad_w_left_ + pad_w_right_ - Xs) / stride_w_ + 1;
Ho_ = (H_ + pad_h_left_ + pad_h_right_ - Ys) / stride_h_ + 1;
Wo_ = (W_ + pad_w_left_ + pad_w_right_ - Xs) / stride_w_ + 1;
}
int N_, H_, W_, C_;
@@ -123,10 +123,10 @@ TEST_P(PoolingTileEngineTest2D, BasicFunctionality)
d_out_index.SetZero();
// Build shapes and strides (NHWC layout)
const auto input_shape = ck_tile::make_tuple(N_, H_, W_, C_);
const auto output_shape = ck_tile::make_tuple(N_, Ho_, Wo_, C_);
const auto input_strides = ck_tile::make_tuple(H_ * W_ * C_, W_ * C_, C_, 1);
const auto output_strides = ck_tile::make_tuple(Ho_ * Wo_ * C_, Wo_ * C_, C_, 1);
const auto input_shape = ck_tile::make_tuple(N_, H_, W_, C_);
const auto output_shape = ck_tile::make_tuple(N_, Ho_, Wo_, C_);
const auto input_strides = ck_tile::make_tuple(H_ * W_ * C_, W_ * C_, C_, 1);
const auto output_strides = ck_tile::make_tuple(Ho_ * Wo_ * C_, Wo_ * C_, C_, 1);
const auto window_lengths = ck_tile::make_tuple(Y_, X_);
const auto window_strides = ck_tile::make_tuple(stride_h_, stride_w_);
const auto window_dilations = ck_tile::make_tuple(dilation_h_, dilation_w_);
@@ -134,20 +134,19 @@ TEST_P(PoolingTileEngineTest2D, BasicFunctionality)
const auto input_right_pads = ck_tile::make_tuple(pad_h_right_, pad_w_right_);
// Build host args for the generated kernel
auto host_args =
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
d_in.GetDeviceBuffer(),
d_out.GetDeviceBuffer(),
d_out_index.GetDeviceBuffer(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
// Stream config: no timing overhead for fastest execution
ck_tile::stream_config stream_config{nullptr, false, 0, 0, 1, false, false, 1};
@@ -175,20 +174,19 @@ TEST_P(PoolingTileEngineTest2D, BasicFunctionality)
d_out_index.FromDevice(h_out_index.data());
// Compute reference on host
auto kernel_args_ref =
ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto kernel_args_ref = ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
h_in.data(),
h_out_ref.data(),
h_out_ref_index.data(),
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
ck_tile::reference_pool2d<InDataType,
ComputeDataType,
@@ -201,15 +199,14 @@ TEST_P(PoolingTileEngineTest2D, BasicFunctionality)
h_in, h_out_ref, h_out_ref_index, kernel_args_ref, ReduceOpType{});
// Verify value results
bool pass_value =
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
bool pass_value = ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
EXPECT_TRUE(pass_value) << "Pooling value verification failed for " << KERNEL_NAME;
// Verify index results if output_index is enabled
if constexpr(SelectedKernel::kOutputIndex)
{
bool pass_index = ck_tile::check_err(
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
bool pass_index =
ck_tile::check_err(h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
EXPECT_TRUE(pass_index) << "Pooling index verification failed for " << KERNEL_NAME;
}
}
@@ -220,21 +217,19 @@ TEST_P(PoolingTileEngineTest2D, KernelInfo)
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
std::cout << "Problem size: N=" << N_ << " H=" << H_ << " W=" << W_ << " C=" << C_
<< " Window=" << Y_ << "x" << X_ << " Output=" << Ho_ << "x" << Wo_
<< std::endl;
<< " Window=" << Y_ << "x" << X_ << " Output=" << Ho_ << "x" << Wo_ << std::endl;
}
// Instantiate test suite with config-specific test parameters
// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file
INSTANTIATE_TEST_SUITE_P(
PoolingVerification,
PoolingTileEngineTest2D,
::testing::ValuesIn(CONFIG_TEST_PARAMS),
[](const ::testing::TestParamInfo<PoolTestParams2D>& param_info) {
return "N" + std::to_string(param_info.param.N) + "_H" +
std::to_string(param_info.param.H) + "_W" +
std::to_string(param_info.param.W) + "_C" +
std::to_string(param_info.param.C) + "_Y" +
std::to_string(param_info.param.Y) + "_X" +
std::to_string(param_info.param.X);
});
INSTANTIATE_TEST_SUITE_P(PoolingVerification,
PoolingTileEngineTest2D,
::testing::ValuesIn(CONFIG_TEST_PARAMS),
[](const ::testing::TestParamInfo<PoolTestParams2D>& param_info) {
return "N" + std::to_string(param_info.param.N) + "_H" +
std::to_string(param_info.param.H) + "_W" +
std::to_string(param_info.param.W) + "_C" +
std::to_string(param_info.param.C) + "_Y" +
std::to_string(param_info.param.Y) + "_X" +
std::to_string(param_info.param.X);
});