[layernorm] hot fix (#1620)

* hot fix ln

* some rename
This commit is contained in:
carlushuang
2024-11-01 11:52:50 +08:00
committed by GitHub
parent c3a4800c5f
commit 550248deec
3 changed files with 29 additions and 17 deletions

View File

@@ -111,8 +111,9 @@ struct Layernorm2dFwdPipelineOnePass
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, x);

View File

@@ -122,8 +122,9 @@ struct Layernorm2dFwdPipelineTwoPass
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
@@ -170,8 +171,9 @@ struct Layernorm2dFwdPipelineTwoPass
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
});
}
// load gamma/beta (TODO: support no gamma/beta?)