diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index e31a2a221b..612ef62f89 100644 --- a/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -147,8 +147,27 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.FromDevice(x_host_dev.data()); auto [rtol, atol] = get_elimit(); - pass = ck_tile::check_err( - x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + if(stride == n) + { + pass = ck_tile::check_err( + x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } } ck_tile::HostTensor y_host({m, n}); diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index 6a63d7aaf4..40fabf7f55 100644 --- a/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -159,8 +159,27 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.FromDevice(x_host_dev.data()); auto [rtol, atol] = get_elimit(); - pass = ck_tile::check_err( - x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + if(stride == n) + { + pass = ck_tile::check_err( + x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } } ck_tile::HostTensor y_host({m, n});