mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] layernorm support fused-quant/fused-add (#1604)
* add prenorm/postnorm support, refactor using generate.py * update README * update README * fix format * update some description and fix format * update format * format * use non-raw for loading * format and update n4096 * dynamic-quant ready * update readme * support fused dynamic-quant * update fused-quant, with smooth * update README * update args * update some based on comment
This commit is contained in:
@@ -1,11 +1,34 @@
|
||||
set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd")
|
||||
set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
|
||||
)
|
||||
|
||||
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
|
||||
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${INSTANCE_SRCS})
|
||||
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
|
||||
|
||||
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
|
||||
|
||||
|
||||
@@ -1,6 +1,42 @@
|
||||
# Layernorm2D forward
|
||||
|
||||
This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation.
|
||||
This folder contains example for Layernorm2D forward using `ck_tile` tile-programming implementation.
|
||||
|
||||
# Implementation and feature support
|
||||
|
||||
## welford online algorithm
|
||||
We use welfold algorithm to update `mean`/`variance` block by block. For `N <=4096` case we can compute `mean`/`var`/`normalization` within one loop, we call it `one-pass`. For large N case, it is hard to keep `mean`/`var` inside register/LDS and then computation `normalization`, so we need to load input twice, first time to compute `mean`/`var` block-by-block, then load input another time to compute the `normalization`. We call it `two-pass`.
|
||||
|
||||
## mean/variance save
|
||||
In training case the mean/variance need to store out (TBD, not supported yet)
|
||||
|
||||
## prenorm/postnorm
|
||||
|
||||

|
||||
|
||||
since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default).
|
||||
|
||||
## smooth-quant/dynamic-quant
|
||||
we support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen)
|
||||

