Merge commit '7aeec9a901e7e502e8d6ff8538b74cf0944ce318' into develop

This commit is contained in:
assistant-librarian[bot]
2025-06-18 05:13:53 +00:00
parent fd6d24591e
commit 467012d53d
2 changed files with 3 additions and 2 deletions

View File

@@ -80,6 +80,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
static constexpr auto WarpSize = ck_tile::get_warp_size();
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr ck_tile::index_t total_warps =

View File

@@ -186,7 +186,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// Rmsnorm2d
{
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
ck_tile::HostTensor<ck_tile::null_type> unquant_y_host_ref({m, n});
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
// simplicity
ck_tile::reference_rmsnorm2d_fwd<XDataType,
@@ -194,7 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
}
// yscale