This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,381 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace ck_tile {
// host side args
struct Rmsnorm2dFwdHostArgs
{
const void* p_x; // [m ,n], input, fp16/bf16
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
const void* p_gamma; // [1, n], gamma, prec same as input
void* p_y; // [m, n], output, fp16/bf16
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
void* p_y_unquant; // [m, n], output result before quant, nullptr if not used
float epsilon;
index_t m;
index_t n;
index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
};
// TODO: Extract some type to wrapper class
template <typename Pipeline_, typename Epilogue_>
struct Rmsnorm2dFwd
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using UnquantYDataType = remove_cvref_t<typename Problem::UnquantYDataType>;
// for simplicity, shortcut input/output type is same as X
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs
{
const void* p_x;
const void* p_x_residual;
const void* p_sm_scale;
const void* p_gamma;
void* p_y;
void* p_y_residual;
void* p_y_scale;
void* p_invRms;
void* p_y_unquant;
float epsilon;
index_t m;
index_t n;
index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
};
using Hargs = Rmsnorm2dFwdHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_x,
hargs.p_x_residual,
hargs.p_sm_scale,
hargs.p_gamma,
hargs.p_y,
hargs.p_y_residual,
hargs.p_y_scale,
hargs.p_invRms,
hargs.p_y_unquant,
hargs.epsilon,
hargs.m,
hargs.n,
hargs.x_stride,
hargs.xr_stride,
hargs.y_stride,
hargs.yr_stride};
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
return dim3(integer_divide_ceil(hargs.m, Block_M));
}
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
: Problem::BlockShape::template GetBlockSize<false>();
}
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
// in byte
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kFusedAdd != Rmsnorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Rmsnorm2dFusedAddEnumName<kFusedAdd>::name;
if (kFusedQuant != Rmsnorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Rmsnorm2dFusedQuantEnumName<kFusedQuant>::name;
if (kPadN) n += "_pn";
if (kSaveInvRms) n += "_rms";
if (kTwoPass) n += "_2p";
if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL) n += "_nsm";
else if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE) n += "_t5ml";
return n; }();
auto prec_str = [&] () {
std::string base_str = _SS_(t2s<XDataType>::name);
if (!std::is_same_v<XDataType, YDataType>) {
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
}
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) {
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
}
return base_str;
}();
return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
// clang-format on
#undef _SS_
#undef _TS_
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto iM = get_block_id() * Block_M;
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.x_stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto x_residual_window = [&]() {
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XResidualDataType*>(kargs.p_x_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.xr_stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
const auto gamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.n),
make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}();
auto y_window = [&]() {
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.y_stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
auto y_residual_window = [&]() {
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YResidualDataType*>(kargs.p_y_residual),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.yr_stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
auto inv_rms_window = [&]() {
if constexpr(kSaveInvRms)
{
const auto inv_rms_m = [&]() {
const auto inv_rms_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvRmsDataType*>(kargs.p_invRms),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
inv_rms_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(inv_rms_m, make_tuple(number<Block_M>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<Block_M>{}));
}();
auto sm_scale_window = [&]() {
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
make_tuple(kargs.n),
number<Vector_N>{});
return pad_tensor_view(tmp_0_,
make_tuple(number<Block_N>{}),
sequence<false>{}); // sm_scale no need pad
}();
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_N>{}));
}
}();
auto y_scale_window = [&]() {
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
const auto win_ = [&]() {
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<YScaleDataType*>(kargs.p_y_scale),
make_tuple(kargs.m),
number<1>{});
return pad_tensor_view(
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
}();
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}));
}
}();
auto unquant_y_window = [&]() {
if constexpr((kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) &&
kSaveUnquant)
{
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<UnquantYDataType*>(kargs.p_y_unquant),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.y_stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(tmp_,
make_tuple(number<Block_M>{}, number<Block_N>{}),
sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}
else
{
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
}
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window,
x_residual_window,
gamma_window,
y_window,
y_residual_window,
inv_rms_window,
sm_scale_window,
y_scale_window,
unquant_y_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.n,
smem,
Epilogue{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,95 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace ck_tile {
struct Rmsnorm2dFwdPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
if constexpr(Problem::kNeedCrossWarpSync)
{
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
using block_reduce2d = BlockReduce2d<P_>;
using x_block_tile =
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
MakeXBlockTileDistribution<Problem>()));
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
else
{
return 1; // zero size arrays are an extension
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,226 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
/**
* @brief This T5Pass implements the RMSNorm2d forward pipeline as a variant
* based on Rmsnorm2dFwdPipelineOnePass and Rmsnorm2dFwdPipelineTwoPass using a T5 model-like
* method.
*
* The T5 model, developed by Google, is a transformer-based architecture designed to perform
* a variety of NLP tasks. The T5-like approach employed here is characterized by how RMS
* normalization is handled, particularly where intermediate values are cast to BF16. This aims to
* achieve a similar value distribution to that produced by the VLLM hip implementation, thereby
* enhancing model accuracy.
*
* Note: While this implementation improves precision and can reduce discrepancies with VLLM, it is
* not guaranteed to eliminate all differences or ensure uniform outcomes across every use case.
*
* This implementation is a variant based on the original one-pass and two-pass approaches,
* allowing for both fused and non-fused add operations.
*/
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_op"; // block per row
else
return "wpr_op"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename UnquantYWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
YWindow& y_window_,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_,
UnquantYWindow& unquant_y_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
Epilogue) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
[[maybe_unused]] auto pre_out =
make_static_distributed_tensor<YResidualDataType>(x.get_tile_distribution());
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
// To make norm input align with residual output
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
{
pre_out(idx) = float_to_bf16<bf16_rounding_mode::standard>(acc(idx));
}
else
{
pre_out(idx) = type_convert<YResidualDataType>(acc(idx));
}
acc(idx) = type_convert<ComputeDataType>(pre_out(idx));
}
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, pre_out);
}
}
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
set_tile(square_sum, 0);
if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0)
{
sweep_tile(
acc,
[&](auto idx_0, auto idx_1) {
square_sum(idx_0) += acc[idx_0] * acc[idx_0] + acc[idx_1] * acc[idx_1];
},
sequence<1, 2>{});
}
else
{
square_sum = block_reduce2d(acc,
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
}
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp = acc[idx] * inv_rms_[i_idx];
const auto tmp_bf16 = float_to_bf16<bf16_rounding_mode::standard>(tmp);
const auto rmsn_ = type_convert<ComputeDataType>(tmp_bf16) * gamma_;
rmsn(idx) = rmsn_;
}
else
{
const auto tmp = type_convert<XDataType>(acc[idx] * inv_rms_[i_idx]);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
rmsn(idx) = rmsn_;
}
});
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(
unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
}
}
else
{
Epilogue{}(y_window_, rmsn, nullptr);
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,167 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
struct Rmsnorm2dFwdPipelineOnePass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_op"; // block per row
else
return "wpr_op"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename UnquantYWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
YWindow& y_window_,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window_,
UnquantYWindow& unquant_y_window,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
Epilogue) const
{
const auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
const auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
const auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
}
}
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d(acc,
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
reduce_square_sum_func);
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
auto rmsn_ = acc[idx] * inv_rms_[i_idx] * gamma_;
rmsn(idx) = rmsn_;
});
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(
unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
}
}
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
{
if constexpr(kSaveUnquant)
{
Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
}
else
{
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
}
}
else
{
Epilogue{}(y_window_, rmsn, nullptr);
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename XDataType_,
typename GammaDataType_,
typename ComputeDataType_,
typename YDataType_,
typename InvRmsDataType_,
typename UnquantYDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
typename BlockShape_,
typename Traits_>
struct Rmsnorm2dFwdPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
using UnquantYDataType = remove_cvref_t<UnquantYDataType_>;
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
using Traits = remove_cvref_t<Traits_>;
};
} // namespace ck_tile