|
||||
|
||||
```
|
||||
# assume output int8, hidden_states is [m, n] shape and in fp16/bf16
|
||||
# [m, 1]
|
||||
per_token_amax, _ = torch.max(
|
||||
input=torch.abs(hidden_states),
|
||||
dim=-1,
|
||||
keepdim=True
|
||||
)
|
||||
per_token_scale = per_token_amax.to(dtype=torch.float32) / 127.0
|
||||
|
||||
# quant hidden_states
|
||||
hidden_states = (hidden_states / per_token_scale).to(dtype=torch.int8)
|
||||
|
||||
return hidden_states, per_token_scale
|
||||
# hidden_states now is int8 will feed to next layer as intput
|
||||
# per_token_scale will be used as dequant factor later layer
|
||||
```
|
||||
|
||||
## build
|
||||
```
|
||||
@@ -15,8 +51,35 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:3328)
|
||||
-n m dimension (default:4096)
|
||||
-n n dimension (default:4096)
|
||||
-stride stride per row, if -1 then equal to n (default:-1)
|
||||
-e epsilon (default:1e-5)
|
||||
-save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0)
|
||||
-v cpu validation or not (default:1)
|
||||
-prec precision (default:fp16)
|
||||
-kname print kernel name or not (default:1)
|
||||
-prec_i input precision (default:fp16)
|
||||
-prec_o output precision, set auto will be the same as input (default:auto)
|
||||
-prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
|
||||
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
|
||||
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
|
||||
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
|
||||
-warmup cold iter (default:5)
|
||||
-repeat hot iter (default:20)
|
||||
|
||||
```
|
||||
|
||||
## limitations
|
||||
Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, N>8192 case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet.
|
||||
|
||||
```
|
||||
# some case
|
||||
# standard fp16 layernorm 2d, m=10. n=1024
|
||||
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024
|
||||
|
||||
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant, output in int8
|
||||
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1
|
||||
|
||||
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8
|
||||
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1
|
||||
|
||||
```
|
||||
670
example/ck_tile/02_layernorm2d/generate.py
Normal file
670
example/ck_tile/02_layernorm2d/generate.py
Normal file
@@ -0,0 +1,670 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List, Optional, Any
|
||||
import functools
|
||||
import itertools
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
def get_if_str(idx, total, lase_else = True):
|
||||
if idx == 0:
|
||||
return 'if'
|
||||
elif idx < total - 1:
|
||||
return 'else if'
|
||||
else:
|
||||
if lase_else:
|
||||
return 'else'
|
||||
else:
|
||||
return 'else if'
|
||||
|
||||
FUSED_ADD_ENUM_STR_MAP = [
|
||||
'no',
|
||||
'pras', # pre-norm
|
||||
'pra' ] # post-norm
|
||||
|
||||
FUSED_FUSED_SWEEP_STR_MAP = [
|
||||
'no',
|
||||
'dquant' ]
|
||||
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::fp16_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t'}
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
return 'true'
|
||||
else:
|
||||
return 'false'
|
||||
|
||||
class layernorm_fwd_codegen:
|
||||
API_TRAITS_DEFINE = """
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
struct layernorm2d_fwd_traits_
|
||||
{
|
||||
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
||||
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
|
||||
using XScaleDataType = ck_tile::remove_cvref_t<XScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename XScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
using traits_ = layernorm2d_fwd_traits_<XDataType_,
|
||||
YDataType_,
|
||||
XScaleDataType_,
|
||||
YScaleDataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveMeanInvStd_,
|
||||
kTwoPass_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
"""
|
||||
API_COMMON_HEADER = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = layernorm2d_fwd_args;
|
||||
|
||||
{F_traits_define}
|
||||
|
||||
template <typename Traits_>
|
||||
float layernorm2d_fwd_(const S& s, A a)
|
||||
{{
|
||||
using XDataType = typename Traits_::XDataType;
|
||||
using YDataType = typename Traits_::YDataType;
|
||||
using XScaleDataType = typename Traits_::XScaleDataType;
|
||||
using YScaleDataType = typename Traits_::YScaleDataType;
|
||||
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType;
|
||||
|
||||
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
|
||||
Traits_::kSaveMeanInvStd,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::GammaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::BetaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::ComputeDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::MeanDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::InvStdDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::XScaleDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, XScaleDataType, YScaleDataType>::YScaleDataType,
|
||||
typename Traits_::Shape,
|
||||
PipelineTraits>;
|
||||
|
||||
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
|
||||
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
||||
|
||||
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, YScaleDataType, YDataType, typename Traits_::Shape,
|
||||
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, false, true/*max3*/>>;
|
||||
|
||||
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
|
||||
|
||||
using Epilogue = std::conditional_t<Traits_::kFusedQuant == 1, DynamicQuantEpilogue, Default2DEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
API_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
|
||||
{F_traits_define}
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename Traits_>
|
||||
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
|
||||
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
layernorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
|
||||
{F_per_n_case}
|
||||
}}
|
||||
"""
|
||||
API_PER_N_CASE=""" {F_if} {F_N_COND} {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
API_INNER_CASE=""" {F_if} {F_VEC_COND}
|
||||
r={F_instance_func}(s, a);
|
||||
"""
|
||||
|
||||
INSTANCE_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_api_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, working_path, kernel_filter):
|
||||
self.working_path = working_path
|
||||
self.kernel_filter = kernel_filter
|
||||
|
||||
class k_fuesd_add_enum(IntEnum):
|
||||
F_NO_ADD = 0
|
||||
F_PRE_ADD = 1
|
||||
F_PRE_ADD_STORE_RESIDUAL = 2
|
||||
|
||||
class k_fused_sweep_enum(IntEnum):
|
||||
F_NO_SWEEP = 0
|
||||
F_RENORM = 1
|
||||
F_DYNAMIC_QUANT = 2
|
||||
|
||||
@dataclass
|
||||
class k_traits:
|
||||
F_kPadN : bool
|
||||
F_kSaveMeanInvStd : bool
|
||||
F_kTwoPass : bool
|
||||
F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum
|
||||
F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum
|
||||
|
||||
@dataclass
|
||||
class k_shape:
|
||||
F_BlockTile : List[int]
|
||||
F_WarpPerBlock : List[int]
|
||||
F_WarpTile : List[int]
|
||||
F_Vector_ : List[int]
|
||||
@property
|
||||
def F_BlockSize(self) -> int:
|
||||
return functools.reduce(lambda a, b: a*b, self.F_WarpTile)
|
||||
|
||||
@dataclass
|
||||
class k_problem:
|
||||
F_XDataType : str
|
||||
F_GammaDataType : str
|
||||
F_BetaDataType : str
|
||||
F_ComputeDataType : str
|
||||
F_YDataType : str
|
||||
F_MeanDataType : str
|
||||
F_InvStdDataType : str
|
||||
F_BlockShape : str
|
||||
F_Traits : Any #k_traits
|
||||
|
||||
@dataclass
|
||||
class k_pipeline_one_pass:
|
||||
F_Problem : Any #k_problem
|
||||
|
||||
@dataclass
|
||||
class k_pipeline_two_pass:
|
||||
F_Problem : Any #k_problem
|
||||
|
||||
@dataclass
|
||||
class default_2d_epilogue_problem:
|
||||
F_AccDataType : str
|
||||
F_ODataType : str
|
||||
F_kPadM : bool
|
||||
F_kPadN : bool
|
||||
|
||||
@dataclass
|
||||
class default_2d_epilogue:
|
||||
F_problem : Any
|
||||
|
||||
@dataclass
|
||||
class k_kernel:
|
||||
F_pipeline : Any
|
||||
F_epilogue : Any
|
||||
|
||||
@dataclass
|
||||
class h_traits:
|
||||
F_XDataType : str
|
||||
F_YDataType : str
|
||||
F_XScaleDataType : str
|
||||
F_YScaleDataType : str
|
||||
F_Repeat_M : int
|
||||
F_Repeat_N : int
|
||||
F_ThreadPerBlock_M : int
|
||||
F_ThreadPerBlock_N : int
|
||||
F_Vector_N : int
|
||||
F_kPadN : bool
|
||||
F_kSaveMeanInvStd_ : bool
|
||||
F_kTwoPass_ : bool
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
return t_
|
||||
|
||||
# string when calling this kernel
|
||||
@property
|
||||
def call_name(self) -> str:
|
||||
return f'layernorm2d_fwd_<traits_<{self.trait_name}>>'
|
||||
|
||||
# string when define this kernel
|
||||
@property
|
||||
def def_name(self) -> str:
|
||||
return f'template float layernorm2d_fwd_<traits_<{self.trait_name}>>(const S&, A);'
|
||||
|
||||
# this class hold kernel under same source file
|
||||
@dataclass
|
||||
class h_instance:
|
||||
F_DataTypePair : str
|
||||
F_N : str
|
||||
F_add : int
|
||||
F_sweep : int
|
||||
instance_list : List[Any] # List[h_traits]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
prec_i, prec_o = self.F_DataTypePair.split(',')
|
||||
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
|
||||
nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}'
|
||||
if self.F_add != 0:
|
||||
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
|
||||
if self.F_sweep != 0:
|
||||
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
||||
return nnn
|
||||
|
||||
@property
|
||||
def instance_name(self) ->str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def content(self) ->str:
|
||||
instance_defs = ''
|
||||
for ins in self.instance_list:
|
||||
instance_defs += ins.def_name + '\n'
|
||||
return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs)
|
||||
|
||||
@property
|
||||
def name_api(self) -> str:
|
||||
return 'layernorm2d_fwd_api'
|
||||
|
||||
@property
|
||||
def name_common_header(self) -> str:
|
||||
return 'layernorm2d_fwd_api_common'
|
||||
|
||||
@property
|
||||
def content_api(self) -> str:
|
||||
# 1 sort based on dtype
|
||||
t_dtype_dict = dict()
|
||||
blobs = self.get_blobs()
|
||||
for blob in blobs:
|
||||
if blob.F_DataTypePair not in t_dtype_dict:
|
||||
t_dtype_dict[blob.F_DataTypePair] = {}
|
||||
if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]:
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_N] = []
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob)
|
||||
|
||||
d_str = ''
|
||||
for i_d, dtype_ in enumerate(t_dtype_dict):
|
||||
blob_per_t = t_dtype_dict[dtype_]
|
||||
n_str = ''
|
||||
for i_n, n_ in enumerate(blob_per_t):
|
||||
blob_per_n = blob_per_t[n_]
|
||||
inner_str = ""
|
||||
for i_b, b_ in enumerate(blob_per_n):
|
||||
# generate single kernel instance file
|
||||
#vec_str = ""
|
||||
for i_ins, ins in enumerate(b_.instance_list):
|
||||
idx_in_n = i_b * len(b_.instance_list) + i_ins
|
||||
len_in_n = len(blob_per_n) * len(b_.instance_list)
|
||||
# _if = 'if' if i_ins == 0 else 'else if'
|
||||
if ins.F_kFusedQuant == 0:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
|
||||
elif ins.F_kFusedQuant == 1:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType)
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
|
||||
f_sweep_cond = _sweep_cond)
|
||||
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
|
||||
F_VEC_COND = _cond, F_instance_func=ins.call_name)
|
||||
#inner_str = inner_str + vec_str
|
||||
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else ''
|
||||
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
|
||||
prec_i, prec_o = dtype_.split(',')
|
||||
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
|
||||
|
||||
api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str)
|
||||
return api_base
|
||||
|
||||
@property
|
||||
def content_common_header(self) -> str:
|
||||
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
|
||||
|
||||
def get_blobs(self):
|
||||
h_traits = layernorm_fwd_codegen.h_traits
|
||||
h_instance = layernorm_fwd_codegen.h_instance
|
||||
|
||||
dynamic_quant_out_dtype = ['int8']
|
||||
# some predefined support range
|
||||
# (prec_i,prec_o) for simplicity this string will be used as key for dict
|
||||
scale_list = [('fp32,fp32')]
|
||||
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
|
||||
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
|
||||
#fused_add_list = [0, 1, 2]
|
||||
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
|
||||
fused_add_list = [0, 1]
|
||||
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
|
||||
|
||||
# rm rn tm tn vn pd mv 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_x, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
|
||||
continue # skip non dynamic quant case
|
||||
if fused_quant == 1 and hs_key == 'big':
|
||||
continue
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_XScaleDataType = scale_y
|
||||
h_.F_YScaleDataType = scale_x
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs))
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
list_p = w_p / 'layernorm2d_fwd_blobs.txt'
|
||||
blobs = self.get_blobs()
|
||||
with list_p.open('a') as list_f:
|
||||
# api related file
|
||||
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
|
||||
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
|
||||
# kernel instance file
|
||||
for b in blobs:
|
||||
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
|
||||
|
||||
def gen_blobs(self) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
(w_p / (self.name_api + ".cpp")).write_text(self.content_api)
|
||||
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
|
||||
blobs = self.get_blobs()
|
||||
for b in blobs:
|
||||
(w_p / (b.name + ".cpp")).write_text(b.content)
|
||||
|
||||
def list_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs()
|
||||
|
||||
|
||||
def gen_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK layernorm kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--api",
|
||||
default='fwd[all]',
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma."
|
||||
)
|
||||
|
||||
# the directory for list_blobs/gen_blobs to write files into
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--working_path",
|
||||
default="./",
|
||||
required=False,
|
||||
help="the path where all the blobs are going to be generated"
|
||||
)
|
||||
|
||||
# this script have 2 modes
|
||||
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
|
||||
# this is useful in build system like cmake to construct source code dependency, by
|
||||
# reading the content out of this file
|
||||
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
|
||||
# like FA, only need to use this mode
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
action='store_true',
|
||||
help="list all the kernels to a file, "
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--gen_blobs",
|
||||
action='store_true',
|
||||
help="generate all kernels into different tile"
|
||||
)
|
||||
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--traits",
|
||||
default="all",
|
||||
required=False,
|
||||
help="enable/disable some feature. default generate all"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# print(f'{args.list_blobs}-{args.gen_blobs}')
|
||||
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
|
||||
print('gen_blobs/list_blobs must specify only one option')
|
||||
sys.exit()
|
||||
|
||||
p = Path(args.working_path)
|
||||
if not p.exists():
|
||||
p.mkdir()
|
||||
|
||||
if args.list_blobs:
|
||||
list_blobs(args)
|
||||
else:
|
||||
gen_blobs(args)
|
||||
@@ -1,155 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = layernorm2d_fwd_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveMeanInvStd_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename data_type>
|
||||
float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
|
||||
layernorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
#if 1
|
||||
float r = -1;
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
if(a.n <= 64) {
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 128) {
|
||||
if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 256) {
|
||||
if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 512) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 768) {
|
||||
if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 1024) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 1536) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 2048) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 3072) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 4096) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, false>>(s, a);
|
||||
}
|
||||
else if(a.n > 4096) {
|
||||
if (a.n % 8 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, true>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, true>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, true>>(s, a);
|
||||
else
|
||||
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
|
||||
}
|
||||
return r;
|
||||
#else
|
||||
return layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a);
|
||||
#endif
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
layernorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
|
||||
float r = -1;
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
return layernorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
|
||||
}
|
||||
else if(t.data_type.compare("bf16") == 0)
|
||||
{
|
||||
return layernorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
|
||||
}
|
||||
if(r < 0)
|
||||
throw std::runtime_error("Without supported instances!");
|
||||
|
||||
return r;
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
#if 0
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,13 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,13 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,22 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
#if 0
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,13 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,14 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -1,13 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,12 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd mv 2p
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
|
||||
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -1,67 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = layernorm2d_fwd_args;
|
||||
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = layernorm2d_fwd_traits_<DataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveMeanInvStd_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename Traits_>
|
||||
float layernorm2d_fwd_(const S& s, A a)
|
||||
{
|
||||
using DataType = typename Traits_::DataType;
|
||||
|
||||
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
|
||||
typename LayerNormTypeConfig<DataType>::XDataType,
|
||||
typename LayerNormTypeConfig<DataType>::GammaDataType,
|
||||
typename LayerNormTypeConfig<DataType>::BetaDataType,
|
||||
typename LayerNormTypeConfig<DataType>::ComputeDataType,
|
||||
typename LayerNormTypeConfig<DataType>::YDataType,
|
||||
typename LayerNormTypeConfig<DataType>::MeanDataType,
|
||||
typename LayerNormTypeConfig<DataType>::InvStdDataType,
|
||||
typename Traits_::Shape,
|
||||
Traits_::kPadN,
|
||||
Traits_::kSaveMeanInvStd,
|
||||
Traits_::kTwoPass>;
|
||||
|
||||
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
|
||||
using Kernel = ck_tile::Layernorm2dFwd<Pipeline>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
// different threshold for different dtype
|
||||
@@ -29,7 +30,16 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("prec_i", "fp16", "input precision")
|
||||
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
|
||||
.insert("prec_sx",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1")
|
||||
.insert("prec_sy",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1 or 2")
|
||||
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
@@ -37,7 +47,11 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType, bool SaveMeanVar>
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename XScaleDataType,
|
||||
typename YScaleDataType,
|
||||
bool SaveMeanVar>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
@@ -45,21 +59,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = n;
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sx = arg_parser.get_str("prec_sx");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
prec_o = prec_i;
|
||||
}
|
||||
if(prec_sx == "auto")
|
||||
{
|
||||
prec_sx = "fp32";
|
||||
}
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
if(fused_quant == 1 && prec_o != "int8")
|
||||
{
|
||||
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
assert(stride >= n);
|
||||
|
||||
using TypeConfig = LayerNormTypeConfig<DataType>;
|
||||
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YDataType = typename TypeConfig::YDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using BetaDataType = typename TypeConfig::BetaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YDataType = typename TypeConfig::YDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using BetaDataType = typename TypeConfig::BetaDataType;
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
using MeanDataType =
|
||||
std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>;
|
||||
@@ -73,36 +112,72 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({n});
|
||||
ck_tile::HostTensor<BetaDataType> beta_host({n});
|
||||
|
||||
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
|
||||
|
||||
ck_tile::HostTensor<XScaleDataType> x_scale_host({n});
|
||||
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
|
||||
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
beta_buf.ToDevice(beta_host.data());
|
||||
x_residual_buf.ToDevice(x_residual_host.data());
|
||||
x_scale_buf.ToDevice(x_scale_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_o)
|
||||
{
|
||||
base_str += "|" + prec_o;
|
||||
}
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
base_str += std::string("(") + prec_sy + ")";
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
|
||||
layernorm2d_fwd_traits traits{data_type, SaveMeanVar};
|
||||
layernorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
|
||||
|
||||
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
beta_buf.GetDeviceBuffer(),
|
||||
|
||||
y_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
nullptr, // p_mean, unsupported yet
|
||||
nullptr, // p_invStd, unsupported yet
|
||||
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
@@ -111,6 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float ave_time = layernorm2d_fwd(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
|
||||
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
|
||||
|
||||
@@ -122,6 +203,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
if(fused_add != 0)
|
||||
{
|
||||
// fused pre_add/pre_add_store
|
||||
// TODO we accumulate directly to x_host for simplcity here...
|
||||
|
||||
std::transform(x_host.mData.cbegin(),
|
||||
x_host.mData.cend(),
|
||||
x_residual_host.mData.cbegin(),
|
||||
x_host.mData.begin(),
|
||||
std::plus<XDataType>{});
|
||||
}
|
||||
ck_tile::reference_layernorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
@@ -131,13 +223,80 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
InvStdDataType>(
|
||||
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
|
||||
|
||||
if(fused_quant != 0)
|
||||
{
|
||||
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
|
||||
int N_ = acc_.mDesc.get_lengths()[1];
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
// input smooth outlier
|
||||
acc_(m_, n_) =
|
||||
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_));
|
||||
}
|
||||
}
|
||||
ComputeDataType absmax = static_cast<ComputeDataType>(0);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
const auto a = ck_tile::abs(acc_(m_, n_));
|
||||
absmax = a > absmax ? a : absmax;
|
||||
}
|
||||
// printf("cpu:absmax:%f\n", absmax);
|
||||
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
|
||||
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::reference_layernorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType>(x_host,
|
||||
gamma_host,
|
||||
beta_host,
|
||||
y_host_ref,
|
||||
mean_host_ref,
|
||||
invStd_host_ref,
|
||||
epsilon,
|
||||
dquant_functor);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_layernorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType>(
|
||||
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
|
||||
}
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>();
|
||||
ck_tile::HostTensor<YResidualDataType> sy_host_dev({m, n}, {stride, 1});
|
||||
if(fused_add == 1)
|
||||
{
|
||||
y_residual_buf.FromDevice(sy_host_dev.data());
|
||||
}
|
||||
|
||||
auto [rtol, atol] = get_elimit<InDataType>();
|
||||
|
||||
if(stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(
|
||||
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
if(fused_add == 1)
|
||||
{
|
||||
pass &= ck_tile::check_err(
|
||||
sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -153,8 +312,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
if(fused_add == 1)
|
||||
{
|
||||
std::vector<YResidualDataType> sy_host_dev_row(
|
||||
sy_host_dev.begin() + i_r * stride, sy_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<YResidualDataType> sy_host_ref_row(
|
||||
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n);
|
||||
pass &= ck_tile::check_err(sy_host_dev_row,
|
||||
sy_host_ref_row,
|
||||
std::string("ADD[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
}
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
y_scale_buf.FromDevice(y_scale_host_dev.data());
|
||||
pass &= ck_tile::check_err(y_scale_host_dev,
|
||||
y_scale_host_ref,
|
||||
std::string("SCALE Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
@@ -168,23 +349,56 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
int save_mv = arg_parser.get_int("save_mv");
|
||||
if(data_type == "fp16" && save_mv)
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sx = arg_parser.get_str("prec_sx");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2;
|
||||
prec_o = prec_i;
|
||||
}
|
||||
else if(data_type == "fp16" && !save_mv)
|
||||
if(prec_sx == "auto")
|
||||
{
|
||||
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2;
|
||||
prec_sx = "fp32";
|
||||
}
|
||||
else if(data_type == "bf16" && save_mv)
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
else if(data_type == "bf16" && !save_mv)
|
||||
int save_mv = arg_parser.get_int("save_mv");
|
||||
|
||||
// no dynamic quant case
|
||||
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
// dynamic quant case, only in inference
|
||||
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
@@ -8,31 +8,35 @@
|
||||
#include "ck_tile/ops/layernorm2d.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename DataType>
|
||||
template <typename InType, typename OutType, typename XScaleDataType_, typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig;
|
||||
|
||||
template <>
|
||||
struct LayerNormTypeConfig<ck_tile::half_t>
|
||||
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig<ck_tile::half_t, OutType, XScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = ck_tile::half_t;
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
using InvStdDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
using XScaleDataType = XScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LayerNormTypeConfig<ck_tile::bf16_t>
|
||||
template <typename OutType, typename XScaleDataType_, typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, XScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = ck_tile::bf16_t;
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using BetaDataType = ck_tile::bf16_t;
|
||||
using MeanDataType = ck_tile::bf16_t;
|
||||
using InvStdDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
using XScaleDataType = XScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
@@ -40,82 +44,21 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename DataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kTwoPass_>
|
||||
struct layernorm2d_fwd_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Layernorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct layernorm2d_fwd_traits
|
||||
{
|
||||
std::string data_type;
|
||||
bool save_mean_var;
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_o; // output precision
|
||||
|
||||
// if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set
|
||||
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
|
||||
// can set arbitrary(will skip check)
|
||||
std::string prec_sx; // x-scale, used for [1*N] input smooth quant
|
||||
std::string prec_sy; // y-scale, used for [M*1] output for next layer
|
||||
|
||||
bool save_mean_var; //
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
BIN
example/ck_tile/02_layernorm2d/misc/dquant.png
Normal file
BIN
example/ck_tile/02_layernorm2d/misc/dquant.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 36 KiB |
BIN
example/ck_tile/02_layernorm2d/misc/pnorm.png
Normal file
BIN
example/ck_tile/02_layernorm2d/misc/pnorm.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 31 KiB |
@@ -2,37 +2,37 @@
|
||||
# run from top of ck folder
|
||||
EXE=build/bin/tile_example_layernorm2d_fwd
|
||||
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
|
||||
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000
|
||||
@@ -2,30 +2,34 @@
|
||||
# call from top of CK folder
|
||||
EXE=./build/bin/tile_example_layernorm2d_fwd
|
||||
|
||||
for fquant in "" "-fquant=1 -prec_o=int8"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
$EXE -prec=$pr_i -m=99 -n=13
|
||||
$EXE -prec=$pr_i -m=17 -n=16
|
||||
$EXE -prec=$pr_i -m=1 -n=100
|
||||
$EXE -prec=$pr_i -m=4 -n=128
|
||||
$EXE -prec=$pr_i -m=80 -n=127
|
||||
$EXE -prec=$pr_i -m=22 -n=255 -stride=256
|
||||
$EXE -prec=$pr_i -m=7 -n=599
|
||||
$EXE -prec=$pr_i -m=19 -n=512
|
||||
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000
|
||||
$EXE -prec=$pr_i -m=11 -n=510
|
||||
$EXE -prec=$pr_i -m=171 -n=676 -stride=818
|
||||
$EXE -prec=$pr_i -m=91 -n=636
|
||||
$EXE -prec=$pr_i -m=12 -n=768 -stride=800
|
||||
$EXE -prec=$pr_i -m=100 -n=766 -stride=812
|
||||
$EXE -prec=$pr_i -m=31 -n=1024
|
||||
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec=$pr_i -m=8 -n=1501
|
||||
$EXE -prec=$pr_i -m=3 -n=1826
|
||||
$EXE -prec=$pr_i -m=5 -n=2040
|
||||
$EXE -prec=$pr_i -m=7 -n=2734
|
||||
$EXE -prec=$pr_i -m=1 -n=3182
|
||||
$EXE -prec=$pr_i -m=9 -n=4096
|
||||
$EXE -prec=$pr_i -m=3 -n=8192
|
||||
$EXE -prec=$pr_i -m=1 -n=10547
|
||||
$EXE -prec=$pr_i -m=3 -n=17134
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user