mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)
* Add shortcut to RMSNorm
* Modify test for adding shortcut for RMSNorm
* Add fused parameter into tests
* 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp
* 1. Supports various stride and percisions.
* Add support of Epilogue
* Add fuse and epilogue support to rmsnorm ref
* Modify rmsnorm example
* Refactor tests/examples
* Bug fix for newly added tests/examples
* Bug fix for new tests 2
* Modify smoke test scripts
remove dbg code
* Supports non-smooth dyanmic quant
* Update Rmsnorm2dFwd::GetName()
* rename xscale and prec_sx to smoothscale and prec_sm
Bug fix after rename
Remove files
* change example_rmsnorm2d_fwd.cpp
* update performance calculator
* Fix issue in two-pass when fuse add is enabled
* Remove comment of beta
---------
Co-authored-by: rocking <ChunYu.Lai@amd.com>
[ROCm/composable_kernel commit: 04dd314883]
This commit is contained in:
@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
assert(stride >= n);
|
||||
assert(x_stride >= n);
|
||||
|
||||
using XDataType = DataType;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = DataType;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
|
||||
@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
xscale_buf.ToDevice(xscale_host.data());
|
||||
smscale_buf.ToDevice(smscale_host.data());
|
||||
|
||||
constexpr bool kTwoPass = true;
|
||||
|
||||
@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
|
||||
XScaleDataType,
|
||||
SmoothScaleDataType,
|
||||
ComputeDataType,
|
||||
YScaleDataType,
|
||||
QYDataType,
|
||||
@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using Kernel = ck_tile::Smoothquant<Pipeline>;
|
||||
|
||||
ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(),
|
||||
xscale_buf.GetDeviceBuffer(),
|
||||
smscale_buf.GetDeviceBuffer(),
|
||||
yscale_buf.GetDeviceBuffer(),
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
m,
|
||||
@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// smooth outlier
|
||||
{
|
||||
auto f = [&](auto n_) {
|
||||
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_));
|
||||
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
|
||||
|
||||
for(int m_ = 0; m_ < m; ++m_)
|
||||
{
|
||||
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
|
||||
y_host(m_, n_) = v_x * v_xscale;
|
||||
y_host(m_, n_) = v_x * v_smscale;
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())(
|
||||
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "smoothquant.hpp"
|
||||
@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
|
||||
|
||||
using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
|
||||
typename SmoothquantTypeConfig<DataType>::XDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::XScaleDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::SmoothScaleDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::ComputeDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::YScaleDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::QYDataType,
|
||||
|
||||
@@ -66,15 +66,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using TypeConfig = SmoothquantTypeConfig<DataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using XScaleDataType = typename TypeConfig::XScaleDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
|
||||
@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
xscale_buf.ToDevice(xscale_host.data());
|
||||
smscale_buf.ToDevice(smscale_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
|
||||
@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
smoothquant_traits traits{data_type};
|
||||
|
||||
smoothquant_args args{x_buf.GetDeviceBuffer(),
|
||||
xscale_buf.GetDeviceBuffer(),
|
||||
smscale_buf.GetDeviceBuffer(),
|
||||
yscale_buf.GetDeviceBuffer(),
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
m,
|
||||
@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float ave_time = smoothquant(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n +
|
||||
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n +
|
||||
sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// smooth outlier
|
||||
{
|
||||
auto f = [&](auto n_) {
|
||||
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_));
|
||||
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
|
||||
|
||||
for(int m_ = 0; m_ < m; ++m_)
|
||||
{
|
||||
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
|
||||
y_host(m_, n_) = v_x * v_xscale;
|
||||
y_host(m_, n_) = v_x * v_smscale;
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())(
|
||||
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig;
|
||||
template <>
|
||||
struct SmoothquantTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = ck_tile::half_t;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SmoothquantTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
|
||||
Reference in New Issue
Block a user