View File

@@ -0,0 +1,205 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
struct Rmsnorm2dFwdPipelineTwoPass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
using XResidualDataType = XDataType;
using YResidualDataType = XDataType;
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_tp"; // block per row
else
return "wpr_tp"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename XResidualWindow,
typename GammaWindow,
typename YWindow,
typename YResidualWindow,
typename InvRmsWindow,
typename SmoothScaleWindow,
typename YScaleWindow,
typename UnquantYWindow,
typename Epilogue>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XResidualWindow& x_residual_window_,
const GammaWindow& gamma_window_,
YWindow& y_window,
const YResidualWindow& y_residual_window_,
InvRmsWindow& inv_rms_window,
const SmoothScaleWindow& /*sm_scale_window_*/,
YScaleWindow& /*y_scale_window*/,
UnquantYWindow& /*unquant_y_window*/,
ComputeDataType epsilon,
ck_tile::index_t row_size,
void* smem,
Epilogue) const
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window(
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
auto x_residual_window = make_tile_window(
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto y_residual_window = make_tile_window(
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
// Problem::BlockShape
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration =
amd_wave_read_first_lane(integer_divide_ceil(row_size, Block_N));
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
auto reduce_sum_func = ReduceOp::Add{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using ComputeTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto square_sum = block_reduce2d.template MakeYBlockTile<ComputeTensorType>();
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(x_residual_window, {0, Block_N});
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
move_tile_window(y_residual_window, {0, Block_N});
}
}
block_reduce2d(acc, square_sum, reduce_square_sum_func);
}
block_reduce2d_sync(square_sum, reduce_sum_func);
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
},
square_sum);
if constexpr(kSaveInvRms)
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
move_tile_window(y_residual_window, {0, -Block_N});
}
else
{
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
}
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
// rmsnorm computation
for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
{
auto acc = make_static_distributed_tensor<ComputeDataType>(
decltype(load_tile(x_window))::get_tile_distribution());
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
move_tile_window(y_residual_window, {0, -Block_N});
}
else
{
acc = cast_tile<ComputeDataType>(load_tile(x_window));
move_tile_window(x_window, {0, -Block_N});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{
auto x_resi = load_tile(x_residual_window);
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
move_tile_window(x_residual_window, {0, -Block_N});
}
}
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(
decltype(load_tile(x_window))::get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
auto rmsn_ = acc(idx) * inv_rms_[i_idx] * gamma_;
rmsn(idx) = rmsn_;
});
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
Epilogue{}(y_window, rmsn, nullptr);
move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,73 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
enum class Rmsnorm2dFusedAddEnum
{
NO_ADD = 0,
// fused add before RMSNorm and store result to global
PRE_ADD_STORE = 1,
// fused add before RMSNorm, but not store result
PRE_ADD = 2,
};
// clang-format off
template<Rmsnorm2dFusedAddEnum> struct Rmsnorm2dFusedAddEnumName;
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
// clang-format on
enum class Rmsnorm2dFusedQuantEnum
{
NO_SWEEP = 0,
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
};
// clang-format off
template<Rmsnorm2dFusedQuantEnum> struct Rmsnorm2dFusedQuantEnumName;
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
// clang-format on
enum class Rmsnorm2dSensitiveEnum
{
NO_SPECIFIC_MODEL = 0,
// T5-like model for RMSNorm. The T5 model, developed by Google, is a transformer-based
// architecture designed for a variety of NLP tasks. This option mimics T5's approach to
// RMSNorm, aiming to ensure similar value distributions and enhance accuracy.
T5_MODEL_LIKE = 1,
};
// clang-format off
template<Rmsnorm2dSensitiveEnum> struct Rmsnorm2dSensitiveEnumName;
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL> { static constexpr const char * name = "nsm"; };
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE> { static constexpr const char * name = "t5ml"; };
// clang-format on
template <bool kPadN_,
bool kSaveInvRms_,
bool kSaveUnquant_,
bool kTwoPass_,
Rmsnorm2dFusedAddEnum kFusedAdd_,
Rmsnorm2dFusedQuantEnum kFusedQuant_,
Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm_>
struct Rmsnorm2dFwdTraits
{
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kSaveUnquant = kSaveUnquant_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
};
} // namespace ck_tile