mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)
* Add shortcut to RMSNorm * Modify test for adding shortcut for RMSNorm * Add fused parameter into tests * 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp * 1. Supports various stride and percisions. * Add support of Epilogue * Add fuse and epilogue support to rmsnorm ref * Modify rmsnorm example * Refactor tests/examples * Bug fix for newly added tests/examples * Bug fix for new tests 2 * Modify smoke test scripts remove dbg code * Supports non-smooth dyanmic quant * Update Rmsnorm2dFwd::GetName() * rename xscale and prec_sx to smoothscale and prec_sm Bug fix after rename Remove files * change example_rmsnorm2d_fwd.cpp * update performance calculator * Fix issue in two-pass when fuse add is enabled * Remove comment of beta --------- Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,7 +14,7 @@ struct Layernorm2dFwdHostArgs
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_x_bias; // [1, n], bias, prec same as input
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
@@ -43,16 +43,16 @@ struct Layernorm2dFwd
|
||||
using Epilogue = remove_cvref_t<Epilogue_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
// for simplicity, shortcut input/output type is same as X
|
||||
using XResidualDataType = XDataType;
|
||||
@@ -84,7 +84,7 @@ struct Layernorm2dFwd
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_x_bias; // [1, n], bias, prec same as input
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
@@ -111,7 +111,7 @@ struct Layernorm2dFwd
|
||||
{
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_x_residual,
|
||||
hargs.p_x_scale,
|
||||
hargs.p_sm_scale,
|
||||
hargs.p_x_bias,
|
||||
hargs.p_gamma,
|
||||
hargs.p_beta,
|
||||
@@ -171,7 +171,7 @@ struct Layernorm2dFwd
|
||||
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
|
||||
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name);
|
||||
base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
|
||||
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
|
||||
@@ -356,18 +356,18 @@ struct Layernorm2dFwd
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
auto x_scale_window = [&]() {
|
||||
auto sm_scale_window = [&]() {
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
const auto win_ = [&]() {
|
||||
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<const XScaleDataType*>(kargs.p_x_scale),
|
||||
static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
|
||||
make_tuple(kargs.n),
|
||||
number<Vector_N>{});
|
||||
|
||||
return pad_tensor_view(tmp_0_,
|
||||
make_tuple(number<Block_N>{}),
|
||||
sequence<false>{}); // x_scale no need pad
|
||||
sequence<false>{}); // sm_scale no need pad
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
|
||||
}
|
||||
@@ -405,7 +405,7 @@ struct Layernorm2dFwd
|
||||
y_residual_window,
|
||||
mean_window,
|
||||
inv_std_window,
|
||||
x_scale_window,
|
||||
sm_scale_window,
|
||||
y_scale_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.n,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
typename YResidualWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow,
|
||||
typename XScaleWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
const YResidualWindow& y_residual_window_,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
const XScaleWindow& x_scale_window_,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
|
||||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem);
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
|
||||
}
|
||||
else
|
||||
Epilogue{}(y_window_, ln);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,23 +15,23 @@ template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct Layernorm2dFwdPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
typename YResidualWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow,
|
||||
typename XScaleWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
const YResidualWindow& y_residual_window_,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
const XScaleWindow& /*x_scale_window*/,
|
||||
const SmoothScaleWindow& /*sm_scale_window*/,
|
||||
YScaleWindow& /*y_scale_window*/,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
|
||||
Reference in New Issue
Block a user