mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[Ck tile] support rmsnorm and related fusion (#1605)
* Add reduce2d new api
* Prevent user use cross warp reduction
* Fix bug of std caculation
* Add rmsnorm2d
* Add rmsnorm small example
* Remove static assert to prevent compile fail
* Add script to test performance and correctness
* Add missing cmake change
* refine naming
* refine example of rmsnorm
* Fix bug of rmsnorm
* Refine naming
* Fix cmake
* clang format
* Refine pipeline name
* Add add_rmsnorm2d_rdquant kernel
* Add reduce op
* host verification
* Fix bug of one pass pipeline
* Refine tile size
* Add two pass pipeline
* Rename two pass to three pass
* Fix bug of kSaveX == false
* Add instance library
* Add test script
* Fix bug of x verification
* Add save_x to trait
* Add README
* Move reduce2d into reduce folder
* Fix bug of welford when number of m warp > 1
* remove reduncant comment
* 1. move 06_rmsnorm2d to 10_rmsnorm2d
2. move 07_add_rmsnorm2d_rdquant to 11_add_rmsnorm2d_rdquant
* clang format and add missing header
* Add host validation of add + layernorm2d + rsquant
* Revert "Add host validation of add + layernorm2d + rsquant"
This reverts commit 936cb45797.
* Remove deprecated flag
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct Rmsnorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,101 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
|
||||
struct Rmsnorm2dFwdPipelineOnePass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr_op"; // block per row
|
||||
else
|
||||
return "wpr_op"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
YWindow& y_window,
|
||||
InvRmsWindow& inv_rms_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
{
|
||||
const auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
const auto x = load_tile(x_window);
|
||||
// load gamma (TODO: support no gamma?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum = block_reduce2d(
|
||||
x, reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), reduce_square_sum_func);
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
// compute inv-rms
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
|
||||
},
|
||||
square_sum);
|
||||
|
||||
if constexpr(kSaveInvRms)
|
||||
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
|
||||
|
||||
// rmsnorm computation
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
});
|
||||
store_tile(y_window, y);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,36 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename GammaDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename InvRmsDataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kTwoPass_>
|
||||
struct Rmsnorm2dFwdPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,131 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
|
||||
struct Rmsnorm2dFwdPipelineTwoPass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr_tp"; // block per row
|
||||
else
|
||||
return "wpr_tp"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
YWindow& y_window,
|
||||
InvRmsWindow& inv_rms_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
|
||||
// Problem::BlockShape
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
|
||||
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>();
|
||||
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
block_reduce2d(x, square_sum, reduce_square_sum_func);
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
// compute inv-rms
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
|
||||
},
|
||||
square_sum);
|
||||
|
||||
if constexpr(kSaveInvRms)
|
||||
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
|
||||
// rmsnorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
});
|
||||
|
||||
store_tile(y_window, y);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user