mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)
* Add shortcut to RMSNorm * Modify test for adding shortcut for RMSNorm * Add fused parameter into tests * 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp * 1. Supports various stride and percisions. * Add support of Epilogue * Add fuse and epilogue support to rmsnorm ref * Modify rmsnorm example * Refactor tests/examples * Bug fix for newly added tests/examples * Bug fix for new tests 2 * Modify smoke test scripts remove dbg code * Supports non-smooth dyanmic quant * Update Rmsnorm2dFwd::GetName() * rename xscale and prec_sx to smoothscale and prec_sm Bug fix after rename Remove files * change example_rmsnorm2d_fwd.cpp * update performance calculator * Fix issue in two-pass when fuse add is enabled * Remove comment of beta --------- Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXScaleBlockTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeSmoothScaleBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
|
||||
@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
|
||||
template <typename XWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename QYWindow,
|
||||
typename YScaleWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XScaleWindow& xscale_window_,
|
||||
const SmoothScaleWindow& smscale_window_,
|
||||
YScaleWindow& yscale_window,
|
||||
QYWindow& qy_window,
|
||||
ck_tile::index_t,
|
||||
@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto xscale_window = make_tile_window(
|
||||
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
|
||||
auto smscale_window = make_tile_window(
|
||||
smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_absmax_func = ReduceOp::AbsMax{};
|
||||
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
|
||||
@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
const auto x = load_tile(x_window);
|
||||
const auto xscale = load_tile(xscale_window);
|
||||
auto y = tile_elementwise_in(
|
||||
const auto x = load_tile(x_window);
|
||||
const auto smscale = load_tile(smscale_window);
|
||||
auto y = tile_elementwise_in(
|
||||
[&](const auto& a, const auto& b) {
|
||||
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
|
||||
},
|
||||
x,
|
||||
xscale);
|
||||
smscale);
|
||||
|
||||
// compute absmax, cross-lane->cross-warp
|
||||
auto absmax = [&]() {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Y = X * XScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
|
||||
// Y = X * SmoothScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
|
||||
template <typename XDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename QYDataType_,
|
||||
@@ -18,12 +18,12 @@ template <typename XDataType_,
|
||||
bool kTwoPass_>
|
||||
struct SmoothquantPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using QYDataType = remove_cvref_t<QYDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using QYDataType = remove_cvref_t<QYDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,11 +16,11 @@ struct SmoothquantPipelineTwoPass
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
|
||||
@@ -39,9 +39,12 @@ struct SmoothquantPipelineTwoPass
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
|
||||
template <typename XWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename QYWindow,
|
||||
typename YScaleWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XScaleWindow& xscale_window_,
|
||||
const SmoothScaleWindow& smscale_window_,
|
||||
YScaleWindow& yscale_window,
|
||||
QYWindow& qy_window,
|
||||
ck_tile::index_t row_size,
|
||||
@@ -49,8 +52,8 @@ struct SmoothquantPipelineTwoPass
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto xscale_window = make_tile_window(
|
||||
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
|
||||
auto smscale_window = make_tile_window(
|
||||
smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
|
||||
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
index_t num_n_tile_iteration =
|
||||
@@ -76,14 +79,14 @@ struct SmoothquantPipelineTwoPass
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
const auto xscale = load_tile(xscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
const auto x = load_tile(x_window);
|
||||
const auto smscale = load_tile(smscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
[&](const auto& a, const auto& b) {
|
||||
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
|
||||
},
|
||||
x,
|
||||
xscale);
|
||||
smscale);
|
||||
|
||||
constexpr auto x_size_per_row =
|
||||
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
|
||||
@@ -94,7 +97,7 @@ struct SmoothquantPipelineTwoPass
|
||||
block_reduce2d(y, absmax, reduce_absmax_func);
|
||||
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
move_tile_window(xscale_window, {Block_N});
|
||||
move_tile_window(smscale_window, {Block_N});
|
||||
}
|
||||
|
||||
// compute absmax, cross-lane->cross-warp
|
||||
@@ -114,20 +117,20 @@ struct SmoothquantPipelineTwoPass
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(xscale_window, {-Block_N});
|
||||
move_tile_window(smscale_window, {-Block_N});
|
||||
move_tile_window(qy_window, {0, stride_to_right_most_window});
|
||||
|
||||
// recompute y and quantize y to qy
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
const auto xscale = load_tile(xscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
const auto x = load_tile(x_window);
|
||||
const auto smscale = load_tile(smscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
[&](const auto& a, const auto& b) {
|
||||
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
|
||||
},
|
||||
x,
|
||||
xscale);
|
||||
smscale);
|
||||
|
||||
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
|
||||
sweep_tile(qy, [&](auto idx) {
|
||||
@@ -138,7 +141,7 @@ struct SmoothquantPipelineTwoPass
|
||||
store_tile(qy_window, qy);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(xscale_window, {0, -Block_N});
|
||||
move_tile_window(smscale_window, {0, -Block_N});
|
||||
move_tile_window(qy_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user