[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)

* Add shortcut to RMSNorm

* Modify test for adding shortcut for RMSNorm

* Add fused parameter into tests

* 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp

* 1. Supports various stride and percisions.

* Add support of Epilogue

* Add fuse and epilogue support to rmsnorm ref

* Modify rmsnorm example

* Refactor tests/examples

* Bug fix for newly added tests/examples

* Bug fix for new tests 2

* Modify smoke test scripts

remove dbg code

* Supports non-smooth dyanmic quant

* Update Rmsnorm2dFwd::GetName()

* rename xscale and prec_sx to smoothscale and prec_sm

Bug fix after rename

Remove files

* change example_rmsnorm2d_fwd.cpp

* update performance calculator

* Fix issue in two-pass when fuse add is enabled

* Remove comment of beta

---------

Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
ruanjm
2025-01-15 10:23:48 +08:00
committed by GitHub
parent c0b90f130f
commit 04dd314883
58 changed files with 1823 additions and 1045 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXScaleBlockTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeSmoothScaleBlockTileDistribution()
{
using S = typename Problem::BlockShape;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
template <typename XWindow,
typename SmoothScaleWindow,
typename QYWindow,
typename YScaleWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XScaleWindow& xscale_window_,
const SmoothScaleWindow& smscale_window_,
YScaleWindow& yscale_window,
QYWindow& qy_window,
ck_tile::index_t,
@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto xscale_window = make_tile_window(
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
auto smscale_window = make_tile_window(
smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window);
auto y = tile_elementwise_in(
const auto x = load_tile(x_window);
const auto smscale = load_tile(smscale_window);
auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
},
x,
xscale);
smscale);
// compute absmax, cross-lane->cross-warp
auto absmax = [&]() {

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -7,9 +7,9 @@
namespace ck_tile {
// Y = X * XScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
// Y = X * SmoothScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
template <typename XDataType_,
typename XScaleDataType_,
typename SmoothScaleDataType_,
typename ComputeDataType_,
typename YScaleDataType_,
typename QYDataType_,
@@ -18,12 +18,12 @@ template <typename XDataType_,
bool kTwoPass_>
struct SmoothquantPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using QYDataType = remove_cvref_t<QYDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using XDataType = remove_cvref_t<XDataType_>;
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using QYDataType = remove_cvref_t<QYDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,11 +16,11 @@ struct SmoothquantPipelineTwoPass
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
@@ -39,9 +39,12 @@ struct SmoothquantPipelineTwoPass
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
template <typename XWindow,
typename SmoothScaleWindow,
typename QYWindow,
typename YScaleWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XScaleWindow& xscale_window_,
const SmoothScaleWindow& smscale_window_,
YScaleWindow& yscale_window,
QYWindow& qy_window,
ck_tile::index_t row_size,
@@ -49,8 +52,8 @@ struct SmoothquantPipelineTwoPass
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto xscale_window = make_tile_window(
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
auto smscale_window = make_tile_window(
smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration =
@@ -76,14 +79,14 @@ struct SmoothquantPipelineTwoPass
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window);
const auto y = tile_elementwise_in(
const auto x = load_tile(x_window);
const auto smscale = load_tile(smscale_window);
const auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
},
x,
xscale);
smscale);
constexpr auto x_size_per_row =
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
@@ -94,7 +97,7 @@ struct SmoothquantPipelineTwoPass
block_reduce2d(y, absmax, reduce_absmax_func);
move_tile_window(x_window, {0, Block_N});
move_tile_window(xscale_window, {Block_N});
move_tile_window(smscale_window, {Block_N});
}
// compute absmax, cross-lane->cross-warp
@@ -114,20 +117,20 @@ struct SmoothquantPipelineTwoPass
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {-Block_N});
move_tile_window(smscale_window, {-Block_N});
move_tile_window(qy_window, {0, stride_to_right_most_window});
// recompute y and quantize y to qy
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window);
const auto y = tile_elementwise_in(
const auto x = load_tile(x_window);
const auto smscale = load_tile(smscale_window);
const auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
},
x,
xscale);
smscale);
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
sweep_tile(qy, [&](auto idx) {
@@ -138,7 +141,7 @@ struct SmoothquantPipelineTwoPass
store_tile(qy_window, qy);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {0, -Block_N});
move_tile_window(smscale_window, {0, -Block_N});
move_tile_window(qy_window, {0, -Block_N});
}
}