[CK_TILE] Improve RMS/Layer Normalization 2 Pass Pipeline Performance (#1861)

* 50ms -> 28ms

* Fix bug in non fuse_add_store cases

* Fine tuned setting for 2 pass pipeline

* adjust workload

* remove unnecessary change

* add layernorm

* Adding output quant and unquant results at the same time.

* fix test

* fix format

* tune for cases 128x640 and 128x1024

* bug ifx
This commit is contained in:
ruanjm
2025-03-25 20:09:45 +08:00
committed by GitHub
parent d2eab23958
commit d49abdaa87
15 changed files with 492 additions and 135 deletions

View File

@@ -21,6 +21,7 @@ struct Rmsnorm2dFwdHostArgs
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
void* p_y_unquant; // [m, n], output result before quant, nullptr if not used
float epsilon;
@@ -47,13 +48,15 @@ struct Rmsnorm2dFwd
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using UnquantYDataType = remove_cvref_t<typename Problem::UnquantYDataType>;
// for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
@@ -81,6 +84,7 @@ struct Rmsnorm2dFwd
void* p_y_residual;
void* p_y_scale;
void* p_invRms;
void* p_y_unquant;
float epsilon;
@@ -103,6 +107,7 @@ struct Rmsnorm2dFwd
hargs.p_y_residual,
hargs.p_y_scale,
hargs.p_invRms,
hargs.p_y_unquant,
hargs.epsilon,
hargs.m,
hargs.n,
@@ -323,6 +328,30 @@ struct Rmsnorm2dFwd
}
}();
auto unquant_y_window = [&]() {
if constexpr((kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) &&
kSaveUnquant)
{
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<UnquantYDataType*>(kargs.p_y_unquant),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.y_stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window,
@@ -333,6 +362,7 @@ struct Rmsnorm2dFwd
inv_rms_window,
sm_scale_window,
y_scale_window,
unquant_y_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n,
smem,

View File

@@ -25,8 +25,9 @@ struct Rmsnorm2dFwdPipelineOnePass
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
@@ -54,6 +55,7 @@ struct Rmsnorm2dFwdPipelineOnePass
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename UnquantYWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
@@ -63,6 +65,7 @@ struct Rmsnorm2dFwdPipelineOnePass
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_,
UnquantYWindow& unquant_y_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
@@ -137,11 +140,26 @@ struct Rmsnorm2dFwdPipelineOnePass
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
if constexpr(kSaveUnquant)
{
Epilogue{}(
unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
if constexpr(kSaveUnquant)
{
Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
}
}
else
{

View File

@@ -12,6 +12,7 @@ template <typename XDataType_,
typename ComputeDataType_,
typename YDataType_,
typename InvRmsDataType_,
typename UnquantYDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
typename BlockShape_,
@@ -23,6 +24,7 @@ struct Rmsnorm2dFwdPipelineProblem
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
using UnquantYDataType = remove_cvref_t<UnquantYDataType_>;
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;

View File

@@ -54,6 +54,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename UnquantYWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
@@ -63,6 +64,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& /*sm_scale_window_*/,
YScaleWindow& /*y_scale_window*/,
UnquantYWindow& /*unquant_y_window*/,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
@@ -136,32 +138,51 @@ struct Rmsnorm2dFwdPipelineTwoPass
ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
move_tile_window(y_residual_window, {0, -Block_N});
}
else
{
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
}
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
// rmsnorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x);
auto acc = make_static_distributed_tensor<ComputeDataType>(
decltype(load_tile(x_window))::get_tile_distribution());
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
move_tile_window(y_residual_window, {0, -Block_N});
}
else
{
acc = cast_tile<ComputeDataType>(load_tile(x_window));
move_tile_window(x_window, {0, -Block_N});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{
auto x_resi = load_tile(x_residual_window);
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
move_tile_window(x_residual_window, {0, -Block_N});
}
}
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
auto rmsn = make_static_distributed_tensor<ComputeDataType>(
decltype(load_tile(x_window))::get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
@@ -176,8 +197,6 @@ struct Rmsnorm2dFwdPipelineTwoPass
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
Epilogue{}(y_window, rmsn);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});
}

View File

@@ -39,6 +39,7 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DY
template <bool kPadN_,
bool kSaveInvRms_,
bool kSaveUnquant_,
bool kTwoPass_,
Rmsnorm2dFusedAddEnum kFusedAdd_,
Rmsnorm2dFusedQuantEnum kFusedQuant_>
@@ -46,6 +47,7 @@ struct Rmsnorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kSaveUnquant = kSaveUnquant_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;