From b683de6b32190a2e645515a33e16ed68fb154da7 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 28 Oct 2024 19:49:08 +0000 Subject: [PATCH] Fix bug of x verification --- .../add_rmsnorm2d_rdquant_fwd.cpp | 23 +++++++++++++++++-- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 23 +++++++++++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index e31a2a221b..612ef62f89 100644 --- a/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -147,8 +147,27 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.FromDevice(x_host_dev.data()); auto [rtol, atol] = get_elimit(); - pass = ck_tile::check_err( - x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + if(stride == n) + { + pass = ck_tile::check_err( + x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } } ck_tile::HostTensor y_host({m, n}); diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index 6a63d7aaf4..40fabf7f55 100644 --- a/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -159,8 +159,27 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.FromDevice(x_host_dev.data()); auto [rtol, atol] = get_elimit(); - pass = ck_tile::check_err( - x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + if(stride == n) + { + pass = ck_tile::check_err( + x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } } ck_tile::HostTensor y_host({m, n});