mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
fix formating
This commit is contained in:
@@ -28,20 +28,36 @@ def extract_test_params(config_file, output_file, pooling_dim="2d"):
|
||||
# Default 2D test parameters
|
||||
test_params = [
|
||||
{
|
||||
"N": 1, "H": 8, "W": 8, "C": 32,
|
||||
"Y": 2, "X": 2,
|
||||
"stride_h": 2, "stride_w": 2,
|
||||
"dilation_h": 1, "dilation_w": 1,
|
||||
"pad_h_left": 0, "pad_h_right": 0,
|
||||
"pad_w_left": 0, "pad_w_right": 0,
|
||||
"N": 1,
|
||||
"H": 8,
|
||||
"W": 8,
|
||||
"C": 32,
|
||||
"Y": 2,
|
||||
"X": 2,
|
||||
"stride_h": 2,
|
||||
"stride_w": 2,
|
||||
"dilation_h": 1,
|
||||
"dilation_w": 1,
|
||||
"pad_h_left": 0,
|
||||
"pad_h_right": 0,
|
||||
"pad_w_left": 0,
|
||||
"pad_w_right": 0,
|
||||
},
|
||||
{
|
||||
"N": 2, "H": 16, "W": 16, "C": 32,
|
||||
"Y": 3, "X": 3,
|
||||
"stride_h": 2, "stride_w": 2,
|
||||
"dilation_h": 1, "dilation_w": 1,
|
||||
"pad_h_left": 1, "pad_h_right": 1,
|
||||
"pad_w_left": 1, "pad_w_right": 1,
|
||||
"N": 2,
|
||||
"H": 16,
|
||||
"W": 16,
|
||||
"C": 32,
|
||||
"Y": 3,
|
||||
"X": 3,
|
||||
"stride_h": 2,
|
||||
"stride_w": 2,
|
||||
"dilation_h": 1,
|
||||
"dilation_w": 1,
|
||||
"pad_h_left": 1,
|
||||
"pad_h_right": 1,
|
||||
"pad_w_left": 1,
|
||||
"pad_w_right": 1,
|
||||
},
|
||||
]
|
||||
else: # 3d
|
||||
@@ -51,13 +67,26 @@ def extract_test_params(config_file, output_file, pooling_dim="2d"):
|
||||
# Default 3D test parameters
|
||||
test_params = [
|
||||
{
|
||||
"N": 1, "D": 4, "H": 4, "W": 4, "C": 32,
|
||||
"Z": 2, "Y": 2, "X": 2,
|
||||
"stride_d": 2, "stride_h": 2, "stride_w": 2,
|
||||
"dilation_d": 1, "dilation_h": 1, "dilation_w": 1,
|
||||
"pad_d_left": 0, "pad_d_right": 0,
|
||||
"pad_h_left": 0, "pad_h_right": 0,
|
||||
"pad_w_left": 0, "pad_w_right": 0,
|
||||
"N": 1,
|
||||
"D": 4,
|
||||
"H": 4,
|
||||
"W": 4,
|
||||
"C": 32,
|
||||
"Z": 2,
|
||||
"Y": 2,
|
||||
"X": 2,
|
||||
"stride_d": 2,
|
||||
"stride_h": 2,
|
||||
"stride_w": 2,
|
||||
"dilation_d": 1,
|
||||
"dilation_h": 1,
|
||||
"dilation_w": 1,
|
||||
"pad_d_left": 0,
|
||||
"pad_d_right": 0,
|
||||
"pad_h_left": 0,
|
||||
"pad_h_right": 0,
|
||||
"pad_w_left": 0,
|
||||
"pad_w_right": 0,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -33,24 +33,24 @@
|
||||
/// @brief Test parameters for 2D pooling
|
||||
struct PoolTestParams2D
|
||||
{
|
||||
int N, H, W, C; // Input dimensions (NHWC)
|
||||
int Y, X; // Window size
|
||||
int stride_h, stride_w; // Strides
|
||||
int dilation_h, dilation_w; // Dilations
|
||||
int pad_h_left, pad_h_right; // Height padding
|
||||
int pad_w_left, pad_w_right; // Width padding
|
||||
int N, H, W, C; // Input dimensions (NHWC)
|
||||
int Y, X; // Window size
|
||||
int stride_h, stride_w; // Strides
|
||||
int dilation_h, dilation_w; // Dilations
|
||||
int pad_h_left, pad_h_right; // Height padding
|
||||
int pad_w_left, pad_w_right; // Width padding
|
||||
};
|
||||
|
||||
/// @brief Test parameters for 3D pooling
|
||||
struct PoolTestParams3D
|
||||
{
|
||||
int N, D, H, W, C; // Input dimensions (NDHWC)
|
||||
int Z, Y, X; // Window size
|
||||
int stride_d, stride_h, stride_w; // Strides
|
||||
int dilation_d, dilation_h, dilation_w; // Dilations
|
||||
int pad_d_left, pad_d_right; // Depth padding
|
||||
int pad_h_left, pad_h_right; // Height padding
|
||||
int pad_w_left, pad_w_right; // Width padding
|
||||
int Z, Y, X; // Window size
|
||||
int stride_d, stride_h, stride_w; // Strides
|
||||
int dilation_d, dilation_h, dilation_w; // Dilations
|
||||
int pad_d_left, pad_d_right; // Depth padding
|
||||
int pad_h_left, pad_h_right; // Height padding
|
||||
int pad_w_left, pad_w_right; // Width padding
|
||||
};
|
||||
|
||||
// Include config-specific test parameters (after parameter structs are defined)
|
||||
@@ -67,17 +67,17 @@ class PoolingTileEngineTest2D : public ::testing::TestWithParam<PoolTestParams2D
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
auto params = GetParam();
|
||||
N_ = params.N;
|
||||
H_ = params.H;
|
||||
W_ = params.W;
|
||||
C_ = params.C;
|
||||
Y_ = params.Y;
|
||||
X_ = params.X;
|
||||
stride_h_ = params.stride_h;
|
||||
stride_w_ = params.stride_w;
|
||||
dilation_h_ = params.dilation_h;
|
||||
dilation_w_ = params.dilation_w;
|
||||
auto params = GetParam();
|
||||
N_ = params.N;
|
||||
H_ = params.H;
|
||||
W_ = params.W;
|
||||
C_ = params.C;
|
||||
Y_ = params.Y;
|
||||
X_ = params.X;
|
||||
stride_h_ = params.stride_h;
|
||||
stride_w_ = params.stride_w;
|
||||
dilation_h_ = params.dilation_h;
|
||||
dilation_w_ = params.dilation_w;
|
||||
pad_h_left_ = params.pad_h_left;
|
||||
pad_h_right_ = params.pad_h_right;
|
||||
pad_w_left_ = params.pad_w_left;
|
||||
@@ -86,8 +86,8 @@ class PoolingTileEngineTest2D : public ::testing::TestWithParam<PoolTestParams2D
|
||||
// Calculate output dimensions
|
||||
ck_tile::index_t Ys = (Y_ - 1) * dilation_h_ + 1;
|
||||
ck_tile::index_t Xs = (X_ - 1) * dilation_w_ + 1;
|
||||
Ho_ = (H_ + pad_h_left_ + pad_h_right_ - Ys) / stride_h_ + 1;
|
||||
Wo_ = (W_ + pad_w_left_ + pad_w_right_ - Xs) / stride_w_ + 1;
|
||||
Ho_ = (H_ + pad_h_left_ + pad_h_right_ - Ys) / stride_h_ + 1;
|
||||
Wo_ = (W_ + pad_w_left_ + pad_w_right_ - Xs) / stride_w_ + 1;
|
||||
}
|
||||
|
||||
int N_, H_, W_, C_;
|
||||
@@ -123,10 +123,10 @@ TEST_P(PoolingTileEngineTest2D, BasicFunctionality)
|
||||
d_out_index.SetZero();
|
||||
|
||||
// Build shapes and strides (NHWC layout)
|
||||
const auto input_shape = ck_tile::make_tuple(N_, H_, W_, C_);
|
||||
const auto output_shape = ck_tile::make_tuple(N_, Ho_, Wo_, C_);
|
||||
const auto input_strides = ck_tile::make_tuple(H_ * W_ * C_, W_ * C_, C_, 1);
|
||||
const auto output_strides = ck_tile::make_tuple(Ho_ * Wo_ * C_, Wo_ * C_, C_, 1);
|
||||
const auto input_shape = ck_tile::make_tuple(N_, H_, W_, C_);
|
||||
const auto output_shape = ck_tile::make_tuple(N_, Ho_, Wo_, C_);
|
||||
const auto input_strides = ck_tile::make_tuple(H_ * W_ * C_, W_ * C_, C_, 1);
|
||||
const auto output_strides = ck_tile::make_tuple(Ho_ * Wo_ * C_, Wo_ * C_, C_, 1);
|
||||
const auto window_lengths = ck_tile::make_tuple(Y_, X_);
|
||||
const auto window_strides = ck_tile::make_tuple(stride_h_, stride_w_);
|
||||
const auto window_dilations = ck_tile::make_tuple(dilation_h_, dilation_w_);
|
||||
@@ -134,20 +134,19 @@ TEST_P(PoolingTileEngineTest2D, BasicFunctionality)
|
||||
const auto input_right_pads = ck_tile::make_tuple(pad_h_right_, pad_w_right_);
|
||||
|
||||
// Build host args for the generated kernel
|
||||
auto host_args =
|
||||
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
|
||||
d_in.GetDeviceBuffer(),
|
||||
d_out.GetDeviceBuffer(),
|
||||
d_out_index.GetDeviceBuffer(),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
|
||||
d_in.GetDeviceBuffer(),
|
||||
d_out.GetDeviceBuffer(),
|
||||
d_out_index.GetDeviceBuffer(),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
// Stream config: no timing overhead for fastest execution
|
||||
ck_tile::stream_config stream_config{nullptr, false, 0, 0, 1, false, false, 1};
|
||||
@@ -175,20 +174,19 @@ TEST_P(PoolingTileEngineTest2D, BasicFunctionality)
|
||||
d_out_index.FromDevice(h_out_index.data());
|
||||
|
||||
// Compute reference on host
|
||||
auto kernel_args_ref =
|
||||
ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
|
||||
h_in.data(),
|
||||
h_out_ref.data(),
|
||||
h_out_ref_index.data(),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
auto kernel_args_ref = ck_tile::PoolKernelArgs<decltype(input_shape), decltype(window_lengths)>{
|
||||
h_in.data(),
|
||||
h_out_ref.data(),
|
||||
h_out_ref_index.data(),
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
output_strides,
|
||||
window_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
ck_tile::reference_pool2d<InDataType,
|
||||
ComputeDataType,
|
||||
@@ -201,15 +199,14 @@ TEST_P(PoolingTileEngineTest2D, BasicFunctionality)
|
||||
h_in, h_out_ref, h_out_ref_index, kernel_args_ref, ReduceOpType{});
|
||||
|
||||
// Verify value results
|
||||
bool pass_value =
|
||||
ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
|
||||
bool pass_value = ck_tile::check_err(h_out, h_out_ref, "Error: Incorrect values!", 1e-5, 1e-5);
|
||||
EXPECT_TRUE(pass_value) << "Pooling value verification failed for " << KERNEL_NAME;
|
||||
|
||||
// Verify index results if output_index is enabled
|
||||
if constexpr(SelectedKernel::kOutputIndex)
|
||||
{
|
||||
bool pass_index = ck_tile::check_err(
|
||||
h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
|
||||
bool pass_index =
|
||||
ck_tile::check_err(h_out_index, h_out_ref_index, "Error: Incorrect indices!", 0, 0);
|
||||
EXPECT_TRUE(pass_index) << "Pooling index verification failed for " << KERNEL_NAME;
|
||||
}
|
||||
}
|
||||
@@ -220,21 +217,19 @@ TEST_P(PoolingTileEngineTest2D, KernelInfo)
|
||||
|
||||
std::cout << "Testing kernel: " << KERNEL_NAME << std::endl;
|
||||
std::cout << "Problem size: N=" << N_ << " H=" << H_ << " W=" << W_ << " C=" << C_
|
||||
<< " Window=" << Y_ << "x" << X_ << " Output=" << Ho_ << "x" << Wo_
|
||||
<< std::endl;
|
||||
<< " Window=" << Y_ << "x" << X_ << " Output=" << Ho_ << "x" << Wo_ << std::endl;
|
||||
}
|
||||
|
||||
// Instantiate test suite with config-specific test parameters
|
||||
// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
PoolingVerification,
|
||||
PoolingTileEngineTest2D,
|
||||
::testing::ValuesIn(CONFIG_TEST_PARAMS),
|
||||
[](const ::testing::TestParamInfo<PoolTestParams2D>& param_info) {
|
||||
return "N" + std::to_string(param_info.param.N) + "_H" +
|
||||
std::to_string(param_info.param.H) + "_W" +
|
||||
std::to_string(param_info.param.W) + "_C" +
|
||||
std::to_string(param_info.param.C) + "_Y" +
|
||||
std::to_string(param_info.param.Y) + "_X" +
|
||||
std::to_string(param_info.param.X);
|
||||
});
|
||||
INSTANTIATE_TEST_SUITE_P(PoolingVerification,
|
||||
PoolingTileEngineTest2D,
|
||||
::testing::ValuesIn(CONFIG_TEST_PARAMS),
|
||||
[](const ::testing::TestParamInfo<PoolTestParams2D>& param_info) {
|
||||
return "N" + std::to_string(param_info.param.N) + "_H" +
|
||||
std::to_string(param_info.param.H) + "_W" +
|
||||
std::to_string(param_info.param.W) + "_C" +
|
||||
std::to_string(param_info.param.C) + "_Y" +
|
||||
std::to_string(param_info.param.Y) + "_X" +
|
||||
std::to_string(param_info.param.X);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user