mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
[CK_TILE] Improve RMS/Layer Normalization 2 Pass Pipeline Performance (#1861)
* 50ms -> 28ms
* Fix bug in non fuse_add_store cases
* Fine tuned setting for 2 pass pipeline
* adjust workload
* remove unnecessary change
* add layernorm
* Adding output quant and unquant results at the same time.
* fix test
* fix format
* tune for cases 128x640 and 128x1024
* bug ifx
[ROCm/composable_kernel commit: d49abdaa87]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -35,11 +35,13 @@ template <typename XDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename InvRmsDataType,
|
||||
typename UnquantYDataType,
|
||||
typename Epilogue = reference_rmsnorm2d_default_epilogue>
|
||||
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
const HostTensor<GammaDataType>& gamma_n,
|
||||
HostTensor<YDataType>& y_m_n,
|
||||
HostTensor<InvRmsDataType>& invRms_m,
|
||||
HostTensor<UnquantYDataType>& unquant_y_m_n,
|
||||
ComputeDataType epsilon,
|
||||
Epilogue epilogue_functor = {})
|
||||
{
|
||||
@@ -69,7 +71,14 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
acc(m, n) = x * divisor * gamma;
|
||||
}
|
||||
|
||||
epilogue_functor(m, y_m_n, acc);
|
||||
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
|
||||
{
|
||||
epilogue_functor(m, unquant_y_m_n, y_m_n, acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
epilogue_functor(m, y_m_n, acc);
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "default_2d_epilogue.hpp"
|
||||
#include "dynamic_quant_epilogue.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// User can reuse DynamicQuantEpilogueTraits with this epilogue
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool UseSmoothInputScale_,
|
||||
bool UseRawStore_ = true,
|
||||
bool UseMax3_ = false>
|
||||
using Default2DAndDynamicQuantEpilogueTraits =
|
||||
DynamicQuantEpilogueTraits<kPadM_, kPadN_, UseSmoothInputScale_, UseRawStore_, UseMax3_>;
|
||||
|
||||
// This epilogue just store out a M*N matrix, row major
|
||||
template <typename AccDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename ODataType_,
|
||||
typename UnquantYDataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct Default2DAndDynamicQuantEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using UnquantYDataType = remove_cvref_t<UnquantYDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct Default2DAndDynamicQuantEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using UnquantYDataType = remove_cvref_t<typename Problem::UnquantYDataType>;
|
||||
|
||||
static constexpr bool kPadM = Problem::Traits::kPadM;
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
|
||||
|
||||
using Default2DProblem =
|
||||
Default2DEpilogueProblem<AccDataType, UnquantYDataType, kPadM, kPadN, UseRawStore>;
|
||||
using Default2D = Default2DEpilogue<Default2DProblem>;
|
||||
using DynamicQuant = DynamicQuantEpilogue<Problem>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(Default2D::GetSmemSize(), DynamicQuant::GetSmemSize());
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmpD,
|
||||
typename ODramWindowTmpQ,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp,
|
||||
ODramWindowTmpQ& o_quant_dram_window_tmp,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem);
|
||||
DynamicQuant{}(o_quant_dram_window_tmp, sm_scale_window_, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmpD,
|
||||
typename ODramWindowTmpQ,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp,
|
||||
ODramWindowTmpQ& o_quant_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem);
|
||||
DynamicQuant{}(o_quant_dram_window_tmp, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -182,9 +182,16 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
}
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(beta_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
@@ -192,28 +199,43 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
// layernorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
auto acc = make_static_distributed_tensor<ComputeDataType>(
|
||||
decltype(load_tile(x_window))::get_tile_distribution());
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
sweep_tile(x, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
acc = cast_tile<ComputeDataType>(load_tile(x_window));
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
{
|
||||
const auto x_bias = load_tile(x_bias_window);
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
|
||||
sweep_tile(acc, [&](auto idx) {
|
||||
// compute x = bias + x
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
}
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
@@ -235,9 +257,6 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT);
|
||||
Epilogue{}(y_window, ln);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(x_bias_window, {-Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(beta_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
|
||||
@@ -21,6 +21,7 @@ struct Rmsnorm2dFwdHostArgs
|
||||
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
|
||||
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
|
||||
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
|
||||
void* p_y_unquant; // [m, n], output result before quant, nullptr if not used
|
||||
|
||||
float epsilon;
|
||||
|
||||
@@ -47,13 +48,15 @@ struct Rmsnorm2dFwd
|
||||
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using UnquantYDataType = remove_cvref_t<typename Problem::UnquantYDataType>;
|
||||
|
||||
// for simplicity, shortcut input/output type is same as X
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
@@ -81,6 +84,7 @@ struct Rmsnorm2dFwd
|
||||
void* p_y_residual;
|
||||
void* p_y_scale;
|
||||
void* p_invRms;
|
||||
void* p_y_unquant;
|
||||
|
||||
float epsilon;
|
||||
|
||||
@@ -103,6 +107,7 @@ struct Rmsnorm2dFwd
|
||||
hargs.p_y_residual,
|
||||
hargs.p_y_scale,
|
||||
hargs.p_invRms,
|
||||
hargs.p_y_unquant,
|
||||
hargs.epsilon,
|
||||
hargs.m,
|
||||
hargs.n,
|
||||
@@ -323,6 +328,30 @@ struct Rmsnorm2dFwd
|
||||
}
|
||||
}();
|
||||
|
||||
auto unquant_y_window = [&]() {
|
||||
if constexpr((kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
|
||||
kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) &&
|
||||
kSaveUnquant)
|
||||
{
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<UnquantYDataType*>(kargs.p_y_unquant),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.y_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
auto tmp2_ = pad_tensor_view(tmp_,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window,
|
||||
@@ -333,6 +362,7 @@ struct Rmsnorm2dFwd
|
||||
inv_rms_window,
|
||||
sm_scale_window,
|
||||
y_scale_window,
|
||||
unquant_y_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.n,
|
||||
smem,
|
||||
|
||||
@@ -25,8 +25,9 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
|
||||
@@ -54,6 +55,7 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
typename InvRmsWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename UnquantYWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
@@ -63,6 +65,7 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
InvRmsWindow& inv_rms_window,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window_,
|
||||
UnquantYWindow& unquant_y_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem,
|
||||
@@ -137,11 +140,26 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
|
||||
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(
|
||||
unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
}
|
||||
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -12,6 +12,7 @@ template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename InvRmsDataType_,
|
||||
typename UnquantYDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename BlockShape_,
|
||||
@@ -23,6 +24,7 @@ struct Rmsnorm2dFwdPipelineProblem
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
|
||||
using UnquantYDataType = remove_cvref_t<UnquantYDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
@@ -54,6 +54,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
typename InvRmsWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename UnquantYWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
@@ -63,6 +64,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
InvRmsWindow& inv_rms_window,
|
||||
const SmoothScaleWindow& /*sm_scale_window_*/,
|
||||
YScaleWindow& /*y_scale_window*/,
|
||||
UnquantYWindow& /*unquant_y_window*/,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem,
|
||||
@@ -136,32 +138,51 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
}
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
|
||||
// rmsnorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
auto acc = make_static_distributed_tensor<ComputeDataType>(
|
||||
decltype(load_tile(x_window))::get_tile_distribution());
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
|
||||
move_tile_window(y_residual_window, {0, -Block_N});
|
||||
}
|
||||
else
|
||||
{
|
||||
acc = cast_tile<ComputeDataType>(load_tile(x_window));
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
|
||||
// load gamma (TODO: support no gamma?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
// rmsnorm computation
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(
|
||||
decltype(load_tile(x_window))::get_tile_distribution());
|
||||
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
@@ -176,8 +197,6 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
|
||||
Epilogue{}(y_window, rmsn);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DY
|
||||
|
||||
template <bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
Rmsnorm2dFusedAddEnum kFusedAdd_,
|
||||
Rmsnorm2dFusedQuantEnum kFusedQuant_>
|
||||
@@ -46,6 +47,7 @@ struct Rmsnorm2dFwdTraits
|
||||
{
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
|
||||
Reference in New Issue
Block a user