mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)
* Add shortcut to RMSNorm * Modify test for adding shortcut for RMSNorm * Add fused parameter into tests * 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp * 1. Supports various stride and percisions. * Add support of Epilogue * Add fuse and epilogue support to rmsnorm ref * Modify rmsnorm example * Refactor tests/examples * Bug fix for newly added tests/examples * Bug fix for new tests 2 * Modify smoke test scripts remove dbg code * Supports non-smooth dyanmic quant * Update Rmsnorm2dFwd::GetName() * rename xscale and prec_sx to smoothscale and prec_sm Bug fix after rename Remove files * change example_rmsnorm2d_fwd.cpp * update performance calculator * Fix issue in two-pass when fuse add is enabled * Remove comment of beta --------- Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -11,11 +11,11 @@ namespace ck_tile {
|
||||
// host side args
|
||||
struct SmoothquantHostArgs
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_xscale; // [1, n], input, columnwise scale, fp32
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_smscale; // [1, n], input, columnwise scale, fp32
|
||||
|
||||
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_xscale)
|
||||
void* p_qy; // [m, n], output, p_x * p_xscale / p_yscale
|
||||
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_smscale)
|
||||
void* p_qy; // [m, n], output, p_x * p_smscale / p_yscale
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
@@ -30,11 +30,11 @@ struct Smoothquant
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
@@ -52,7 +52,7 @@ struct Smoothquant
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_xscale;
|
||||
const void* p_smscale;
|
||||
|
||||
void* p_yscale;
|
||||
void* p_qy;
|
||||
@@ -67,7 +67,7 @@ struct Smoothquant
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_xscale,
|
||||
hargs.p_smscale,
|
||||
hargs.p_yscale,
|
||||
hargs.p_qy,
|
||||
hargs.m,
|
||||
@@ -134,9 +134,9 @@ struct Smoothquant
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
const auto xscale_window = [&]() {
|
||||
const auto smscale_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XScaleDataType*>(kargs.p_xscale),
|
||||
static_cast<const SmoothScaleDataType*>(kargs.p_smscale),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<Vector_N>{},
|
||||
@@ -177,7 +177,7 @@ struct Smoothquant
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.n, smem);
|
||||
Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.n, smem);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user