From 0f3ee13842551349a177ec0f97b70d2ff18d57af Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 18 Jun 2025 12:37:59 +0800 Subject: [PATCH] [CK_TILE] fix build error in tile_add_rmsnorm2d_rdquant_fwd (#2243) * [CK_TILE] fix build error in tile_add_rmsnorm2d_rdquant_fwd * fix error with the latest develop code. [ROCm/composable_kernel commit: 7aeec9a901e7e502e8d6ff8538b74cf0944ce318] --- .../11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp | 1 + .../example_add_rmsnorm2d_rdquant_fwd.cpp | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index 1d843b5594..faa134e5c4 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -80,6 +80,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ using InputDataType = ck_tile::remove_cvref_t; using QuantizedDataType = ck_tile::remove_cvref_t; + static constexpr auto WarpSize = ck_tile::get_warp_size(); static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); static constexpr ck_tile::index_t total_warps = diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index ada4c6f2da..c43d9c9a2e 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -186,7 +186,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // Rmsnorm2d { ck_tile::HostTensor invRms_host_ref({m}); - + ck_tile::HostTensor unquant_y_host_ref({m, n}); // CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for // simplicity ck_tile::reference_rmsnorm2d_fwd( - x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); + x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon); } // yscale