diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index 43f4e8c724..8f029c212c 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -127,9 +127,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor x_scale_host_dev({n}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(x_scale_host); ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); - ck_tile::FillUniformDistribution{-1.f, 1.f}(x_scale_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); @@ -212,7 +213,11 @@ bool run(const ck_tile::ArgParser& arg_parser) x_host.mData.cend(), x_residual_host.mData.cbegin(), x_host.mData.begin(), - std::plus{}); + [](auto x_, auto r_) { + auto o_ = ck_tile::type_convert(x_) + + ck_tile::type_convert(r_); + return ck_tile::type_convert(o_); + }); } ck_tile::reference_layernorm2d_fwd sy_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor y_residual_host_dev({m, n}, {stride, 1}); if(fused_add == 1) { - y_residual_buf.FromDevice(sy_host_dev.data()); + y_residual_buf.FromDevice(y_residual_host_dev.data()); } auto [rtol, atol] = get_elimit(); @@ -294,8 +299,11 @@ bool run(const ck_tile::ArgParser& arg_parser) y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); if(fused_add == 1) { - pass &= ck_tile::check_err( - sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol); + pass &= ck_tile::check_err(y_residual_host_dev, + x_host, + std::string("ADD Error: Incorrect results!"), + rtol, + atol); } } else @@ -314,12 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser) atol); if(fused_add == 1) { - std::vector sy_host_dev_row( - sy_host_dev.begin() + i_r * stride, sy_host_dev.begin() + i_r * stride + n); - std::vector sy_host_ref_row( + std::vector y_residual_host_dev_row( + y_residual_host_dev.begin() + i_r * stride, + y_residual_host_dev.begin() + i_r * stride + n); + std::vector y_residual_host_ref_row( x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n); - pass &= ck_tile::check_err(sy_host_dev_row, - sy_host_ref_row, + pass &= ck_tile::check_err(y_residual_host_dev_row, + y_residual_host_ref_row, std::string("ADD[") + std::to_string(i_r) + std::string("] Error: Incorrect results!"), rtol, diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index 16a7c3b86d..5601f3a68c 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -111,8 +111,9 @@ struct Layernorm2dFwdPipelineOnePass { sweep_tile(x_resi, [&](auto idx) { // compute x = x_resi + x - x(idx) = type_convert(x_resi(idx)) + - type_convert(x(idx)); + auto re_ = type_convert(x_resi(idx)) + + type_convert(x(idx)); + x(idx) = type_convert(re_); }); if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) store_tile(y_residual_window, x); diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index ec10efbc69..48f66739da 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -122,8 +122,9 @@ struct Layernorm2dFwdPipelineTwoPass { sweep_tile(x_resi, [&](auto idx) { // compute x = x_resi + x - x(idx) = type_convert(x_resi(idx)) + - type_convert(x(idx)); + auto re_ = type_convert(x_resi(idx)) + + type_convert(x(idx)); + x(idx) = type_convert(re_); }); if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) { @@ -170,8 +171,9 @@ struct Layernorm2dFwdPipelineTwoPass { sweep_tile(x_resi, [&](auto idx) { // compute x = x_resi + x - x(idx) = type_convert(x_resi(idx)) + - type_convert(x(idx)); + auto re_ = type_convert(x_resi(idx)) + + type_convert(x(idx)); + x(idx) = type_convert(re_); }); } // load gamma/beta (TODO: support no gamma/beta?)