diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index 1d843b5594..faa134e5c4 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -80,6 +80,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ using InputDataType = ck_tile::remove_cvref_t; using QuantizedDataType = ck_tile::remove_cvref_t; + static constexpr auto WarpSize = ck_tile::get_warp_size(); static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index ada4c6f2da..c43d9c9a2e 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -186,7 +186,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // Rmsnorm2d { ck_tile::HostTensor invRms_host_ref({m}); - + ck_tile::HostTensor unquant_y_host_ref({m, n}); // CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for // simplicity ck_tile::reference_rmsnorm2d_fwd( - x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); + x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon); } // yscale