mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
@@ -127,9 +127,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
@@ -212,7 +213,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
x_host.mData.cend(),
|
||||
x_residual_host.mData.cbegin(),
|
||||
x_host.mData.begin(),
|
||||
std::plus<XDataType>{});
|
||||
[](auto x_, auto r_) {
|
||||
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
|
||||
ck_tile::type_convert<ComputeDataType>(r_);
|
||||
return ck_tile::type_convert<XDataType>(o_);
|
||||
});
|
||||
}
|
||||
ck_tile::reference_layernorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
@@ -280,10 +285,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
ck_tile::HostTensor<YResidualDataType> sy_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1});
|
||||
if(fused_add == 1)
|
||||
{
|
||||
y_residual_buf.FromDevice(sy_host_dev.data());
|
||||
y_residual_buf.FromDevice(y_residual_host_dev.data());
|
||||
}
|
||||
|
||||
auto [rtol, atol] = get_elimit<InDataType>();
|
||||
@@ -294,8 +299,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
if(fused_add == 1)
|
||||
{
|
||||
pass &= ck_tile::check_err(
|
||||
sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol);
|
||||
pass &= ck_tile::check_err(y_residual_host_dev,
|
||||
x_host,
|
||||
std::string("ADD Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -314,12 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
atol);
|
||||
if(fused_add == 1)
|
||||
{
|
||||
std::vector<YResidualDataType> sy_host_dev_row(
|
||||
sy_host_dev.begin() + i_r * stride, sy_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<YResidualDataType> sy_host_ref_row(
|
||||
std::vector<YResidualDataType> y_residual_host_dev_row(
|
||||
y_residual_host_dev.begin() + i_r * stride,
|
||||
y_residual_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<YResidualDataType> y_residual_host_ref_row(
|
||||
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n);
|
||||
pass &= ck_tile::check_err(sy_host_dev_row,
|
||||
sy_host_ref_row,
|
||||
pass &= ck_tile::check_err(y_residual_host_dev_row,
|
||||
y_residual_host_ref_row,
|
||||
std::string("ADD[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
|
||||
@@ -111,8 +111,9 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
|
||||
type_convert<YResidualDataType>(x(idx));
|
||||
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
|
||||
type_convert<ComputeDataType>(x(idx));
|
||||
x(idx) = type_convert<XDataType>(re_);
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
store_tile(y_residual_window, x);
|
||||
|
||||
@@ -122,8 +122,9 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
|
||||
type_convert<YResidualDataType>(x(idx));
|
||||
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
|
||||
type_convert<ComputeDataType>(x(idx));
|
||||
x(idx) = type_convert<XDataType>(re_);
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
@@ -170,8 +171,9 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
|
||||
type_convert<YResidualDataType>(x(idx));
|
||||
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
|
||||
type_convert<ComputeDataType>(x(idx));
|
||||
x(idx) = type_convert<XDataType>(re_);
|
||||
});
|
||||
}
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
|
||||
Reference in New Issue
Block a user