mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
* chore(copyright): update copyright header for test directory * chore(copyright): update copyright header for test directory * chore(copyright): update copyright header for client_example directory * chore(copyright): update copyright header for test directory
71 lines
2.6 KiB
C++
71 lines
2.6 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/ops/layernorm2d.hpp"
|
|
#include <string>
|
|
|
|
template <typename InType,
|
|
typename OutType,
|
|
typename SmoothSScaleDataType_,
|
|
typename YScaleDataType_>
|
|
struct LayerNormTypeConfig;
|
|
|
|
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
|
struct LayerNormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
|
{
|
|
using XDataType = ck_tile::half_t;
|
|
using YDataType = OutType;
|
|
using XBiasDataType = ck_tile::half_t;
|
|
using GammaDataType = ck_tile::half_t;
|
|
using BetaDataType = ck_tile::half_t;
|
|
using MeanDataType = ck_tile::half_t;
|
|
using InvStdDataType = ck_tile::half_t;
|
|
using ComputeDataType = float;
|
|
using SmoothScaleDataType = SmoothScaleDataType_;
|
|
using YScaleDataType = YScaleDataType_;
|
|
};
|
|
|
|
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
|
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
|
{
|
|
using XDataType = ck_tile::bf16_t;
|
|
using YDataType = OutType;
|
|
using XBiasDataType = ck_tile::bf16_t;
|
|
using GammaDataType = ck_tile::bf16_t;
|
|
using BetaDataType = ck_tile::bf16_t;
|
|
using MeanDataType = ck_tile::bf16_t;
|
|
using InvStdDataType = ck_tile::bf16_t;
|
|
using ComputeDataType = float;
|
|
using SmoothScaleDataType = SmoothScaleDataType_;
|
|
using YScaleDataType = YScaleDataType_;
|
|
};
|
|
|
|
// runtime args
|
|
struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
|
|
{
|
|
};
|
|
|
|
// This is the public API, will be generated by script
|
|
struct layernorm2d_fwd_traits
|
|
{
|
|
std::string prec_i; // input precision
|
|
std::string prec_o; // output precision
|
|
|
|
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
|
|
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
|
|
// can set arbitrary(will skip check)
|
|
std::string prec_sm; // x-scale, used for [1*N] input smooth quant
|
|
std::string prec_sy; // y-scale, used for [M*1] output for next layer
|
|
|
|
bool save_mean_var; //
|
|
int xbias; // 0:no-bias, 1:add bias
|
|
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
|
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
|
};
|
|
|
|
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
|