// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/rmsnorm2d.hpp" #include template struct RmsnormTypeConfig; template struct RmsnormTypeConfig { using XDataType = ck_tile::half_t; using YDataType = OutType; using GammaDataType = ck_tile::half_t; using InvRmsDataType = ck_tile::half_t; using UnquantYDataType = ck_tile::half_t; using ComputeDataType = float; using SmoothScaleDataType = SmoothScaleDataType_; using YScaleDataType = YScaleDataType_; }; template struct RmsnormTypeConfig { using XDataType = ck_tile::bf16_t; using YDataType = OutType; using GammaDataType = ck_tile::bf16_t; using InvRmsDataType = ck_tile::bf16_t; using UnquantYDataType = ck_tile::bf16_t; using ComputeDataType = float; using SmoothScaleDataType = SmoothScaleDataType_; using YScaleDataType = YScaleDataType_; }; // runtime args struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs { }; template float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); // This is the public API, will be generated by script struct rmsnorm2d_fwd_traits { std::string prec_i; // input precision std::string prec_o; // output precision // if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise // can set arbitrary(will skip check) std::string prec_sm; // x-scale, used for [1*N] input smooth quant std::string prec_sy; // y-scale, used for [M*1] output for next layer bool save_rms; bool save_unquant; int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);