Fix bug of x verification

This commit is contained in:
rocking
2024-10-28 19:49:08 +00:00
parent 88d3079065
commit b683de6b32
2 changed files with 42 additions and 4 deletions

View File

@@ -147,8 +147,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_buf.FromDevice(x_host_dev.data());
auto [rtol, atol] = get_elimit<XDataType>();
pass = ck_tile::check_err(
x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol);
if(stride == n)
{
pass = ck_tile::check_err(
x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> x_host_dev_row(x_host_dev.begin() + i_r * stride,
x_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> x_host_ref_row(x_host_ref.begin() + i_r * stride,
x_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(x_host_dev_row,
x_host_ref_row,
std::string("x[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
ck_tile::HostTensor<YDataType> y_host({m, n});

View File

@@ -159,8 +159,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_buf.FromDevice(x_host_dev.data());
auto [rtol, atol] = get_elimit<XDataType>();
pass = ck_tile::check_err(
x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol);
if(stride == n)
{
pass = ck_tile::check_err(
x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> x_host_dev_row(x_host_dev.begin() + i_r * stride,
x_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> x_host_ref_row(x_host_ref.begin() + i_r * stride,
x_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(x_host_dev_row,
x_host_ref_row,
std::string("x[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
ck_tile::HostTensor<YDataType> y_host({m, n});