mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
@@ -111,8 +111,9 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
|
||||
type_convert<YResidualDataType>(x(idx));
|
||||
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
|
||||
type_convert<ComputeDataType>(x(idx));
|
||||
x(idx) = type_convert<XDataType>(re_);
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
store_tile(y_residual_window, x);
|
||||
|
||||
@@ -122,8 +122,9 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
|
||||
type_convert<YResidualDataType>(x(idx));
|
||||
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
|
||||
type_convert<ComputeDataType>(x(idx));
|
||||
x(idx) = type_convert<XDataType>(re_);
|
||||
});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
@@ -170,8 +171,9 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
|
||||
type_convert<YResidualDataType>(x(idx));
|
||||
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
|
||||
type_convert<ComputeDataType>(x(idx));
|
||||
x(idx) = type_convert<XDataType>(re_);
|
||||
});
|
||||
}
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
|
||||
Reference in New Issue
Block a user