mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)
* Add shortcut to RMSNorm * Modify test for adding shortcut for RMSNorm * Add fused parameter into tests * 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp * 1. Supports various stride and percisions. * Add support of Epilogue * Add fuse and epilogue support to rmsnorm ref * Modify rmsnorm example * Refactor tests/examples * Bug fix for newly added tests/examples * Bug fix for new tests 2 * Modify smoke test scripts remove dbg code * Supports non-smooth dyanmic quant * Update Rmsnorm2dFwd::GetName() * rename xscale and prec_sx to smoothscale and prec_sm Bug fix after rename Remove files * change example_rmsnorm2d_fwd.cpp * update performance calculator * Fix issue in two-pass when fuse add is enabled * Remove comment of beta --------- Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -24,19 +24,19 @@ struct DynamicQuantEpilogueTraits
|
||||
|
||||
// this epilogue just store out a M*N matrix, row major
|
||||
template <typename AccDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct DynamicQuantEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
// TODO: we should put descriptor creation function into policy
|
||||
@@ -45,7 +45,7 @@ struct DynamicQuantEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
|
||||
@@ -78,7 +78,7 @@ struct DynamicQuantEpilogue
|
||||
#if 0
|
||||
// don't remove this
|
||||
// Note that if we set encoding purposely like this, you will result in compile fail
|
||||
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length)
|
||||
// TODO: sm_scale create local-scratch to accept arbitrary acc input (with same length)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
@@ -105,34 +105,18 @@ struct DynamicQuantEpilogue
|
||||
return reduce_crosswarp_sync.GetSmemSize();
|
||||
}
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp,
|
||||
typename XScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const XScaleWindow& x_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
|
||||
CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
auto reduce = GetBlockReduce2d();
|
||||
auto reduce_sync = GetBlockReduce2dSync();
|
||||
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
|
||||
const auto x_scale_window =
|
||||
make_tile_window(x_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
|
||||
auto x_scale = load_tile(x_scale_window);
|
||||
|
||||
auto o_acc_tmp = o_acc_tile;
|
||||
|
||||
sweep_tile(o_acc_tmp, [&](auto idx) {
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]);
|
||||
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
|
||||
});
|
||||
|
||||
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
|
||||
|
||||
auto row_absmax = [&]() {
|
||||
@@ -184,5 +168,45 @@ struct DynamicQuantEpilogue
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
|
||||
// Smooth Dynamic Quant
|
||||
template <typename ODramWindowTmp,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
const auto sm_scale_window =
|
||||
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
|
||||
auto sm_scale = load_tile(sm_scale_window);
|
||||
|
||||
auto o_acc_tmp = o_acc_tile;
|
||||
|
||||
sweep_tile(o_acc_tmp, [&](auto idx) {
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
|
||||
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
|
||||
});
|
||||
|
||||
Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem);
|
||||
}
|
||||
|
||||
// Dynamic Quant
|
||||
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user