[CK_TILE] fix build error in tile_add_rmsnorm2d_rdquant_fwd (#2243)

* [CK_TILE] fix build error in tile_add_rmsnorm2d_rdquant_fwd

* fix error with the latest develop code.
This commit is contained in:
linqunAMD
2025-06-18 12:37:59 +08:00
committed by GitHub
parent a4e1248dba
commit 7aeec9a901
2 changed files with 3 additions and 2 deletions

View File

@@ -80,6 +80,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
static constexpr auto WarpSize = ck_tile::get_warp_size();
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr ck_tile::index_t total_warps =

View File

@@ -186,7 +186,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// Rmsnorm2d
{
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
ck_tile::HostTensor<ck_tile::null_type> unquant_y_host_ref({m, n});
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
// simplicity
ck_tile::reference_rmsnorm2d_fwd<XDataType,
@@ -194,7 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
}
// yscale