mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Fix 11_add_rmsnorm2d_rdquant (#2207)
This commit is contained in:
@@ -67,13 +67,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using TypeConfig = AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = float;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = float;
|
||||
using UnquantYDataType = ck_tile::null_type;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
|
||||
@@ -184,6 +185,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// Rmsnorm2d
|
||||
{
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n});
|
||||
|
||||
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
|
||||
// simplicity
|
||||
@@ -191,8 +193,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
}
|
||||
|
||||
// yscale
|
||||
|
||||
Reference in New Issue
Block a user