[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)

* Add shortcut to RMSNorm

* Modify test for adding shortcut for RMSNorm

* Add fused parameter into tests

* 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp

* 1. Supports various stride and percisions.

* Add support of Epilogue

* Add fuse and epilogue support to rmsnorm ref

* Modify rmsnorm example

* Refactor tests/examples

* Bug fix for newly added tests/examples

* Bug fix for new tests 2

* Modify smoke test scripts

remove dbg code

* Supports non-smooth dyanmic quant

* Update Rmsnorm2dFwd::GetName()

* rename xscale and prec_sx to smoothscale and prec_sm

Bug fix after rename

Remove files

* change example_rmsnorm2d_fwd.cpp

* update performance calculator

* Fix issue in two-pass when fuse add is enabled

* Remove comment of beta

---------

Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
ruanjm
2025-01-15 10:23:48 +08:00
committed by GitHub
parent c0b90f130f
commit 04dd314883
58 changed files with 1823 additions and 1045 deletions

View File

@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
assert(x_stride >= n);
using XDataType = DataType;
using XScaleDataType = float;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
using XDataType = DataType;
using SmoothScaleDataType = float;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data());
smscale_buf.ToDevice(smscale_host.data());
constexpr bool kTwoPass = true;
@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
XScaleDataType,
SmoothScaleDataType,
ComputeDataType,
YScaleDataType,
QYDataType,
@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Kernel = ck_tile::Smoothquant<Pipeline>;
ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
m,
@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// smooth outlier
{
auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_));
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_)
{
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale;
y_host(m_, n_) = v_x * v_smscale;
}
};
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())(
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency());
}