[CK_TILE] Enhance RMSNorm Accuracy: New Pipeline Pass for Selectable Implementation (#2409)

* Add Rmsnorm2dFwdPipelineModelSensitiveT5Pass

* Update rmsnorm2d_fwd_pipeline_model_sensitive_pass

1.  Add BlockReduce2dTreeCrossWarpSync

* Add Rmsnorm2dFusedModelSensitiveEnum

* Update patch

1. Reverse generate.py
2. Remove comment in generate.py
3. Update tree cross warp reduce

* Refactor RMSNorm model enum and introduce T5-like option

* Update the n stage for cross warp reduce

* Add new cmdline option in RMSNorm for new pipeline testing

---------

Co-authored-by: Clement Lin <clement.lin@amd.com>
Co-authored-by: ClementLinCF <162283536+ClementLinCF@users.noreply.github.com>
This commit is contained in:
MHYangAMD
2025-07-16 14:05:26 +08:00
committed by GitHub
parent 6b09f0823e
commit 3499fe67ff
13 changed files with 730 additions and 184 deletions

View File

@@ -58,13 +58,14 @@ struct Rmsnorm2dFwd
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
@@ -150,6 +151,8 @@ struct Rmsnorm2dFwd
if (kPadN) n += "_pn";
if (kSaveInvRms) n += "_rms";
if (kTwoPass) n += "_2p";
if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL) n += "_nsm";
else if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE) n += "_t5ml";
return n; }();
auto prec_str = [&] () {

View File

@@ -69,6 +69,15 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
return BlockReduce2dCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dTreeCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dTreeCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{

View File

@@ -0,0 +1,228 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
/**
* @brief This T5Pass implements the RMSNorm2d forward pipeline as a variant
* based on Rmsnorm2dFwdPipelineOnePass and Rmsnorm2dFwdPipelineTwoPass using a T5 model-like
* method.
*
* The T5 model, developed by Google, is a transformer-based architecture designed to perform
* a variety of NLP tasks. The T5-like approach employed here is characterized by how RMS
* normalization is handled, particularly where intermediate values are cast to BF16. This aims to
* achieve a similar value distribution to that produced by the VLLM hip implementation, thereby
* enhancing model accuracy.
*
* Note: While this implementation improves precision and can reduce discrepancies with VLLM, it is
* not guaranteed to eliminate all differences or ensure uniform outcomes across every use case.
*
* This implementation is a variant based on the original one-pass and two-pass approaches,
* allowing for both fused and non-fused add operations.
*/
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_op"; // block per row
else
return "wpr_op"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename UnquantYWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
YWindow& y_window_,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_,
UnquantYWindow& unquant_y_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
Epilogue) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_tree_cross_warp_sync =
Policy::template GetBlockReduce2dTreeCrossWarpSync<Problem>();
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
[[maybe_unused]] auto pre_out =
make_static_distributed_tensor<YResidualDataType>(x.get_tile_distribution());
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
// To make norm input align with residual output
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
{
pre_out(idx) = float_to_bf16<bf16_rounding_mode::standard>(acc(idx));
}
else
{
pre_out(idx) = type_convert<YResidualDataType>(acc(idx));
}
acc(idx) = type_convert<ComputeDataType>(pre_out(idx));
}
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, pre_out);
}
}
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
set_tile(square_sum, 0);
if constexpr(Problem::BlockShape::Vector_N % 2 == 0)
{
sweep_tile(
acc,
[&](auto idx_0, auto idx_1) {
square_sum(idx_0) += acc[idx_0] * acc[idx_0] + acc[idx_1] * acc[idx_1];
},
sequence<1, 2>{});
}
else
{
square_sum = block_reduce2d(acc,
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
}
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_tree_cross_warp_sync(square_sum, smem, reduce_sum_func);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
{
const auto tmp0 =
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
type_convert<ComputeDataType>(tmp0) * gamma_);
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
rmsn(idx) = rmsn_;
}
else
{
const auto tmp = type_convert<YResidualDataType>(acc[idx] * inv_rms_[i_idx]);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
rmsn(idx) = rmsn_;
}
});
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(
unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
}
}
else
{
Epilogue{}(y_window_, rmsn);
}
}
};
} // namespace ck_tile

View File

@@ -117,10 +117,7 @@ struct Rmsnorm2dFwdPipelineOnePass
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
},
square_sum);
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));

View File

@@ -37,20 +37,37 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_Q
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
// clang-format on
enum class Rmsnorm2dSensitiveEnum
{
NO_SPECIFIC_MODEL = 0,
// T5-like model for RMSNorm. The T5 model, developed by Google, is a transformer-based
// architecture designed for a variety of NLP tasks. This option mimics T5's approach to
// RMSNorm, aiming to ensure similar value distributions and enhance accuracy.
T5_MODEL_LIKE = 1,
};
// clang-format off
template<Rmsnorm2dSensitiveEnum> struct Rmsnorm2dSensitiveEnumName;
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL> { static constexpr const char * name = "nsm"; };
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE> { static constexpr const char * name = "t5ml"; };
// clang-format on
template <bool kPadN_,
bool kSaveInvRms_,
bool kSaveUnquant_,
bool kTwoPass_,
Rmsnorm2dFusedAddEnum kFusedAdd_,
Rmsnorm2dFusedQuantEnum kFusedQuant_>
Rmsnorm2dFusedQuantEnum kFusedQuant_,
Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm_>
struct Rmsnorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kSaveUnquant = kSaveUnquant_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kSaveUnquant = kSaveUnquant_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
};
} // namespace ck_tile