mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] Add Various Fusion Functions to RMSNorm (#1802)
* Add shortcut to RMSNorm
* Modify test for adding shortcut for RMSNorm
* Add fused parameter into tests
* 1. Add YDataType. 2. rmsnorm2d_fwd_traits_ from rmsnorm2d_fwd.hpp to rmsnorm2d_fwd_api.cpp and rmsnorm2d_fwd_instance_common.hpp
* 1. Supports various stride and percisions.
* Add support of Epilogue
* Add fuse and epilogue support to rmsnorm ref
* Modify rmsnorm example
* Refactor tests/examples
* Bug fix for newly added tests/examples
* Bug fix for new tests 2
* Modify smoke test scripts
remove dbg code
* Supports non-smooth dyanmic quant
* Update Rmsnorm2dFwd::GetName()
* rename xscale and prec_sx to smoothscale and prec_sm
Bug fix after rename
Remove files
* change example_rmsnorm2d_fwd.cpp
* update performance calculator
* Fix issue in two-pass when fuse add is enabled
* Remove comment of beta
---------
Co-authored-by: rocking <ChunYu.Lai@amd.com>
[ROCm/composable_kernel commit: 04dd314883]
This commit is contained in:
@@ -59,7 +59,7 @@ args:
|
||||
-kname print kernel name or not (default:1)
|
||||
-prec_i input precision (default:fp16)
|
||||
-prec_o output precision, set auto will be the same as input (default:auto)
|
||||
-prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
|
||||
-prec_sm output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
|
||||
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
|
||||
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
|
||||
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
|
||||
@@ -69,7 +69,7 @@ args:
|
||||
```
|
||||
|
||||
## limitations
|
||||
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
|
||||
Note that `fquant=2`, `fadd=2`, `prec_sm/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
|
||||
|
||||
```
|
||||
# some case
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
@@ -52,7 +52,7 @@ class layernorm_fwd_codegen:
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
@@ -71,7 +71,7 @@ struct layernorm2d_fwd_traits_
|
||||
{
|
||||
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
||||
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
|
||||
using XScaleDataType = ck_tile::remove_cvref_t<XScaleDataType_>;
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
@@ -135,7 +135,7 @@ struct layernorm2d_fwd_traits_
|
||||
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
@@ -152,7 +152,7 @@ template <typename XDataType_,
|
||||
int kFusedQuant_>
|
||||
using traits_ = layernorm2d_fwd_traits_<XDataType_,
|
||||
YDataType_,
|
||||
XScaleDataType_,
|
||||
SmoothScaleDataType_,
|
||||
YScaleDataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
@@ -170,7 +170,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
|
||||
"""
|
||||
API_COMMON_HEADER = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
@@ -189,9 +189,9 @@ float layernorm2d_fwd_(const S& s, A a)
|
||||
{{
|
||||
using XDataType = typename Traits_::XDataType;
|
||||
using YDataType = typename Traits_::YDataType;
|
||||
using XScaleDataType = typename Traits_::XScaleDataType;
|
||||
using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
|
||||
using YScaleDataType = typename Traits_::YScaleDataType;
|
||||
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType;
|
||||
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
|
||||
|
||||
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
|
||||
Traits_::kSaveMeanInvStd,
|
||||
@@ -202,16 +202,16 @@ float layernorm2d_fwd_(const S& s, A a)
|
||||
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XBiasDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::MeanDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::InvStdDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XScaleDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YScaleDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XBiasDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::BetaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::MeanDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvStdDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
|
||||
typename Traits_::Shape,
|
||||
PipelineTraits>;
|
||||
|
||||
@@ -224,7 +224,7 @@ float layernorm2d_fwd_(const S& s, A a)
|
||||
|
||||
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
|
||||
static constexpr bool UseRawStore = sizeof(YDataType) == 4;
|
||||
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, XScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
|
||||
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
|
||||
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>;
|
||||
|
||||
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
|
||||
@@ -249,7 +249,7 @@ float layernorm2d_fwd_(const S& s, A a)
|
||||
|
||||
API_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
@@ -285,7 +285,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
|
||||
INSTANCE_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_api_common.hpp"
|
||||
|
||||
@@ -374,7 +374,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
class h_traits:
|
||||
F_XDataType : str
|
||||
F_YDataType : str
|
||||
F_XScaleDataType : str
|
||||
F_SmoothScaleDataType : str
|
||||
F_YScaleDataType : str
|
||||
F_Repeat_M : int
|
||||
F_Repeat_N : int
|
||||
@@ -392,7 +392,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
return t_
|
||||
@@ -477,8 +477,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
if ins.F_kFusedQuant == 0:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
|
||||
elif ins.F_kFusedQuant == 1:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType)
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType)
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
|
||||
@@ -572,7 +572,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_x, scale_y = scale_type.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
|
||||
continue # skip non dynamic quant case
|
||||
if fused_quant == 1 and hs_key == 'big':
|
||||
@@ -582,8 +582,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_XScaleDataType = scale_y
|
||||
h_.F_YScaleDataType = scale_x
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_kXbias = xbias
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
|
||||
@@ -35,7 +35,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec_i", "fp16", "input precision")
|
||||
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
|
||||
.insert("prec_sx",
|
||||
.insert("prec_sm",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1")
|
||||
.insert("prec_sy",
|
||||
@@ -53,7 +53,7 @@ auto create_args(int argc, char* argv[])
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename XScaleDataType,
|
||||
typename SmoothScaleDataType,
|
||||
typename YScaleDataType,
|
||||
bool SaveMeanVar>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
@@ -75,15 +75,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sx = arg_parser.get_str("prec_sx");
|
||||
std::string prec_sm = arg_parser.get_str("prec_sm");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
prec_o = prec_i;
|
||||
}
|
||||
if(prec_sx == "auto")
|
||||
if(prec_sm == "auto")
|
||||
{
|
||||
prec_sx = "fp32";
|
||||
prec_sm = "fp32";
|
||||
}
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
@@ -105,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
assert(x_stride >= n);
|
||||
|
||||
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
|
||||
using TypeConfig =
|
||||
LayerNormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YDataType = typename TypeConfig::YDataType;
|
||||
@@ -139,12 +140,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
|
||||
|
||||
ck_tile::HostTensor<XScaleDataType> x_scale_host({n});
|
||||
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host_dev({n});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
|
||||
ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
|
||||
@@ -155,7 +156,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
|
||||
@@ -165,7 +166,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
beta_buf.ToDevice(beta_host.data());
|
||||
x_residual_buf.ToDevice(x_residual_host.data());
|
||||
x_scale_buf.ToDevice(x_scale_host.data());
|
||||
sm_scale_buf.ToDevice(sm_scale_host.data());
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
@@ -186,11 +187,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
layernorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant};
|
||||
prec_i, prec_o, prec_sm, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant};
|
||||
|
||||
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
x_bias_buf.GetDeviceBuffer(),
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
beta_buf.GetDeviceBuffer(),
|
||||
@@ -279,8 +280,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
// input smooth outlier
|
||||
acc_(m_, n_) =
|
||||
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_));
|
||||
acc_(m_, n_) = acc_(m_, n_) *
|
||||
ck_tile::type_convert<ComputeDataType>(sm_scale_host(n_));
|
||||
}
|
||||
}
|
||||
ComputeDataType absmax = static_cast<ComputeDataType>(0);
|
||||
@@ -402,16 +403,16 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sx = arg_parser.get_str("prec_sx");
|
||||
std::string prec_sm = arg_parser.get_str("prec_sm");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
prec_o = prec_i;
|
||||
}
|
||||
if(prec_sx == "auto")
|
||||
if(prec_sm == "auto")
|
||||
{
|
||||
prec_sx = "fp32";
|
||||
prec_sm = "fp32";
|
||||
}
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
@@ -420,33 +421,33 @@ int main(int argc, char* argv[])
|
||||
int save_mv = arg_parser.get_int("save_mv");
|
||||
|
||||
// no dynamic quant case
|
||||
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv)
|
||||
if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
// dynamic quant case, only in inference
|
||||
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,37 +8,40 @@
|
||||
#include "ck_tile/ops/layernorm2d.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename InType, typename OutType, typename XScaleDataType_, typename YScaleDataType_>
|
||||
template <typename InType,
|
||||
typename OutType,
|
||||
typename SmoothSScaleDataType_,
|
||||
typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig;
|
||||
|
||||
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleDataType_>
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = OutType;
|
||||
using XBiasDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
using InvStdDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
using XScaleDataType = XScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = OutType;
|
||||
using XBiasDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
using InvStdDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleDataType_>
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = OutType;
|
||||
using XBiasDataType = ck_tile::bf16_t;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using BetaDataType = ck_tile::bf16_t;
|
||||
using MeanDataType = ck_tile::bf16_t;
|
||||
using InvStdDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
using XScaleDataType = XScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = OutType;
|
||||
using XBiasDataType = ck_tile::bf16_t;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using BetaDataType = ck_tile::bf16_t;
|
||||
using MeanDataType = ck_tile::bf16_t;
|
||||
using InvStdDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
@@ -52,10 +55,10 @@ struct layernorm2d_fwd_traits
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_o; // output precision
|
||||
|
||||
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set
|
||||
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
|
||||
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
|
||||
// can set arbitrary(will skip check)
|
||||
std::string prec_sx; // x-scale, used for [1*N] input smooth quant
|
||||
std::string prec_sm; // x-scale, used for [1*N] input smooth quant
|
||||
std::string prec_sy; // y-scale, used for [M*1] output for next layer
|
||||
|
||||
bool save_mean_var; //
|
||||
|
||||
@@ -1,11 +1,34 @@
|
||||
set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd")
|
||||
set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(RMSNORM2D_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(RMSNORM2D_FWD_ENABLE_APIS ${RMSNORM2D_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/rmsnorm2d_fwd_blobs.txt RMSNORM2D_FWD_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${RMSNORM2D_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
|
||||
)
|
||||
|
||||
set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
|
||||
message("adding ${TILE_RMSNORM2D_FWD}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp)
|
||||
target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${INSTANCE_SRCS})
|
||||
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
|
||||
|
||||
set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d.hpp"
|
||||
#include <cstring>
|
||||
|
||||
@@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
assert(stride >= n);
|
||||
|
||||
using XDataType = DataType;
|
||||
using YDataType = DataType;
|
||||
using GammaDataType = DataType;
|
||||
using InvRmsDataType = ck_tile::null_type;
|
||||
using XDataType = DataType;
|
||||
using YDataType = DataType;
|
||||
using GammaDataType = DataType;
|
||||
using InvRmsDataType = ck_tile::null_type;
|
||||
using SmoothScaleDataType = ck_tile::null_type;
|
||||
using YScaleDataType = ck_tile::null_type;
|
||||
|
||||
using ComputeDataType = float;
|
||||
|
||||
@@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using BlockTile = ck_tile::sequence<2, 128>;
|
||||
using WarpTile = ck_tile::sequence<1, 64>;
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
using PipelineTraits =
|
||||
ck_tile::Rmsnorm2dFwdTraits<true, // kPadN
|
||||
false, // kSaveInvRms
|
||||
kTwoPass,
|
||||
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
|
||||
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
SmoothScaleDataType,
|
||||
YScaleDataType,
|
||||
Shape,
|
||||
true, // kPadN
|
||||
false, // kSaveInvRms
|
||||
kTwoPass>;
|
||||
PipelineTraits>;
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>;
|
||||
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::
|
||||
Default2DEpilogueProblem<ComputeDataType, YDataType, false, PipelineTraits::kPadN, false>;
|
||||
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
||||
|
||||
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Default2DEpilogue>;
|
||||
|
||||
ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
y_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
stride,
|
||||
stride,
|
||||
stride,
|
||||
stride};
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args);
|
||||
|
||||
681
example/ck_tile/10_rmsnorm2d/generate.py
Normal file
681
example/ck_tile/10_rmsnorm2d/generate.py
Normal file
@@ -0,0 +1,681 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List, Optional, Any
|
||||
import functools
|
||||
import itertools
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def get_if_str(idx, total, lase_else = True):
|
||||
if idx == 0:
|
||||
return 'if'
|
||||
elif idx < total - 1:
|
||||
return 'else if'
|
||||
else:
|
||||
if lase_else:
|
||||
return 'else'
|
||||
else:
|
||||
return 'else if'
|
||||
|
||||
FUSED_ADD_ENUM_STR_MAP = [
|
||||
'no',
|
||||
'pras', # pre-norm
|
||||
'pra' ] # post-norm
|
||||
|
||||
FUSED_FUSED_SWEEP_STR_MAP = [
|
||||
'no',
|
||||
'sdquant', # smooth dynamic quant
|
||||
'dquant' ] # dynamic quant (without sm_scale)
|
||||
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::fp16_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t'}
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
return 'true'
|
||||
else:
|
||||
return 'false'
|
||||
|
||||
|
||||
class rmsnorm_fwd_codegen:
|
||||
API_TRAITS_DEFINE = """
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
struct rmsnorm2d_fwd_traits_
|
||||
{
|
||||
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
||||
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kTwoPass_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
YDataType_,
|
||||
SmoothScaleDataType_,
|
||||
YScaleDataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveInvRms_,
|
||||
kTwoPass_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
"""
|
||||
|
||||
API_COMMON_HEADER = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = rmsnorm2d_fwd_args;
|
||||
|
||||
{F_traits_define}
|
||||
|
||||
template <typename Traits_>
|
||||
float rmsnorm2d_fwd_(const S& s, A a)
|
||||
{{
|
||||
using XDataType = typename Traits_::XDataType;
|
||||
using YDataType = typename Traits_::YDataType;
|
||||
using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
|
||||
using YScaleDataType = typename Traits_::YScaleDataType;
|
||||
using ComputeDataType = typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
|
||||
|
||||
using PipelineTraits =
|
||||
ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN,
|
||||
Traits_::kSaveInvRms,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvRmsDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
|
||||
typename Traits_::Shape,
|
||||
PipelineTraits>;
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
|
||||
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
||||
|
||||
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
|
||||
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
|
||||
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
|
||||
|
||||
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
|
||||
|
||||
using Epilogue = std::conditional_t<Traits_::kFusedQuant != 0, DynamicQuantEpilogue, Default2DEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Epilogue>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
API_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
|
||||
{F_traits_define}
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename Traits_>
|
||||
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
|
||||
|
||||
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
rmsnorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
INSTANCE_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_api_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
|
||||
"""
|
||||
|
||||
API_PER_DTYPE = """
|
||||
{F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
|
||||
{F_per_n_case}
|
||||
}}
|
||||
"""
|
||||
API_PER_N_CASE = """
|
||||
{F_if} {F_N_COND} {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
API_INNER_CASE = """
|
||||
{F_if} {F_VEC_COND}
|
||||
r={F_instance_func}(s, a);
|
||||
"""
|
||||
|
||||
def __init__(self, working_path, kernel_filter):
|
||||
self.working_path = working_path
|
||||
self.kernel_filter = kernel_filter
|
||||
|
||||
class k_fuesd_add_enum(IntEnum):
|
||||
F_NO_ADD = 0
|
||||
F_PRE_ADD = 1
|
||||
F_PRE_ADD_STORE_RESIDUAL = 2
|
||||
|
||||
class k_fused_sweep_enum(IntEnum):
|
||||
F_NO_SWEEP = 0
|
||||
F_RENORM = 1
|
||||
F_DYNAMIC_QUANT = 2
|
||||
|
||||
@dataclass
|
||||
class k_traits:
|
||||
F_kPadN : bool
|
||||
F_kSaveMeanInvStd : bool
|
||||
F_kTwoPass : bool
|
||||
F_kFusedAdd : Any
|
||||
F_kFusedQuant : Any
|
||||
|
||||
@dataclass
|
||||
class k_shape:
|
||||
F_BlockTile : List[int]
|
||||
F_WarpPerBlock : List[int]
|
||||
F_WarpTile : List[int]
|
||||
F_Vector_ : List[int]
|
||||
@property
|
||||
def F_BlockSize(self) -> int:
|
||||
return functools.reduce(lambda a, b: a*b, self.F_WarpTile)
|
||||
|
||||
@dataclass
|
||||
class k_problem:
|
||||
F_XDataType : str
|
||||
F_GammaDataType : str
|
||||
F_ComputeDataType : str
|
||||
F_YDataType : str
|
||||
F_InvRmsDataType : str
|
||||
F_BlockShape : str
|
||||
F_Traits : Any #k_traits
|
||||
|
||||
@dataclass
|
||||
class k_pipeline_one_pass:
|
||||
F_Problem : Any #k_problem
|
||||
|
||||
@dataclass
|
||||
class k_pipeline_two_pass:
|
||||
F_Problem : Any #k_problem
|
||||
|
||||
@dataclass
|
||||
class default_2d_epilogue_problem:
|
||||
F_AccDataType : str
|
||||
F_ODataType : str
|
||||
F_kPadM : bool
|
||||
F_kPadN : bool
|
||||
|
||||
@dataclass
|
||||
class default_2d_epilogue:
|
||||
F_problem : Any
|
||||
|
||||
@dataclass
|
||||
class k_kernel:
|
||||
F_pipeline : Any
|
||||
F_epilogue : Any
|
||||
|
||||
@dataclass
|
||||
class h_traits:
|
||||
F_XDataType : str
|
||||
F_YDataType : str
|
||||
F_SmoothScaleDataType : str
|
||||
F_YScaleDataType : str
|
||||
F_Repeat_M : int
|
||||
F_Repeat_N : int
|
||||
F_ThreadPerBlock_M : int
|
||||
F_ThreadPerBlock_N : int
|
||||
F_Vector_N : int
|
||||
F_kPadN : bool
|
||||
F_kSaveInvRms : bool
|
||||
F_kTwoPass : bool
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
return t_
|
||||
|
||||
# string when calling this kernel
|
||||
@property
|
||||
def call_name(self) -> str:
|
||||
return f'rmsnorm2d_fwd_<traits_<{self.trait_name}>>'
|
||||
|
||||
# string when define this kernel
|
||||
@property
|
||||
def def_name(self) -> str:
|
||||
return f'template float rmsnorm2d_fwd_<traits_<{self.trait_name}>>(const S&, A);'
|
||||
|
||||
# this class hold kernel under same source file
|
||||
@dataclass
|
||||
class h_instance:
|
||||
F_DataTypePair : str
|
||||
F_N : str
|
||||
F_add : int
|
||||
F_sweep : int
|
||||
instance_list : List[Any] # List[h_traits]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
prec_i, prec_o = self.F_DataTypePair.split(',')
|
||||
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
|
||||
nnn = f'rmsnorm2d_fwd_{dtype_str}_n{self.F_N}'
|
||||
if self.F_add != 0:
|
||||
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
|
||||
if self.F_sweep != 0:
|
||||
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
||||
return nnn
|
||||
|
||||
@property
|
||||
def instance_name(self) ->str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def content(self) ->str:
|
||||
instance_defs = ''
|
||||
for ins in self.instance_list:
|
||||
instance_defs += ins.def_name + '\n'
|
||||
return rmsnorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs)
|
||||
|
||||
@property
|
||||
def name_api(self) -> str:
|
||||
return 'rmsnorm2d_fwd_api'
|
||||
|
||||
@property
|
||||
def name_common_header(self) -> str:
|
||||
return 'rmsnorm2d_fwd_api_common'
|
||||
|
||||
@property
|
||||
def content_api(self) -> str:
|
||||
# 1 sort based on dtype
|
||||
t_dtype_dict = dict()
|
||||
blobs = self.get_blobs()
|
||||
for blob in blobs:
|
||||
if blob.F_DataTypePair not in t_dtype_dict:
|
||||
t_dtype_dict[blob.F_DataTypePair] = {}
|
||||
if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]:
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_N] = []
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob)
|
||||
|
||||
d_str = ''
|
||||
for i_d, dtype_ in enumerate(t_dtype_dict):
|
||||
blob_per_t = t_dtype_dict[dtype_]
|
||||
n_str = ''
|
||||
for i_n, n_ in enumerate(blob_per_t):
|
||||
blob_per_n = blob_per_t[n_]
|
||||
inner_str = ""
|
||||
for i_b, b_ in enumerate(blob_per_n):
|
||||
# generate single kernel instance file
|
||||
#vec_str = ""
|
||||
for i_ins, ins in enumerate(b_.instance_list):
|
||||
idx_in_n = i_b * len(b_.instance_list) + i_ins
|
||||
len_in_n = len(blob_per_n) * len(b_.instance_list)
|
||||
# _if = 'if' if i_ins == 0 else 'else if'
|
||||
if ins.F_kFusedQuant == 0:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
|
||||
elif ins.F_kFusedQuant == 1:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType)
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
|
||||
f_sweep_cond = _sweep_cond)
|
||||
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
|
||||
F_VEC_COND = _cond, F_instance_func=ins.call_name)
|
||||
#inner_str = inner_str + vec_str
|
||||
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else ''
|
||||
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
|
||||
prec_i, prec_o = dtype_.split(',')
|
||||
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
|
||||
|
||||
api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str)
|
||||
return api_base
|
||||
|
||||
@property
|
||||
def content_common_header(self) -> str:
|
||||
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
|
||||
|
||||
def get_blobs(self):
|
||||
h_traits = rmsnorm_fwd_codegen.h_traits
|
||||
h_instance = rmsnorm_fwd_codegen.h_instance
|
||||
|
||||
dynamic_quant_out_dtype = ['int8']
|
||||
# some predefined support range
|
||||
# (prec_i,prec_o) for simplicity this string will be used as key for dict
|
||||
scale_list = [('fp32,fp32')]
|
||||
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
|
||||
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
|
||||
#fused_add_list = [0, 1, 2]
|
||||
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
fused_add_list = [0, 1]
|
||||
fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
|
||||
# rm rn tm tn vn pd mv 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
|
||||
continue # skip non dynamic quant case
|
||||
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
|
||||
continue
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
list_p = w_p / 'rmsnorm2d_fwd_blobs.txt'
|
||||
blobs = self.get_blobs()
|
||||
with list_p.open('w') as list_f:
|
||||
# api related file
|
||||
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
|
||||
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
|
||||
# kernel instance file
|
||||
for b in blobs:
|
||||
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
|
||||
|
||||
def gen_blobs(self) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
(w_p / (self.name_api + ".cpp")).write_text(self.content_api)
|
||||
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
|
||||
blobs = self.get_blobs()
|
||||
for b in blobs:
|
||||
(w_p / (b.name + ".cpp")).write_text(b.content)
|
||||
|
||||
|
||||
def list_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
rmsnorm_fwd_codegen(args.working_path, args.filter).list_blobs()
|
||||
|
||||
|
||||
def gen_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
rmsnorm_fwd_codegen(args.working_path, args.filter).gen_blobs()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK rmsnorm kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--api",
|
||||
default='fwd[all]',
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma."
|
||||
)
|
||||
|
||||
# the directory for list_blobs/gen_blobs to write files into
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--working_path",
|
||||
default="./",
|
||||
required=False,
|
||||
help="the path where all the blobs are going to be generated"
|
||||
)
|
||||
|
||||
# this script have 2 modes
|
||||
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
|
||||
# this is useful in build system like cmake to construct source code dependency, by
|
||||
# reading the content out of this file
|
||||
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
|
||||
# like FA, only need to use this mode
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
action='store_true',
|
||||
help="list all the kernels to a file, "
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--gen_blobs",
|
||||
action='store_true',
|
||||
help="generate all kernels into different tile"
|
||||
)
|
||||
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--traits",
|
||||
default="all",
|
||||
required=False,
|
||||
help="enable/disable some feature. default generate all"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# print(f'{args.list_blobs}-{args.gen_blobs}')
|
||||
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
|
||||
print('gen_blobs/list_blobs must specify only one option')
|
||||
sys.exit()
|
||||
|
||||
p = Path(args.working_path)
|
||||
if not p.exists():
|
||||
p.mkdir()
|
||||
|
||||
if args.list_blobs:
|
||||
list_blobs(args)
|
||||
else:
|
||||
gen_blobs(args)
|
||||
@@ -1,146 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = rmsnorm2d_fwd_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveInvRms_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename data_type>
|
||||
float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
|
||||
rmsnorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
float r = -1;
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
if(a.n <= 64) {
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 128) {
|
||||
if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 256) {
|
||||
if (a.n % 4 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 512) {
|
||||
if (a.n % 8 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 768) {
|
||||
if (a.n % 4 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 1024) {
|
||||
if (a.n % 8 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 1536) {
|
||||
if (a.n % 8 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 2048) {
|
||||
if (a.n % 8 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 3072) {
|
||||
if (a.n % 8 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 4096) {
|
||||
if (a.n % 8 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n > 4096) {
|
||||
if (a.n % 8 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, true>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, true>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, true>>(s, a);
|
||||
else
|
||||
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
|
||||
}
|
||||
return r;
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
return rmsnorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
|
||||
}
|
||||
else if(t.data_type.compare("bf16") == 0)
|
||||
{
|
||||
return rmsnorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("Without supported instances!");
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
#if 0
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,13 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,13 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,22 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
#if 0
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,13 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,13 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,65 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = rmsnorm2d_fwd_args;
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = rmsnorm2d_fwd_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveInvRms_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename Traits_>
|
||||
float rmsnorm2d_fwd_(const S& s, A a)
|
||||
{
|
||||
using DataType = typename Traits_::DataType;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<DataType>::XDataType,
|
||||
typename RmsnormTypeConfig<DataType>::GammaDataType,
|
||||
typename RmsnormTypeConfig<DataType>::ComputeDataType,
|
||||
typename RmsnormTypeConfig<DataType>::YDataType,
|
||||
typename RmsnormTypeConfig<DataType>::InvRmsDataType,
|
||||
typename Traits_::Shape,
|
||||
Traits_::kPadN,
|
||||
Traits_::kSaveInvRms,
|
||||
Traits_::kTwoPass>;
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
|
||||
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
@@ -19,17 +19,37 @@ auto get_elimit<ck_tile::bf16_t>()
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::int8_t>()
|
||||
{
|
||||
double rtol = 1e-02;
|
||||
double atol = 1.0;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
|
||||
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
|
||||
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
|
||||
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("prec_i", "fp16", "input precision")
|
||||
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
|
||||
.insert("prec_sm",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1")
|
||||
.insert("prec_sy",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1 or 2")
|
||||
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
@@ -37,28 +57,68 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType, bool SaveRms>
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename SmoothScaleDataType,
|
||||
typename YScaleDataType,
|
||||
bool SaveRms>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = n;
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
assert(stride >= n);
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
x_stride = n;
|
||||
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
|
||||
if(xr_stride < 0)
|
||||
xr_stride = n;
|
||||
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
|
||||
if(y_stride < 0)
|
||||
y_stride = n;
|
||||
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
|
||||
if(yr_stride < 0)
|
||||
yr_stride = n;
|
||||
assert(x_stride >= n);
|
||||
|
||||
using TypeConfig = RmsnormTypeConfig<DataType>;
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sm = arg_parser.get_str("prec_sm");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
prec_o = prec_i;
|
||||
}
|
||||
if(prec_sm == "auto")
|
||||
{
|
||||
prec_sm = "fp32";
|
||||
}
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YDataType = typename TypeConfig::YDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8")
|
||||
{
|
||||
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
using TypeConfig =
|
||||
RmsnormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YDataType = typename TypeConfig::YDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
using InvRmsDataType =
|
||||
std::conditional_t<SaveRms, typename TypeConfig::InvRmsDataType, ck_tile::null_type>;
|
||||
@@ -66,43 +126,84 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host_dev({n});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
|
||||
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
x_residual_buf.ToDevice(x_residual_host.data());
|
||||
sm_scale_buf.ToDevice(sm_scale_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_o)
|
||||
{
|
||||
base_str += "|" + prec_o;
|
||||
}
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
base_str += std::string("(") + prec_sy + ")";
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
rmsnorm2d_fwd_traits traits{data_type, SaveRms};
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
rmsnorm2d_fwd_traits traits{prec_i, prec_o, prec_sm, prec_sy, SaveRms, fused_add, fused_quant};
|
||||
|
||||
rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
y_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
nullptr, // p_invRms, unsupported yet
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
stride};
|
||||
x_stride, // x row_stride
|
||||
xr_stride, // x residule row stride
|
||||
y_stride, // y row stride
|
||||
yr_stride}; // y residule row stride
|
||||
|
||||
float ave_time = rmsnorm2d_fwd(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
std::size_t num_byte =
|
||||
sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n;
|
||||
num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0;
|
||||
num_byte += fused_add ? sizeof(XResidualDataType) * m * n : 0;
|
||||
num_byte += ((fused_quant == 1) || (fused_quant == 2)) ? sizeof(YScaleDataType) * m : 0;
|
||||
num_byte += (fused_quant == 1) ? sizeof(SmoothScaleDataType) * n : 0;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
||||
@@ -112,38 +213,131 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon);
|
||||
if(fused_add != 0)
|
||||
{
|
||||
// fused pre_add/pre_add_store
|
||||
// TODO we accumulate directly to x_host for simplcity here...
|
||||
std::transform(x_host.mData.cbegin(),
|
||||
x_host.mData.cend(),
|
||||
x_residual_host.mData.cbegin(),
|
||||
x_host.mData.begin(),
|
||||
[](auto x_, auto r_) {
|
||||
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
|
||||
ck_tile::type_convert<ComputeDataType>(r_);
|
||||
return ck_tile::type_convert<XDataType>(o_);
|
||||
});
|
||||
}
|
||||
|
||||
if(fused_quant != 0)
|
||||
{
|
||||
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
|
||||
int N_ = acc_.mDesc.get_lengths()[1];
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
// input smooth outlier
|
||||
acc_(m_, n_) = acc_(m_, n_) *
|
||||
ck_tile::type_convert<ComputeDataType>(sm_scale_host(n_));
|
||||
}
|
||||
}
|
||||
ComputeDataType absmax = static_cast<ComputeDataType>(0);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
const auto a = ck_tile::abs(acc_(m_, n_));
|
||||
absmax = a > absmax ? a : absmax;
|
||||
}
|
||||
// printf("cpu:absmax:%f\n", absmax);
|
||||
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
|
||||
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon, dquant_functor);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon);
|
||||
}
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>();
|
||||
if(stride == n)
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
|
||||
if(fused_add == 1)
|
||||
{
|
||||
y_residual_buf.FromDevice(y_residual_host_dev.data());
|
||||
}
|
||||
|
||||
auto [rtol, atol] = get_elimit<YDataType>();
|
||||
if(x_stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(
|
||||
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
y_host_dev, y_host_ref, std::string("\nOUT Error: Incorrect results!"), rtol, atol);
|
||||
|
||||
if(fused_add == 1)
|
||||
{
|
||||
pass &= ck_tile::check_err(y_residual_host_dev,
|
||||
x_host,
|
||||
std::string("\nADD Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
|
||||
y_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
|
||||
y_host_ref.begin() + i_r * stride + n);
|
||||
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
|
||||
y_host_dev.begin() + i_r * y_stride + n);
|
||||
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
|
||||
y_host_ref.begin() + i_r * y_stride + n);
|
||||
pass &= ck_tile::check_err(y_host_dev_row,
|
||||
y_host_ref_row,
|
||||
std::string("OUT[") + std::to_string(i_r) +
|
||||
std::string("\nOUT[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
|
||||
if(fused_add == 1)
|
||||
{
|
||||
std::vector<YResidualDataType> y_residual_host_dev_row(
|
||||
y_residual_host_dev.begin() + i_r * yr_stride,
|
||||
y_residual_host_dev.begin() + i_r * yr_stride + n);
|
||||
std::vector<YResidualDataType> y_residual_host_ref_row(
|
||||
x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
|
||||
pass &= ck_tile::check_err(y_residual_host_dev_row,
|
||||
y_residual_host_ref_row,
|
||||
std::string("\nADD[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
y_scale_buf.FromDevice(y_scale_host_dev.data());
|
||||
pass &= ck_tile::check_err(y_scale_host_dev,
|
||||
y_scale_host_ref,
|
||||
std::string("\nSCALE Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
@@ -156,23 +350,55 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
int save_rms = arg_parser.get_int("save_rms");
|
||||
if(data_type == "fp16" && save_rms)
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sm = arg_parser.get_str("prec_sm");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2;
|
||||
prec_o = prec_i;
|
||||
}
|
||||
else if(data_type == "fp16" && !save_rms)
|
||||
if(prec_sm == "auto")
|
||||
{
|
||||
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2;
|
||||
prec_sm = "fp32";
|
||||
}
|
||||
else if(data_type == "bf16" && save_rms)
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
else if(data_type == "bf16" && !save_rms)
|
||||
|
||||
int save_rms = arg_parser.get_int("save_rms");
|
||||
|
||||
if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms)
|
||||
{
|
||||
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
save_rms)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
// dynamic quant case, only in inference
|
||||
else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -8,27 +8,34 @@
|
||||
#include "ck_tile/ops/rmsnorm2d.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename DataType>
|
||||
template <typename InType,
|
||||
typename OutType,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_>
|
||||
struct RmsnormTypeConfig;
|
||||
|
||||
template <>
|
||||
struct RmsnormTypeConfig<ck_tile::half_t>
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct RmsnormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using InvRmsDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using InvRmsDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RmsnormTypeConfig<ck_tile::bf16_t>
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct RmsnormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = ck_tile::bf16_t;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using InvRmsDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using InvRmsDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
@@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kTwoPass_>
|
||||
struct rmsnorm2d_fwd_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct rmsnorm2d_fwd_traits
|
||||
{
|
||||
std::string data_type;
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_o; // output precision
|
||||
|
||||
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
|
||||
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
|
||||
// can set arbitrary(will skip check)
|
||||
std::string prec_sm; // x-scale, used for [1*N] input smooth quant
|
||||
std::string prec_sy; // y-scale, used for [M*1] output for next layer
|
||||
|
||||
bool save_rms;
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
|
||||
|
||||
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
$EXE -prec=$pr_i -m=99 -n=13
|
||||
$EXE -prec=$pr_i -m=17 -n=16
|
||||
$EXE -prec=$pr_i -m=1 -n=100
|
||||
$EXE -prec=$pr_i -m=4 -n=128
|
||||
$EXE -prec=$pr_i -m=80 -n=127
|
||||
$EXE -prec=$pr_i -m=22 -n=255 -stride=256
|
||||
$EXE -prec=$pr_i -m=7 -n=599
|
||||
$EXE -prec=$pr_i -m=19 -n=512
|
||||
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000
|
||||
$EXE -prec=$pr_i -m=11 -n=510
|
||||
$EXE -prec=$pr_i -m=171 -n=676 -stride=818
|
||||
$EXE -prec=$pr_i -m=91 -n=636
|
||||
$EXE -prec=$pr_i -m=12 -n=768 -stride=800
|
||||
$EXE -prec=$pr_i -m=100 -n=766 -stride=812
|
||||
$EXE -prec=$pr_i -m=31 -n=1024
|
||||
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec=$pr_i -m=8 -n=1501
|
||||
$EXE -prec=$pr_i -m=3 -n=1826
|
||||
$EXE -prec=$pr_i -m=5 -n=2040
|
||||
$EXE -prec=$pr_i -m=7 -n=2734
|
||||
$EXE -prec=$pr_i -m=1 -n=3182
|
||||
$EXE -prec=$pr_i -m=9 -n=4096
|
||||
$EXE -prec=$pr_i -m=3 -n=8192
|
||||
$EXE -prec=$pr_i -m=1 -n=10547
|
||||
$EXE -prec=$pr_i -m=3 -n=17134
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
assert(stride >= n);
|
||||
assert(x_stride >= n);
|
||||
|
||||
using XDataType = DataType;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = DataType;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
|
||||
@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
xscale_buf.ToDevice(xscale_host.data());
|
||||
smscale_buf.ToDevice(smscale_host.data());
|
||||
|
||||
constexpr bool kTwoPass = true;
|
||||
|
||||
@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
|
||||
XScaleDataType,
|
||||
SmoothScaleDataType,
|
||||
ComputeDataType,
|
||||
YScaleDataType,
|
||||
QYDataType,
|
||||
@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using Kernel = ck_tile::Smoothquant<Pipeline>;
|
||||
|
||||
ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(),
|
||||
xscale_buf.GetDeviceBuffer(),
|
||||
smscale_buf.GetDeviceBuffer(),
|
||||
yscale_buf.GetDeviceBuffer(),
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
m,
|
||||
@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// smooth outlier
|
||||
{
|
||||
auto f = [&](auto n_) {
|
||||
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_));
|
||||
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
|
||||
|
||||
for(int m_ = 0; m_ < m; ++m_)
|
||||
{
|
||||
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
|
||||
y_host(m_, n_) = v_x * v_xscale;
|
||||
y_host(m_, n_) = v_x * v_smscale;
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())(
|
||||
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "smoothquant.hpp"
|
||||
@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
|
||||
|
||||
using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
|
||||
typename SmoothquantTypeConfig<DataType>::XDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::XScaleDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::SmoothScaleDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::ComputeDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::YScaleDataType,
|
||||
typename SmoothquantTypeConfig<DataType>::QYDataType,
|
||||
|
||||
@@ -66,15 +66,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using TypeConfig = SmoothquantTypeConfig<DataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using XScaleDataType = typename TypeConfig::XScaleDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
|
||||
@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
xscale_buf.ToDevice(xscale_host.data());
|
||||
smscale_buf.ToDevice(smscale_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
|
||||
@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
smoothquant_traits traits{data_type};
|
||||
|
||||
smoothquant_args args{x_buf.GetDeviceBuffer(),
|
||||
xscale_buf.GetDeviceBuffer(),
|
||||
smscale_buf.GetDeviceBuffer(),
|
||||
yscale_buf.GetDeviceBuffer(),
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
m,
|
||||
@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float ave_time = smoothquant(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n +
|
||||
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n +
|
||||
sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// smooth outlier
|
||||
{
|
||||
auto f = [&](auto n_) {
|
||||
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_));
|
||||
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
|
||||
|
||||
for(int m_ = 0; m_ < m; ++m_)
|
||||
{
|
||||
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
|
||||
y_host(m_, n_) = v_x * v_xscale;
|
||||
y_host(m_, n_) = v_x * v_smscale;
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())(
|
||||
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig;
|
||||
template <>
|
||||
struct SmoothquantTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = ck_tile::half_t;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SmoothquantTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "moe_smoothquant.hpp"
|
||||
@@ -35,7 +35,7 @@ float moe_smoothquant_(const S& s, A a)
|
||||
|
||||
using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
|
||||
typename MoeSmoothquantTypeConfig<DataType>::XDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::XScaleDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::SmoothScaleDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::ComputeDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::YScaleDataType,
|
||||
typename MoeSmoothquantTypeConfig<DataType>::QYDataType,
|
||||
|
||||
@@ -91,15 +91,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using TypeConfig = MoeSmoothquantTypeConfig<DataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using XScaleDataType = typename TypeConfig::XScaleDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<XScaleDataType> xscale_host({experts * hidden_size});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({experts * hidden_size});
|
||||
ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1});
|
||||
@@ -110,16 +110,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937);
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
xscale_buf.ToDevice(xscale_host.data());
|
||||
smscale_buf.ToDevice(smscale_host.data());
|
||||
topk_ids_buf.ToDevice(topk_ids_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
@@ -129,7 +129,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
moe_smoothquant_traits traits{data_type};
|
||||
|
||||
moe_smoothquant_args args{x_buf.GetDeviceBuffer(),
|
||||
xscale_buf.GetDeviceBuffer(),
|
||||
smscale_buf.GetDeviceBuffer(),
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
yscale_buf.GetDeviceBuffer(),
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
@@ -143,9 +143,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float ave_time = moe_smoothquant(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
std::size_t num_byte =
|
||||
sizeof(XDataType) * tokens * hidden_size + sizeof(XScaleDataType) * topk * hidden_size +
|
||||
sizeof(YScaleDataType) * topk * tokens + sizeof(QYDataType) * topk * tokens * hidden_size;
|
||||
std::size_t num_byte = sizeof(XDataType) * tokens * hidden_size +
|
||||
sizeof(SmoothScaleDataType) * topk * hidden_size +
|
||||
sizeof(YScaleDataType) * topk * tokens +
|
||||
sizeof(QYDataType) * topk * tokens * hidden_size;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
||||
@@ -165,11 +166,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
for(int i_h = 0; i_h < hidden_size; ++i_h)
|
||||
{
|
||||
auto v_xscale = ck_tile::type_convert<ComputeDataType>(
|
||||
xscale_host(i_expert * hidden_size + i_h));
|
||||
auto v_smscale = ck_tile::type_convert<ComputeDataType>(
|
||||
smscale_host(i_expert * hidden_size + i_h));
|
||||
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h));
|
||||
// y_host(i_token * topk + i_topk, i_h) = v_x * v_xscale;
|
||||
y_host(i_topk * tokens + i_token, i_h) = v_x * v_xscale;
|
||||
// y_host(i_token * topk + i_topk, i_h) = v_x * v_smscale;
|
||||
y_host(i_topk * tokens + i_token, i_h) = v_x * v_smscale;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,21 +14,21 @@ struct MoeSmoothquantTypeConfig;
|
||||
template <>
|
||||
struct MoeSmoothquantTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = ck_tile::half_t;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MoeSmoothquantTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using XScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using SmoothScaleDataType = float;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
|
||||
@@ -8,16 +8,40 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Note: for simplicity, each functor only care about single M
|
||||
struct reference_rmsnorm2d_default_epilogue
|
||||
{
|
||||
template <typename OutDataType, typename AccDataType>
|
||||
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
|
||||
{
|
||||
const int N = acc.mDesc.get_lengths()[1];
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutDataType, typename AccDataType>
|
||||
auto operator()(int m, const HostTensor<AccDataType>& acc)
|
||||
{
|
||||
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
|
||||
operator()(m, o, acc);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename InvRmsDataType>
|
||||
typename InvRmsDataType,
|
||||
typename Epilogue = reference_rmsnorm2d_default_epilogue>
|
||||
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
const HostTensor<GammaDataType>& gamma_n,
|
||||
HostTensor<YDataType>& y_m_n,
|
||||
HostTensor<InvRmsDataType>& invRms_m,
|
||||
ComputeDataType epsilon)
|
||||
ComputeDataType epsilon,
|
||||
Epilogue epilogue_functor = {})
|
||||
{
|
||||
auto rmsnorm2d_fwd_func = [&](auto m) {
|
||||
const int N = x_m_n.mDesc.get_lengths()[1];
|
||||
@@ -37,13 +61,15 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
|
||||
invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
|
||||
|
||||
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
|
||||
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
|
||||
auto y = x * divisor * gamma;
|
||||
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y);
|
||||
acc(m, n) = x * divisor * gamma;
|
||||
}
|
||||
|
||||
epilogue_functor(m, y_m_n, acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -24,19 +24,19 @@ struct DynamicQuantEpilogueTraits
|
||||
|
||||
// this epilogue just store out a M*N matrix, row major
|
||||
template <typename AccDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct DynamicQuantEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
// TODO: we should put descriptor creation function into policy
|
||||
@@ -45,7 +45,7 @@ struct DynamicQuantEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
|
||||
@@ -78,7 +78,7 @@ struct DynamicQuantEpilogue
|
||||
#if 0
|
||||
// don't remove this
|
||||
// Note that if we set encoding purposely like this, you will result in compile fail
|
||||
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length)
|
||||
// TODO: sm_scale create local-scratch to accept arbitrary acc input (with same length)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
@@ -105,34 +105,18 @@ struct DynamicQuantEpilogue
|
||||
return reduce_crosswarp_sync.GetSmemSize();
|
||||
}
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp,
|
||||
typename XScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const XScaleWindow& x_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
|
||||
CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
auto reduce = GetBlockReduce2d();
|
||||
auto reduce_sync = GetBlockReduce2dSync();
|
||||
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
|
||||
const auto x_scale_window =
|
||||
make_tile_window(x_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
|
||||
auto x_scale = load_tile(x_scale_window);
|
||||
|
||||
auto o_acc_tmp = o_acc_tile;
|
||||
|
||||
sweep_tile(o_acc_tmp, [&](auto idx) {
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]);
|
||||
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
|
||||
});
|
||||
|
||||
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
|
||||
|
||||
auto row_absmax = [&]() {
|
||||
@@ -184,5 +168,45 @@ struct DynamicQuantEpilogue
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
|
||||
// Smooth Dynamic Quant
|
||||
template <typename ODramWindowTmp,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
const auto sm_scale_window =
|
||||
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
|
||||
auto sm_scale = load_tile(sm_scale_window);
|
||||
|
||||
auto o_acc_tmp = o_acc_tile;
|
||||
|
||||
sweep_tile(o_acc_tmp, [&](auto idx) {
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
|
||||
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
|
||||
});
|
||||
|
||||
Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem);
|
||||
}
|
||||
|
||||
// Dynamic Quant
|
||||
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,7 +14,7 @@ struct Layernorm2dFwdHostArgs
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_x_bias; // [1, n], bias, prec same as input
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
@@ -43,16 +43,16 @@ struct Layernorm2dFwd
|
||||
using Epilogue = remove_cvref_t<Epilogue_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XBiasDataType = remove_cvref_t<typename Problem::XBiasDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
|
||||
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
// for simplicity, shortcut input/output type is same as X
|
||||
using XResidualDataType = XDataType;
|
||||
@@ -84,7 +84,7 @@ struct Layernorm2dFwd
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_x_bias; // [1, n], bias, prec same as input
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_beta; // [1, n], beta, prec same as input
|
||||
@@ -111,7 +111,7 @@ struct Layernorm2dFwd
|
||||
{
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_x_residual,
|
||||
hargs.p_x_scale,
|
||||
hargs.p_sm_scale,
|
||||
hargs.p_x_bias,
|
||||
hargs.p_gamma,
|
||||
hargs.p_beta,
|
||||
@@ -171,7 +171,7 @@ struct Layernorm2dFwd
|
||||
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
|
||||
base_str += _SS_("_sx") + _SS_(t2s<XScaleDataType>::name);
|
||||
base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
|
||||
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) {
|
||||
@@ -356,18 +356,18 @@ struct Layernorm2dFwd
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
auto x_scale_window = [&]() {
|
||||
auto sm_scale_window = [&]() {
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
const auto win_ = [&]() {
|
||||
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<const XScaleDataType*>(kargs.p_x_scale),
|
||||
static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
|
||||
make_tuple(kargs.n),
|
||||
number<Vector_N>{});
|
||||
|
||||
return pad_tensor_view(tmp_0_,
|
||||
make_tuple(number<Block_N>{}),
|
||||
sequence<false>{}); // x_scale no need pad
|
||||
sequence<false>{}); // sm_scale no need pad
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
|
||||
}
|
||||
@@ -405,7 +405,7 @@ struct Layernorm2dFwd
|
||||
y_residual_window,
|
||||
mean_window,
|
||||
inv_std_window,
|
||||
x_scale_window,
|
||||
sm_scale_window,
|
||||
y_scale_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.n,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
typename YResidualWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow,
|
||||
typename XScaleWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
const YResidualWindow& y_residual_window_,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
const XScaleWindow& x_scale_window_,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
|
||||
kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem);
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
|
||||
}
|
||||
else
|
||||
Epilogue{}(y_window_, ln);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,23 +15,23 @@ template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct Layernorm2dFwdPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using XBiasDataType = remove_cvref_t<XBiasDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
typename YResidualWindow,
|
||||
typename MeanWindow,
|
||||
typename InvStdWindow,
|
||||
typename XScaleWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
const YResidualWindow& y_residual_window_,
|
||||
MeanWindow& mean_window,
|
||||
InvStdWindow& inv_std_window,
|
||||
const XScaleWindow& /*x_scale_window*/,
|
||||
const SmoothScaleWindow& /*sm_scale_window*/,
|
||||
YScaleWindow& /*y_scale_window*/,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
|
||||
@@ -8,5 +8,6 @@
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -1,50 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// host side args
|
||||
struct Rmsnorm2dFwdHostArgs
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
|
||||
const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
|
||||
const void* p_gamma; // [1, n], gamma, prec same as input
|
||||
|
||||
void* p_y; // [m, n], output, fp16/bf16
|
||||
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
|
||||
void* p_y; // [m, n], output, fp16/bf16
|
||||
void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
|
||||
void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
|
||||
void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
|
||||
|
||||
float epsilon;
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
index_t x_stride; // x row_stride
|
||||
index_t xr_stride; // x residule row stride
|
||||
index_t y_stride; // y row stride
|
||||
index_t yr_stride; // y residule row stride
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
template <typename Pipeline_>
|
||||
template <typename Pipeline_, typename Epilogue_>
|
||||
struct Rmsnorm2dFwd
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Epilogue = remove_cvref_t<Epilogue_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<typename Problem::YDataType>;
|
||||
using InvRmsDataType = remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
// for simplicity, shortcut input/output type is same as X
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::kTwoPass;
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_x_residual;
|
||||
const void* p_sm_scale;
|
||||
const void* p_gamma;
|
||||
|
||||
void* p_y;
|
||||
void* p_y_residual;
|
||||
void* p_y_scale;
|
||||
void* p_invRms;
|
||||
|
||||
float epsilon;
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
index_t x_stride; // x row_stride
|
||||
index_t xr_stride; // x residule row stride
|
||||
index_t y_stride; // y row stride
|
||||
index_t yr_stride; // y residule row stride
|
||||
};
|
||||
using Hargs = Rmsnorm2dFwdHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_x_residual,
|
||||
hargs.p_sm_scale,
|
||||
hargs.p_gamma,
|
||||
hargs.p_y,
|
||||
hargs.p_y_residual,
|
||||
hargs.p_y_scale,
|
||||
hargs.p_invRms,
|
||||
hargs.epsilon,
|
||||
hargs.m,
|
||||
hargs.n,
|
||||
hargs.stride};
|
||||
hargs.x_stride,
|
||||
hargs.xr_stride,
|
||||
hargs.y_stride,
|
||||
hargs.yr_stride};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
|
||||
// clang-format on
|
||||
|
||||
// in byte
|
||||
@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
// clang-format off
|
||||
using S_ = typename Problem::BlockShape;
|
||||
auto surfix = [&] () {
|
||||
std::string n;
|
||||
if (kFusedAdd != Rmsnorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Rmsnorm2dFusedAddEnumName<kFusedAdd>::name;
|
||||
if (kFusedQuant != Rmsnorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Rmsnorm2dFusedQuantEnumName<kFusedQuant>::name;
|
||||
if (kPadN) n += "_pn";
|
||||
if (kSaveInvRms) n += "_rms";
|
||||
if (kTwoPass) n += "_2p";
|
||||
return n; }();
|
||||
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
return _SS_("rmsnorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
|
||||
auto prec_str = [&] () {
|
||||
std::string base_str = _SS_(t2s<XDataType>::name);
|
||||
if (!std::is_same_v<XDataType, YDataType>) {
|
||||
base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) {
|
||||
base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
|
||||
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
|
||||
}
|
||||
if (kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) {
|
||||
base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
|
||||
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
|
||||
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
|
||||
_SS_(Pipeline::name) + surfix;
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.x_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
const auto x_residual_window = [&]() {
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XResidualDataType*>(kargs.p_x_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.xr_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ = pad_tensor_view(tmp_,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto gamma_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.y_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
auto y_residual_window = [&]() {
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YResidualDataType*>(kargs.p_y_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.yr_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
auto tmp2_ = pad_tensor_view(tmp_,
|
||||
make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}, number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
auto inv_rms_window = [&]() {
|
||||
if constexpr(kSaveInvRms)
|
||||
{
|
||||
@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}();
|
||||
|
||||
auto sm_scale_window = [&]() {
|
||||
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
const auto win_ = [&]() {
|
||||
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
|
||||
make_tuple(kargs.n),
|
||||
number<Vector_N>{});
|
||||
|
||||
return pad_tensor_view(tmp_0_,
|
||||
make_tuple(number<Block_N>{}),
|
||||
sequence<false>{}); // sm_scale no need pad
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_N>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
auto y_scale_window = [&]() {
|
||||
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT ||
|
||||
kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
{
|
||||
const auto win_ = [&]() {
|
||||
const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<YScaleDataType*>(kargs.p_y_scale),
|
||||
make_tuple(kargs.m),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
tmp_0_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(make_tuple(number<Block_M>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window,
|
||||
x_residual_window,
|
||||
gamma_window,
|
||||
y_window,
|
||||
y_residual_window,
|
||||
inv_rms_window,
|
||||
sm_scale_window,
|
||||
y_scale_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.n,
|
||||
smem);
|
||||
smem,
|
||||
Epilogue{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2d<P_>{};
|
||||
@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
decltype(make_static_distributed_tensor<typename Problem::ComputeDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile<x_block_tile>());
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow>
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename GammaWindow,
|
||||
typename YWindow,
|
||||
typename YResidualWindow,
|
||||
typename InvRmsWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
YWindow& y_window,
|
||||
YWindow& y_window_,
|
||||
const YResidualWindow& y_residual_window_,
|
||||
InvRmsWindow& inv_rms_window,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window_,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
const auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
const auto x_residual_window = make_tile_window(
|
||||
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
const auto x = load_tile(x_window);
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
|
||||
// load gamma (TODO: support no gamma?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
|
||||
}
|
||||
}
|
||||
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum = block_reduce2d(
|
||||
x, reduce_square_sum_func.GetIdentityValue<ComputeDataType>(), reduce_square_sum_func);
|
||||
auto square_sum = block_reduce2d(acc,
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
reduce_square_sum_func);
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
|
||||
|
||||
// rmsnorm computation
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
|
||||
auto rmsn_ = acc[idx] * inv_rms_[i_idx] * gamma_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
rmsn(idx) = rmsn_;
|
||||
});
|
||||
store_tile(y_window, y);
|
||||
|
||||
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
{
|
||||
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, rmsn);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -12,25 +12,25 @@ template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename InvRmsDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename BlockShape_,
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kTwoPass_>
|
||||
typename Traits_>
|
||||
struct Rmsnorm2dFwdPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using InvRmsDataType = remove_cvref_t<InvRmsDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::kSaveInvRms;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename GammaWindow, typename YWindow, typename InvRmsWindow>
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename GammaWindow,
|
||||
typename YWindow,
|
||||
typename YResidualWindow,
|
||||
typename InvRmsWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
YWindow& y_window,
|
||||
const YResidualWindow& y_residual_window_,
|
||||
InvRmsWindow& inv_rms_window,
|
||||
const SmoothScaleWindow& /*sm_scale_window_*/,
|
||||
YScaleWindow& /*y_scale_window*/,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem) const
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
auto x_residual_window = make_tile_window(
|
||||
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
// Problem::BlockShape
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
using XTensorType = decltype(load_tile(x_window));
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<XTensorType>();
|
||||
using ComputeTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<ComputeTensorType>();
|
||||
set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
block_reduce2d(x, square_sum, reduce_square_sum_func);
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
move_tile_window(x_residual_window, {0, Block_N});
|
||||
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
|
||||
move_tile_window(y_residual_window, {0, Block_N});
|
||||
}
|
||||
}
|
||||
|
||||
block_reduce2d(acc, square_sum, reduce_square_sum_func);
|
||||
}
|
||||
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {stride_to_right_most_window});
|
||||
move_tile_window(y_window, {0, stride_to_right_most_window});
|
||||
|
||||
// rmsnorm computation
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
// load gamma/beta (TODO: support no gamma/beta?)
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
|
||||
{
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
});
|
||||
}
|
||||
|
||||
// load gamma (TODO: support no gamma?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
// rmsnorm computation
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
const auto x_ = type_convert<ComputeDataType>(x[idx]);
|
||||
auto y_ = x_ * inv_rms_[i_idx] * gamma_;
|
||||
auto rmsn_ = acc(idx) * inv_rms_[i_idx] * gamma_;
|
||||
|
||||
y(idx) = type_convert<YDataType>(y_);
|
||||
rmsn(idx) = rmsn_;
|
||||
});
|
||||
|
||||
store_tile(y_window, y);
|
||||
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
|
||||
Epilogue{}(y_window, rmsn);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(x_residual_window, {0, -Block_N});
|
||||
move_tile_window(gamma_window, {-Block_N});
|
||||
move_tile_window(y_window, {0, -Block_N});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class Rmsnorm2dFusedAddEnum
|
||||
{
|
||||
NO_ADD = 0,
|
||||
// fused add before RMSNorm and store result to global
|
||||
PRE_ADD_STORE = 1,
|
||||
// fused add before RMSNorm, but not store result
|
||||
PRE_ADD = 2,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Rmsnorm2dFusedAddEnum> struct Rmsnorm2dFusedAddEnumName;
|
||||
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
|
||||
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
|
||||
template<> struct Rmsnorm2dFusedAddEnumName<Rmsnorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
|
||||
// clang-format on
|
||||
|
||||
enum class Rmsnorm2dFusedQuantEnum
|
||||
{
|
||||
NO_SWEEP = 0,
|
||||
SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
|
||||
DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Rmsnorm2dFusedQuantEnum> struct Rmsnorm2dFusedQuantEnumName;
|
||||
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
|
||||
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
|
||||
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
|
||||
// clang-format on
|
||||
|
||||
template <bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kTwoPass_,
|
||||
Rmsnorm2dFusedAddEnum kFusedAdd_,
|
||||
Rmsnorm2dFusedQuantEnum kFusedQuant_>
|
||||
struct Rmsnorm2dFwdTraits
|
||||
{
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace ck_tile {
|
||||
struct MoeSmoothquantHostArgs
|
||||
{
|
||||
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
|
||||
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32
|
||||
const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
|
||||
void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale
|
||||
@@ -33,11 +33,11 @@ struct MoeSmoothquant
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
@@ -57,7 +57,7 @@ struct MoeSmoothquant
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
|
||||
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32
|
||||
const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
|
||||
void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale
|
||||
@@ -75,7 +75,7 @@ struct MoeSmoothquant
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_xscale,
|
||||
hargs.p_smscale,
|
||||
hargs.p_topk_ids,
|
||||
hargs.p_yscale,
|
||||
hargs.p_qy,
|
||||
@@ -153,9 +153,10 @@ struct MoeSmoothquant
|
||||
}();
|
||||
|
||||
// [experts, hidden_size],
|
||||
const auto xscale_window = [&]() {
|
||||
const auto smscale_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XScaleDataType*>(kargs.p_xscale) + i_expert * kargs.hidden_size,
|
||||
static_cast<const SmoothScaleDataType*>(kargs.p_smscale) +
|
||||
i_expert * kargs.hidden_size,
|
||||
make_tuple(kargs.hidden_size),
|
||||
make_tuple(1),
|
||||
number<Vector_N>{},
|
||||
@@ -198,7 +199,7 @@ struct MoeSmoothquant
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
|
||||
Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -11,11 +11,11 @@ namespace ck_tile {
|
||||
// host side args
|
||||
struct SmoothquantHostArgs
|
||||
{
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_xscale; // [1, n], input, columnwise scale, fp32
|
||||
const void* p_x; // [m ,n], input, fp16/bf16
|
||||
const void* p_smscale; // [1, n], input, columnwise scale, fp32
|
||||
|
||||
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_xscale)
|
||||
void* p_qy; // [m, n], output, p_x * p_xscale / p_yscale
|
||||
void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_smscale)
|
||||
void* p_qy; // [m, n], output, p_x * p_smscale / p_yscale
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
@@ -30,11 +30,11 @@ struct Smoothquant
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = typename Pipeline::Problem;
|
||||
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
|
||||
using XDataType = remove_cvref_t<typename Problem::XDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
@@ -52,7 +52,7 @@ struct Smoothquant
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_xscale;
|
||||
const void* p_smscale;
|
||||
|
||||
void* p_yscale;
|
||||
void* p_qy;
|
||||
@@ -67,7 +67,7 @@ struct Smoothquant
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
return Kargs{hargs.p_x,
|
||||
hargs.p_xscale,
|
||||
hargs.p_smscale,
|
||||
hargs.p_yscale,
|
||||
hargs.p_qy,
|
||||
hargs.m,
|
||||
@@ -134,9 +134,9 @@ struct Smoothquant
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
|
||||
const auto xscale_window = [&]() {
|
||||
const auto smscale_window = [&]() {
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XScaleDataType*>(kargs.p_xscale),
|
||||
static_cast<const SmoothScaleDataType*>(kargs.p_smscale),
|
||||
make_tuple(kargs.n),
|
||||
make_tuple(1),
|
||||
number<Vector_N>{},
|
||||
@@ -177,7 +177,7 @@ struct Smoothquant
|
||||
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.n, smem);
|
||||
Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.n, smem);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXScaleBlockTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeSmoothScaleBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
|
||||
@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
|
||||
template <typename XWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename QYWindow,
|
||||
typename YScaleWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XScaleWindow& xscale_window_,
|
||||
const SmoothScaleWindow& smscale_window_,
|
||||
YScaleWindow& yscale_window,
|
||||
QYWindow& qy_window,
|
||||
ck_tile::index_t,
|
||||
@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto xscale_window = make_tile_window(
|
||||
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
|
||||
auto smscale_window = make_tile_window(
|
||||
smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_absmax_func = ReduceOp::AbsMax{};
|
||||
auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
|
||||
@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass
|
||||
auto block_reduce2d_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
|
||||
|
||||
const auto x = load_tile(x_window);
|
||||
const auto xscale = load_tile(xscale_window);
|
||||
auto y = tile_elementwise_in(
|
||||
const auto x = load_tile(x_window);
|
||||
const auto smscale = load_tile(smscale_window);
|
||||
auto y = tile_elementwise_in(
|
||||
[&](const auto& a, const auto& b) {
|
||||
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
|
||||
},
|
||||
x,
|
||||
xscale);
|
||||
smscale);
|
||||
|
||||
// compute absmax, cross-lane->cross-warp
|
||||
auto absmax = [&]() {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Y = X * XScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
|
||||
// Y = X * SmoothScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
|
||||
template <typename XDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename QYDataType_,
|
||||
@@ -18,12 +18,12 @@ template <typename XDataType_,
|
||||
bool kTwoPass_>
|
||||
struct SmoothquantPipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using QYDataType = remove_cvref_t<QYDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using QYDataType = remove_cvref_t<QYDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,11 +16,11 @@ struct SmoothquantPipelineTwoPass
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
|
||||
@@ -39,9 +39,12 @@ struct SmoothquantPipelineTwoPass
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
|
||||
template <typename XWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename QYWindow,
|
||||
typename YScaleWindow>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XScaleWindow& xscale_window_,
|
||||
const SmoothScaleWindow& smscale_window_,
|
||||
YScaleWindow& yscale_window,
|
||||
QYWindow& qy_window,
|
||||
ck_tile::index_t row_size,
|
||||
@@ -49,8 +52,8 @@ struct SmoothquantPipelineTwoPass
|
||||
{
|
||||
auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto xscale_window = make_tile_window(
|
||||
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
|
||||
auto smscale_window = make_tile_window(
|
||||
smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
|
||||
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
index_t num_n_tile_iteration =
|
||||
@@ -76,14 +79,14 @@ struct SmoothquantPipelineTwoPass
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
const auto xscale = load_tile(xscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
const auto x = load_tile(x_window);
|
||||
const auto smscale = load_tile(smscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
[&](const auto& a, const auto& b) {
|
||||
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
|
||||
},
|
||||
x,
|
||||
xscale);
|
||||
smscale);
|
||||
|
||||
constexpr auto x_size_per_row =
|
||||
x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
|
||||
@@ -94,7 +97,7 @@ struct SmoothquantPipelineTwoPass
|
||||
block_reduce2d(y, absmax, reduce_absmax_func);
|
||||
|
||||
move_tile_window(x_window, {0, Block_N});
|
||||
move_tile_window(xscale_window, {Block_N});
|
||||
move_tile_window(smscale_window, {Block_N});
|
||||
}
|
||||
|
||||
// compute absmax, cross-lane->cross-warp
|
||||
@@ -114,20 +117,20 @@ struct SmoothquantPipelineTwoPass
|
||||
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(xscale_window, {-Block_N});
|
||||
move_tile_window(smscale_window, {-Block_N});
|
||||
move_tile_window(qy_window, {0, stride_to_right_most_window});
|
||||
|
||||
// recompute y and quantize y to qy
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto x = load_tile(x_window);
|
||||
const auto xscale = load_tile(xscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
const auto x = load_tile(x_window);
|
||||
const auto smscale = load_tile(smscale_window);
|
||||
const auto y = tile_elementwise_in(
|
||||
[&](const auto& a, const auto& b) {
|
||||
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
|
||||
},
|
||||
x,
|
||||
xscale);
|
||||
smscale);
|
||||
|
||||
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
|
||||
sweep_tile(qy, [&](auto idx) {
|
||||
@@ -138,7 +141,7 @@ struct SmoothquantPipelineTwoPass
|
||||
store_tile(qy_window, qy);
|
||||
|
||||
move_tile_window(x_window, {0, -Block_N});
|
||||
move_tile_window(xscale_window, {0, -Block_N});
|
||||
move_tile_window(smscale_window, {0, -Block_N});
|
||||
move_tile_window(qy_window, {0, -Block_N});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user