mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK_TILE] fix build error in tile_add_rmsnorm2d_rdquant_fwd (#2243)
* [CK_TILE] fix build error in tile_add_rmsnorm2d_rdquant_fwd * fix error with the latest develop code.
This commit is contained in:
@@ -80,6 +80,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
|
||||
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
|
||||
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
|
||||
|
||||
static constexpr auto WarpSize = ck_tile::get_warp_size();
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
|
||||
@@ -186,7 +186,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// Rmsnorm2d
|
||||
{
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
|
||||
ck_tile::HostTensor<ck_tile::null_type> unquant_y_host_ref({m, n});
|
||||
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
|
||||
// simplicity
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
@@ -194,7 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
}
|
||||
|
||||
// yscale
|
||||
|
||||
Reference in New Issue
Block a user