diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index feae5f791d..1bf74bc055 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -1,11 +1,34 @@ +set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd") +set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") +if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all") + set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS}) +endif() + +# generate a list of kernels, but not actually emit files at config sta +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs +) + set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" + message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") -file(GLOB INSTANCE_SRCS instances/*.cpp) add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) +target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 405325a2a1..14c6fc0d67 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -1,6 +1,42 @@ # Layernorm2D forward -This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. +This folder contains example for Layernorm2D forward using `ck_tile` tile-programming implementation. + +# Implementation and feature support + +## welford online algorithm +We use welfold algorithm to update `mean`/`variance` block by block. For `N <=4096` case we can compute `mean`/`var`/`normalization` within one loop, we call it `one-pass`. For large N case, it is hard to keep `mean`/`var` inside register/LDS and then computation `normalization`, so we need to load input twice, first time to compute `mean`/`var` block-by-block, then load input another time to compute the `normalization`. We call it `two-pass`. + +## mean/variance save +In training case the mean/variance need to store out (TBD, not supported yet) + +## prenorm/postnorm + +![](misc/pnorm.png) + +since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default). + +## smooth-quant/dynamic-quant +we support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen) +![](misc/dquant.png) + +``` +# assume output int8, hidden_states is [m, n] shape and in fp16/bf16 +# [m, 1] +per_token_amax, _ = torch.max( + input=torch.abs(hidden_states), + dim=-1, + keepdim=True +) +per_token_scale = per_token_amax.to(dtype=torch.float32) / 127.0 + +# quant hidden_states +hidden_states = (hidden_states / per_token_scale).to(dtype=torch.int8) + +return hidden_states, per_token_scale +# hidden_states now is int8 will feed to next layer as intput +# per_token_scale will be used as dequant factor later layer +``` ## build ``` @@ -15,8 +51,35 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd` ``` args: -m m dimension (default:3328) - -n m dimension (default:4096) + -n n dimension (default:4096) + -stride stride per row, if -1 then equal to n (default:-1) -e epsilon (default:1e-5) + -save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0) -v cpu validation or not (default:1) - -prec precision (default:fp16) + -kname print kernel name or not (default:1) + -prec_i input precision (default:fp16) + -prec_o output precision, set auto will be the same as input (default:auto) + -prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) + -prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto) + -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) + -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) + -warmup cold iter (default:5) + -repeat hot iter (default:20) + ``` + +## limitations +Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, N>8192 case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. + +``` +# some case +# standard fp16 layernorm 2d, m=10. n=1024 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 + +# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant, output in int8 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 + +# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1 + +``` \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py new file mode 100644 index 0000000000..300f6c05e1 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -0,0 +1,670 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Any +import functools +import itertools +import copy +from dataclasses import dataclass + +def get_if_str(idx, total, lase_else = True): + if idx == 0: + return 'if' + elif idx < total - 1: + return 'else if' + else: + if lase_else: + return 'else' + else: + return 'else if' + +FUSED_ADD_ENUM_STR_MAP = [ + 'no', + 'pras', # pre-norm + 'pra' ] # post-norm + +FUSED_FUSED_SWEEP_STR_MAP = [ + 'no', + 'dquant' ] + +DATA_TYPE_MAP = {'fp32' : 'float', + 'fp16' : 'ck_tile::fp16_t', + 'bf16' : 'ck_tile::bf16_t', + 'int8' : 'ck_tile::int8_t'} + +def BOOL_MAP(b_) -> str: + if b_: + return 'true' + else: + return 'false' + +class layernorm_fwd_codegen: + API_TRAITS_DEFINE = """ +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct layernorm2d_fwd_traits_ +{ + using XDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using XScaleDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; + static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; +}; + +template +using traits_ = layernorm2d_fwd_traits_; +""" + API_COMMON_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" +#include +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = layernorm2d_fwd_args; + +{F_traits_define} + +template +float layernorm2d_fwd_(const S& s, A a) +{{ + using XDataType = typename Traits_::XDataType; + using YDataType = typename Traits_::YDataType; + using XScaleDataType = typename Traits_::XScaleDataType; + using YScaleDataType = typename Traits_::YScaleDataType; + using ComputeDataType = typename LayerNormTypeConfig::ComputeDataType; + + using PipelineTraits = ck_tile::Layernorm2dFwdTraits(Traits_::kFusedAdd), + static_cast(Traits_::kFusedQuant)>; + using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< + typename LayerNormTypeConfig::XDataType, + typename LayerNormTypeConfig::GammaDataType, + typename LayerNormTypeConfig::BetaDataType, + typename LayerNormTypeConfig::ComputeDataType, + typename LayerNormTypeConfig::YDataType, + typename LayerNormTypeConfig::MeanDataType, + typename LayerNormTypeConfig::InvStdDataType, + typename LayerNormTypeConfig::XScaleDataType, + typename LayerNormTypeConfig::YScaleDataType, + typename Traits_::Shape, + PipelineTraits>; + + using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; + using Default2DEpilogue = ck_tile::Default2DEpilogue; + + using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + + using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; + + using Epilogue = std::conditional_t; + + using Kernel = ck_tile::Layernorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); +}} + +""" + + API_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" + +{F_traits_define} + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a); + +float layernorm2d_fwd(layernorm2d_fwd_traits t, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{{ + float r = -1; +{F_dispatch} + return r; +}} + +""" + + API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{ +{F_per_n_case} + }} +""" + API_PER_N_CASE=""" {F_if} {F_N_COND} {{ +{F_inner_dispatch} + }} +""" + API_INNER_CASE=""" {F_if} {F_VEC_COND} + r={F_instance_func}(s, a); +""" + + INSTANCE_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_api_common.hpp" + +// clang-format off +// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep +{F_instance_def} +// clang-format on + +""" + + def __init__(self, working_path, kernel_filter): + self.working_path = working_path + self.kernel_filter = kernel_filter + + class k_fuesd_add_enum(IntEnum): + F_NO_ADD = 0 + F_PRE_ADD = 1 + F_PRE_ADD_STORE_RESIDUAL = 2 + + class k_fused_sweep_enum(IntEnum): + F_NO_SWEEP = 0 + F_RENORM = 1 + F_DYNAMIC_QUANT = 2 + + @dataclass + class k_traits: + F_kPadN : bool + F_kSaveMeanInvStd : bool + F_kTwoPass : bool + F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum + F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum + + @dataclass + class k_shape: + F_BlockTile : List[int] + F_WarpPerBlock : List[int] + F_WarpTile : List[int] + F_Vector_ : List[int] + @property + def F_BlockSize(self) -> int: + return functools.reduce(lambda a, b: a*b, self.F_WarpTile) + + @dataclass + class k_problem: + F_XDataType : str + F_GammaDataType : str + F_BetaDataType : str + F_ComputeDataType : str + F_YDataType : str + F_MeanDataType : str + F_InvStdDataType : str + F_BlockShape : str + F_Traits : Any #k_traits + + @dataclass + class k_pipeline_one_pass: + F_Problem : Any #k_problem + + @dataclass + class k_pipeline_two_pass: + F_Problem : Any #k_problem + + @dataclass + class default_2d_epilogue_problem: + F_AccDataType : str + F_ODataType : str + F_kPadM : bool + F_kPadN : bool + + @dataclass + class default_2d_epilogue: + F_problem : Any + + @dataclass + class k_kernel: + F_pipeline : Any + F_epilogue : Any + + @dataclass + class h_traits: + F_XDataType : str + F_YDataType : str + F_XScaleDataType : str + F_YScaleDataType : str + F_Repeat_M : int + F_Repeat_N : int + F_ThreadPerBlock_M : int + F_ThreadPerBlock_N : int + F_Vector_N : int + F_kPadN : bool + F_kSaveMeanInvStd_ : bool + F_kTwoPass_ : bool + F_kFusedAdd : int + F_kFusedQuant : int + + @property + def trait_name(self) ->str: + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}' + t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + return t_ + + # string when calling this kernel + @property + def call_name(self) -> str: + return f'layernorm2d_fwd_>' + + # string when define this kernel + @property + def def_name(self) -> str: + return f'template float layernorm2d_fwd_>(const S&, A);' + + # this class hold kernel under same source file + @dataclass + class h_instance: + F_DataTypePair : str + F_N : str + F_add : int + F_sweep : int + instance_list : List[Any] # List[h_traits] + + @property + def name(self) -> str: + prec_i, prec_o = self.F_DataTypePair.split(',') + dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' + nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}' + if self.F_add != 0: + nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] + if self.F_sweep != 0: + nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + return nnn + + @property + def instance_name(self) ->str: + return self.name + + @property + def content(self) ->str: + instance_defs = '' + for ins in self.instance_list: + instance_defs += ins.def_name + '\n' + return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs) + + @property + def name_api(self) -> str: + return 'layernorm2d_fwd_api' + + @property + def name_common_header(self) -> str: + return 'layernorm2d_fwd_api_common' + + @property + def content_api(self) -> str: + # 1 sort based on dtype + t_dtype_dict = dict() + blobs = self.get_blobs() + for blob in blobs: + if blob.F_DataTypePair not in t_dtype_dict: + t_dtype_dict[blob.F_DataTypePair] = {} + if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]: + t_dtype_dict[blob.F_DataTypePair][blob.F_N] = [] + t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob) + + d_str = '' + for i_d, dtype_ in enumerate(t_dtype_dict): + blob_per_t = t_dtype_dict[dtype_] + n_str = '' + for i_n, n_ in enumerate(blob_per_t): + blob_per_n = blob_per_t[n_] + inner_str = "" + for i_b, b_ in enumerate(blob_per_n): + # generate single kernel instance file + #vec_str = "" + for i_ins, ins in enumerate(b_.instance_list): + idx_in_n = i_b * len(b_.instance_list) + i_ins + len_in_n = len(blob_per_n) * len(b_.instance_list) + # _if = 'if' if i_ins == 0 else 'else if' + if ins.F_kFusedQuant == 0: + _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) + elif ins.F_kFusedQuant == 1: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType) + elif ins.F_kFusedQuant == 2: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) + _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, + f_sweep_cond = _sweep_cond) + inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), + F_VEC_COND = _cond, F_instance_func=ins.call_name) + #inner_str = inner_str + vec_str + n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' + n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) + prec_i, prec_o = dtype_.split(',') + d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) + + api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str) + return api_base + + @property + def content_common_header(self) -> str: + return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) + + def get_blobs(self): + h_traits = layernorm_fwd_codegen.h_traits + h_instance = layernorm_fwd_codegen.h_instance + + dynamic_quant_out_dtype = ['int8'] + # some predefined support range + # (prec_i,prec_o) for simplicity this string will be used as key for dict + scale_list = [('fp32,fp32')] + dtype_list = [('fp16,fp16'), ('bf16,bf16'), + ('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out + #fused_add_list = [0, 1, 2] + #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant + fused_add_list = [0, 1] + fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant + + # rm rn tm tn vn pd mv 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} + total_blob = list() + for hs_key in h_trait_dict: + hs = h_trait_dict[hs_key] + current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N + for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): + prec_i, prec_o = dtype.split(',') + scale_x, scale_y = scale_type.split(',') + if prec_o in dynamic_quant_out_dtype and fused_quant != 1: + continue # skip non dynamic quant case + if fused_quant == 1 and hs_key == 'big': + continue + current_hs = list() + for chs_ in hs: + h_ = copy.copy(chs_) # copy the base instance out + h_.F_XDataType = prec_i + h_.F_YDataType = prec_o + h_.F_XScaleDataType = scale_y + h_.F_YScaleDataType = scale_x + h_.F_kFusedAdd = fused_add + h_.F_kFusedQuant = fused_quant + current_hs.append(h_) # + "\n" + #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = 'big' if hs_key == 'big' else current_n + total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) + return total_blob + + def list_blobs(self) -> None: + w_p = Path(self.working_path) + list_p = w_p / 'layernorm2d_fwd_blobs.txt' + blobs = self.get_blobs() + with list_p.open('a') as list_f: + # api related file + list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") + list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") + # kernel instance file + for b in blobs: + list_f.write(str(w_p / (b.name + ".cpp")) + "\n") + + def gen_blobs(self) -> None: + w_p = Path(self.working_path) + (w_p / (self.name_api + ".cpp")).write_text(self.content_api) + (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) + blobs = self.get_blobs() + for b in blobs: + (w_p / (b.name + ".cpp")).write_text(b.content) + +def list_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + layernorm_fwd_codegen(args.working_path, args.filter).list_blobs() + + +def gen_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs() + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK layernorm kernel", + ) + parser.add_argument( + "-a", + "--api", + default='fwd[all]', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated" + ) + + # this script have 2 modes + # 1) list_blobs mode, will generate a txt file with all the files going to be generated. + # this is useful in build system like cmake to construct source code dependency, by + # reading the content out of this file + # 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework + # like FA, only need to use this mode + parser.add_argument( + "-l", + "--list_blobs", + action='store_true', + help="list all the kernels to a file, " + ) + + parser.add_argument( + "-g", + "--gen_blobs", + action='store_true', + help="generate all kernels into different tile" + ) + + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-t", + "--traits", + default="all", + required=False, + help="enable/disable some feature. default generate all" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt." + ) + + args = parser.parse_args() + + # print(f'{args.list_blobs}-{args.gen_blobs}') + if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): + print('gen_blobs/list_blobs must specify only one option') + sys.exit() + + p = Path(args.working_path) + if not p.exists(): + p.mkdir() + + if args.list_blobs: + list_blobs(args) + else: + gen_blobs(args) diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp deleted file mode 100644 index f2f51de5d9..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp +++ /dev/null @@ -1,155 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include "layernorm2d_fwd.hpp" - -template -using trait_ = layernorm2d_fwd_traits_; - -template -float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/, - layernorm2d_fwd_args a, - const ck_tile::stream_config& s) -{ -#if 1 - float r = -1; - // clang-format off - // rm rn tm tn vn pd mv 2p - if(a.n <= 64) { - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 128) { - if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 256) { - if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 512) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 768) { - if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 1024) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 1536) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 2048) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 3072) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 4096) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n > 4096) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - return r; -#else - return layernorm2d_fwd_>(s, a); -#endif - // clang-format on -} - -float layernorm2d_fwd(layernorm2d_fwd_traits t, - layernorm2d_fwd_args a, - const ck_tile::stream_config& s) -{ - - float r = -1; - if(t.data_type.compare("fp16") == 0) - { - return layernorm2d_fwd_b16_(t, a, s); - } - else if(t.data_type.compare("bf16") == 0) - { - return layernorm2d_fwd_b16_(t, a, s); - } - if(r < 0) - throw std::runtime_error("Without supported instances!"); - - return r; -} diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp deleted file mode 100644 index 2a20d1e057..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp +++ /dev/null @@ -1,22 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -#if 0 -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -template float layernorm2d_fwd_>(const S&, A); -#endif - -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp deleted file mode 100644 index d043efc86c..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp deleted file mode 100644 index a6ffc8cd2f..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp deleted file mode 100644 index 80beeca67b..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp deleted file mode 100644 index b362a550a0..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp deleted file mode 100644 index 9c2d78999c..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp deleted file mode 100644 index c0c75f878b..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp deleted file mode 100644 index 1bcd0f8a7e..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp deleted file mode 100644 index 6b25fce8c2..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp deleted file mode 100644 index c4400f0f24..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp deleted file mode 100644 index 7f0e4898cb..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp +++ /dev/null @@ -1,22 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -#if 0 -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -template float layernorm2d_fwd_>(const S&, A); -#endif - -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp deleted file mode 100644 index 8c3a42cc4f..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp deleted file mode 100644 index 04d8bc1533..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp deleted file mode 100644 index c325747494..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp deleted file mode 100644 index c71db57a6a..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp deleted file mode 100644 index f3ca0932ef..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp deleted file mode 100644 index 242f1d2dd5..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp deleted file mode 100644 index e3bfa8e3a4..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp deleted file mode 100644 index 90d960cf09..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp deleted file mode 100644 index 0960a95c31..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp deleted file mode 100644 index 22895e8edd..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp +++ /dev/null @@ -1,67 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include "layernorm2d_fwd.hpp" -#include - -#pragma once - -using S = ck_tile::stream_config; -using A = layernorm2d_fwd_args; - -template -using trait_ = layernorm2d_fwd_traits_; - -template -float layernorm2d_fwd_(const S& s, A a) -{ - using DataType = typename Traits_::DataType; - - using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< - typename LayerNormTypeConfig::XDataType, - typename LayerNormTypeConfig::GammaDataType, - typename LayerNormTypeConfig::BetaDataType, - typename LayerNormTypeConfig::ComputeDataType, - typename LayerNormTypeConfig::YDataType, - typename LayerNormTypeConfig::MeanDataType, - typename LayerNormTypeConfig::InvStdDataType, - typename Traits_::Shape, - Traits_::kPadN, - Traits_::kSaveMeanInvStd, - Traits_::kTwoPass>; - - using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass; - using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; - using Pipeline = std::conditional_t; - - using Kernel = ck_tile::Layernorm2dFwd; - - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = 1; - - auto kargs = Kernel::MakeKargs(a); - if(s.log_level_ > 0) - std::cout << ", " << Kernel::GetName() << std::flush; - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); -} diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index 4f12d91032..43f4e8c724 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -1,5 +1,6 @@ #include "ck_tile/host.hpp" #include "layernorm2d_fwd.hpp" +#include #include // different threshold for different dtype @@ -29,7 +30,16 @@ auto create_args(int argc, char* argv[]) .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") - .insert("prec", "fp16", "precision") + .insert("prec_i", "fp16", "input precision") + .insert("prec_o", "auto", "output precision, set auto will be the same as input") + .insert("prec_sx", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1") + .insert("prec_sy", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1 or 2") + .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter"); @@ -37,7 +47,11 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t m = arg_parser.get_int("m"); @@ -45,21 +59,46 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t stride = arg_parser.get_int("stride"); if(stride < 0) stride = n; - float epsilon = arg_parser.get_float("e"); - std::string data_type = arg_parser.get_str("prec"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); + float epsilon = arg_parser.get_float("e"); + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sx == "auto") + { + prec_sx = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } + + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + if(fused_quant == 1 && prec_o != "int8") + { + std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; + return false; + } assert(stride >= n); - using TypeConfig = LayerNormTypeConfig; + using TypeConfig = LayerNormTypeConfig; - using XDataType = typename TypeConfig::XDataType; - using YDataType = typename TypeConfig::YDataType; - using GammaDataType = typename TypeConfig::GammaDataType; - using BetaDataType = typename TypeConfig::BetaDataType; + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using BetaDataType = typename TypeConfig::BetaDataType; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; using MeanDataType = std::conditional_t; @@ -73,36 +112,72 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor gamma_host({n}); ck_tile::HostTensor beta_host({n}); + ck_tile::HostTensor x_residual_host({m, n}, {stride, 1}); + ck_tile::HostTensor y_residual_host({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor mean_host_ref({m}); ck_tile::HostTensor invStd_host_ref({m}); + ck_tile::HostTensor y_scale_host_ref({m}); + ck_tile::HostTensor y_scale_host_dev({m}); + + ck_tile::HostTensor x_scale_host({n}); + ck_tile::HostTensor x_scale_host_dev({n}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(x_scale_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); gamma_buf.ToDevice(gamma_host.data()); beta_buf.ToDevice(beta_host.data()); + x_residual_buf.ToDevice(x_residual_host.data()); + x_scale_buf.ToDevice(x_scale_host.data()); - std::cout << "[" << data_type << "]" + auto prec_str = [&]() { + auto base_str = prec_i; + if(prec_i != prec_o) + { + base_str += "|" + prec_o; + } + if(fused_quant == 1) + { + base_str += std::string("(") + prec_sy + ")"; + } + return base_str; + }(); + + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; - layernorm2d_fwd_traits traits{data_type, SaveMeanVar}; + layernorm2d_fwd_traits traits{ + prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, gamma_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), - nullptr, - nullptr, + fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, + nullptr, // p_mean, unsupported yet + nullptr, // p_invStd, unsupported yet + epsilon, m, n, @@ -111,6 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser) float ave_time = layernorm2d_fwd( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + if(ave_time < 0) + { + std::cout << " not supported!" << std::endl << std::flush; + return false; + } + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; @@ -122,6 +203,17 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { // reference + if(fused_add != 0) + { + // fused pre_add/pre_add_store + // TODO we accumulate directly to x_host for simplcity here... + + std::transform(x_host.mData.cbegin(), + x_host.mData.cend(), + x_residual_host.mData.cbegin(), + x_host.mData.begin(), + std::plus{}); + } ck_tile::reference_layernorm2d_fwd( x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + if(fused_quant != 0) + { + auto dquant_functor = [&](int m_, auto& o_, auto& acc_) { + int N_ = acc_.mDesc.get_lengths()[1]; + if(fused_quant == 1) + { + for(int n_ = 0; n_ < N_; n_++) + { + // input smooth outlier + acc_(m_, n_) = + acc_(m_, n_) * ck_tile::type_convert(x_scale_host(n_)); + } + } + ComputeDataType absmax = static_cast(0); + for(int n_ = 0; n_ < N_; n_++) + { + const auto a = ck_tile::abs(acc_(m_, n_)); + absmax = a > absmax ? a : absmax; + } + // printf("cpu:absmax:%f\n", absmax); + ComputeDataType y_scale = absmax / static_cast(127.0); + y_scale_host_ref(m_) = ck_tile::type_convert(y_scale); + for(int n_ = 0; n_ < N_; n_++) + { + o_(m_, n_) = ck_tile::type_convert(acc_(m_, n_) / y_scale); + } + }; + + ck_tile::reference_layernorm2d_fwd(x_host, + gamma_host, + beta_host, + y_host_ref, + mean_host_ref, + invStd_host_ref, + epsilon, + dquant_functor); + } + else + { + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + } + y_buf.FromDevice(y_host_dev.data()); - auto [rtol, atol] = get_elimit(); + ck_tile::HostTensor sy_host_dev({m, n}, {stride, 1}); + if(fused_add == 1) + { + y_residual_buf.FromDevice(sy_host_dev.data()); + } + + auto [rtol, atol] = get_elimit(); + if(stride == n) { pass = ck_tile::check_err( y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + if(fused_add == 1) + { + pass &= ck_tile::check_err( + sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol); + } } else { @@ -153,8 +312,30 @@ bool run(const ck_tile::ArgParser& arg_parser) std::string("] Error: Incorrect results!"), rtol, atol); + if(fused_add == 1) + { + std::vector sy_host_dev_row( + sy_host_dev.begin() + i_r * stride, sy_host_dev.begin() + i_r * stride + n); + std::vector sy_host_ref_row( + x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n); + pass &= ck_tile::check_err(sy_host_dev_row, + sy_host_ref_row, + std::string("ADD[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } } } + if(fused_quant == 1) + { + y_scale_buf.FromDevice(y_scale_host_dev.data()); + pass &= ck_tile::check_err(y_scale_host_dev, + y_scale_host_ref, + std::string("SCALE Error: Incorrect results!"), + rtol, + atol); + } std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } @@ -168,23 +349,56 @@ int main(int argc, char* argv[]) if(!result) return -1; - const std::string data_type = arg_parser.get_str("prec"); - int save_mv = arg_parser.get_int("save_mv"); - if(data_type == "fp16" && save_mv) + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + + if(prec_o == "auto") { - return run(arg_parser) ? 0 : -2; + prec_o = prec_i; } - else if(data_type == "fp16" && !save_mv) + if(prec_sx == "auto") { - return run(arg_parser) ? 0 : -2; + prec_sx = "fp32"; } - else if(data_type == "bf16" && save_mv) + if(prec_sy == "auto") { - return run(arg_parser) ? 0 : -2; + prec_sy = "fp32"; } - else if(data_type == "bf16" && !save_mv) + int save_mv = arg_parser.get_int("save_mv"); + + // no dynamic quant case + if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + + // dynamic quant case, only in inference + else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp index 861e4a0230..a0f2db0e8a 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -8,31 +8,35 @@ #include "ck_tile/ops/layernorm2d.hpp" #include -template +template struct LayerNormTypeConfig; -template <> -struct LayerNormTypeConfig +template +struct LayerNormTypeConfig { using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; + using YDataType = OutType; using GammaDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t; using InvStdDataType = ck_tile::half_t; using ComputeDataType = float; + using XScaleDataType = XScaleDataType_; + using YScaleDataType = YScaleDataType_; }; -template <> -struct LayerNormTypeConfig +template +struct LayerNormTypeConfig { using XDataType = ck_tile::bf16_t; - using YDataType = ck_tile::bf16_t; + using YDataType = OutType; using GammaDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t; using InvStdDataType = ck_tile::bf16_t; using ComputeDataType = float; + using XScaleDataType = XScaleDataType_; + using YScaleDataType = YScaleDataType_; }; // runtime args @@ -40,82 +44,21 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs { }; -// this is used to pattern-match internl kernel implementation, not to instantiate kernel -template -struct layernorm2d_fwd_traits_ -{ - using DataType = ck_tile::remove_cvref_t; - - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); - } - else - { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; - } - }(); - - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; - static constexpr ck_tile::index_t Repeat_N = Repeat_N_; - - static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; - static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Layernorm2dShape; - - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; - static constexpr bool kTwoPass = kTwoPass_; -}; - -template -float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a); - // This is the public API, will be generated by script struct layernorm2d_fwd_traits { - std::string data_type; - bool save_mean_var; + std::string prec_i; // input precision + std::string prec_o; // output precision + + // if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set + // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise + // can set arbitrary(will skip check) + std::string prec_sx; // x-scale, used for [1*N] input smooth quant + std::string prec_sy; // y-scale, used for [M*1] output for next layer + + bool save_mean_var; // + int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/02_layernorm2d/misc/dquant.png b/example/ck_tile/02_layernorm2d/misc/dquant.png new file mode 100644 index 0000000000..28b1a61a14 Binary files /dev/null and b/example/ck_tile/02_layernorm2d/misc/dquant.png differ diff --git a/example/ck_tile/02_layernorm2d/misc/pnorm.png b/example/ck_tile/02_layernorm2d/misc/pnorm.png new file mode 100644 index 0000000000..65a27e8751 Binary files /dev/null and b/example/ck_tile/02_layernorm2d/misc/pnorm.png differ diff --git a/example/ck_tile/02_layernorm2d/script/perf_test.sh b/example/ck_tile/02_layernorm2d/script/perf_test.sh index bfb7f9ffe5..a34624536c 100755 --- a/example/ck_tile/02_layernorm2d/script/perf_test.sh +++ b/example/ck_tile/02_layernorm2d/script/perf_test.sh @@ -2,37 +2,37 @@ # run from top of ck folder EXE=build/bin/tile_example_layernorm2d_fwd -$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/script/smoke_test.sh b/example/ck_tile/02_layernorm2d/script/smoke_test.sh index dcd40fda40..d56406b6f2 100755 --- a/example/ck_tile/02_layernorm2d/script/smoke_test.sh +++ b/example/ck_tile/02_layernorm2d/script/smoke_test.sh @@ -2,30 +2,34 @@ # call from top of CK folder EXE=./build/bin/tile_example_layernorm2d_fwd +for fquant in "" "-fquant=1 -prec_o=int8"; do for pr_i in "fp16" "bf16" ; do -$EXE -prec=$pr_i -m=99 -n=13 -$EXE -prec=$pr_i -m=17 -n=16 -$EXE -prec=$pr_i -m=1 -n=100 -$EXE -prec=$pr_i -m=4 -n=128 -$EXE -prec=$pr_i -m=80 -n=127 -$EXE -prec=$pr_i -m=22 -n=255 -stride=256 -$EXE -prec=$pr_i -m=7 -n=599 -$EXE -prec=$pr_i -m=19 -n=512 -$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 -$EXE -prec=$pr_i -m=11 -n=510 -$EXE -prec=$pr_i -m=171 -n=676 -stride=818 -$EXE -prec=$pr_i -m=91 -n=636 -$EXE -prec=$pr_i -m=12 -n=768 -stride=800 -$EXE -prec=$pr_i -m=100 -n=766 -stride=812 -$EXE -prec=$pr_i -m=31 -n=1024 -$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 -$EXE -prec=$pr_i -m=8 -n=1501 -$EXE -prec=$pr_i -m=3 -n=1826 -$EXE -prec=$pr_i -m=5 -n=2040 -$EXE -prec=$pr_i -m=7 -n=2734 -$EXE -prec=$pr_i -m=1 -n=3182 -$EXE -prec=$pr_i -m=9 -n=4096 -$EXE -prec=$pr_i -m=3 -n=8192 -$EXE -prec=$pr_i -m=1 -n=10547 -$EXE -prec=$pr_i -m=3 -n=17134 +for fadd in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 +#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 +done +done done diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 2c423831e1..3b198502d0 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -25,6 +25,7 @@ #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/int8.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" diff --git a/include/ck_tile/core/numeric/int8.hpp b/include/ck_tile/core/numeric/int8.hpp new file mode 100644 index 0000000000..9ca3333c39 --- /dev/null +++ b/include/ck_tile/core/numeric/int8.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/random.hpp" +#include +#include + +#pragma once + +namespace ck_tile { + +// use int8_t directly for int8 arithemetic +// here one can use ck_tile::int8_t to access original int8_t +using int8_t = int8_t; + +// limits +template +struct numeric; + +template <> +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr int8_t epsilon() + { + return 1; // not used + } + + CK_TILE_HOST_DEVICE static constexpr int8_t round_error() + { + return 1; // not used + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr int8_t infinity() + { + return 1; // not used + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN() + { + return 1; // not used + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN() + { + return 1; // not used + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min() + { + return 1; // not used + } + + CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; } +}; + +#if 0 +template +struct numeric_traits; + +template <> +struct numeric_traits +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr int bias = 15; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + using bitwise_type = uint16_t; +}; +#endif + +CK_TILE_HOST_DEVICE +constexpr float int8_to_float(const int8_t& x) { return static_cast(x); } + +CK_TILE_HOST_DEVICE +constexpr int8_t float_to_int8(const float& x) { return static_cast(x); } + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index cb18cde70d..4011e08ce4 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -10,6 +10,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/int8.hpp" namespace ck_tile { @@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float) CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float) CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) +CK_TILE_TYPE_CONVERT(float, float, int8_t, int8) +CK_TILE_TYPE_CONVERT(int8_t, int8, float, float) + #undef CK_TILE_TYPE_CONVERT #endif diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index 9707f2990a..de99be1965 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, return null_tile_window>{window_lengths}; } +template +CK_TILE_DEVICE constexpr auto make_tile_window(const null_tile_window& t, + const StaticTileDistribution&) +{ + return t; +} + template CK_TILE_DEVICE void move_tile_window(null_tile_window&, diff --git a/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp index 837f52c399..62cd26b6ab 100644 --- a/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp @@ -8,20 +8,44 @@ namespace ck_tile { +// Note: for simplicity, each functor only care about single M +struct reference_layernorm2d_default_epilogue +{ + template + void operator()(int m, HostTensor& o, const HostTensor& acc) + { + const int N = acc.mDesc.get_lengths()[1]; + for(int n = 0; n < N; ++n) + { + o(m, n) = ck_tile::type_convert(acc(m, n)); + } + } + + template + auto operator()(int m, const HostTensor& acc) + { + HostTensor o(acc.get_lengths(), acc.get_strides()); + operator()(m, o, acc); + return o; + } +}; + template + typename InvStdDataType, + typename Epilogue = reference_layernorm2d_default_epilogue> void reference_layernorm2d_fwd(const HostTensor& x_m_n, const HostTensor& gamma_n, const HostTensor& beta_n, HostTensor& y_m_n, HostTensor& mean_m, HostTensor& invStd_m, - ComputeDataType epsilon) + ComputeDataType epsilon, + Epilogue epilogue_functor = {}) { auto layernorm2d_fwd_func = [&](auto m) { const int N = x_m_n.mDesc.get_lengths()[1]; @@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor& x_m_n, if constexpr(!std::is_same_v) invStd_m(m) = ck_tile::type_convert(divisor); + HostTensor acc(x_m_n.get_lengths(), x_m_n.get_strides()); for(int n = 0; n < N; ++n) { ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); ComputeDataType beta = ck_tile::type_convert(beta_n(n)); - auto y = (x - mean) * divisor; - y = y * gamma + beta; + auto a_ = (x - mean) * divisor; + a_ = a_ * gamma + beta; - y_m_n(m, n) = ck_tile::type_convert(y); + acc(m, n) = a_; } + + epilogue_functor(m, y_m_n, acc); }; make_ParallelTensorFunctor(layernorm2d_fwd_func, diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index eb06fea2dd..fb8d7221b8 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -9,4 +9,5 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 4363ea1f55..1510f18a30 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -3,4 +3,5 @@ #pragma once +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp similarity index 96% rename from include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp rename to include/ck_tile/ops/common/generic_2d_block_shape.hpp index e4b60331eb..64ad20c3be 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -1,11 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck_tile/core.hpp" - namespace ck_tile { + /* // clang-format off @@ -42,7 +41,7 @@ template typename Vector_, // contiguous pixels(vector size) along seq index_t BlockSize_ = warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> -struct Layernorm2dShape +struct Generic2dBlockShape { // block size static constexpr index_t Block_M = BlockTile_::at(number<0>{}); diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 62ba9dc0b3..cd1e43fb8c 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -4,4 +4,5 @@ #pragma once #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index a98f60b364..c24744bdbc 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -5,4 +5,6 @@ #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 5dc49c3b0e..7c5d5a6f31 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -9,23 +9,29 @@ namespace ck_tile { // this epilogue just store out a M*N matrix, row major -template +template struct Default2DEpilogueProblem { - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; }; template struct Default2DEpilogue { - using Problem = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool UseRawStore = Problem::UseRawStore; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } @@ -36,7 +42,7 @@ struct Default2DEpilogue { // TODO: this is ugly - if constexpr(kPadM || kPadN) + if constexpr(UseRawStore && (kPadM || kPadN)) { store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); buffer_store_fence(); diff --git a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp new file mode 100644 index 0000000000..2e29604116 --- /dev/null +++ b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" + +namespace ck_tile { + +template +struct DynamicQuantEpilogueTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; + static constexpr bool UseMax3 = UseMax3_; +}; + +// this epilogue just store out a M*N matrix, row major +template +struct DynamicQuantEpilogueProblem +{ + using AccDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockShape = remove_cvref_t; // can consum generic 2d shape + using Traits = remove_cvref_t; +}; + +template +struct DynamicQuantEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + static constexpr bool kPadM = Problem::Traits::kPadM; + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool UseRawStore = Problem::Traits::UseRawStore; + static constexpr bool UseMax3 = Problem::Traits::UseMax3; + + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2d{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dSync{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dCrossWarpSync{}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); + return reduce_crosswarp_sync.GetSmemSize(); + } + + // TODO: this function assume store out vector size is the same as OAccTile last dimension size + // how do we fix this ? + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, + YScaleWindow& y_scale_window, + const OAccTile& o_acc_tile, + void* smem) + { + auto reduce = GetBlockReduce2d(); + auto reduce_sync = GetBlockReduce2dSync(); + auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); + + const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; + + auto row_absmax = [&]() { + constexpr auto y_size_per_row = + OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at( + number<1>{}); + // constexpr auto y_size_per_row = OAccTile::get_lengths()[number<1>{}]; + if constexpr(UseMax3 && std::is_same_v && y_size_per_row % 2 == 0) + { + // fast max3 implementation + const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) { + float rtn; + asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" + : "=v"(rtn) + : "v"(acc_), "v"(v_0_), "v"(v_1_)); + return rtn; + }; + return reduce(o_acc_tile, type_convert(0), f_max3, sequence<1, 2>{}); + } + else + { + return reduce(o_acc_tile, type_convert(0), f_absmax); + } + }(); + reduce_sync(row_absmax, f_absmax); + reduce_crosswarp_sync(row_absmax, smem, f_absmax); + + // here y_scale is Acc TYpe, need convert to YScale type later + auto y_scale = tile_elementwise_in( + [&](const auto& v_) { + return v_ / type_convert(numeric::max()); + }, + row_absmax); + + store_tile(y_scale_window, cast_tile(y_scale)); + + auto o_acc_scaled_tile = + make_static_distributed_tensor(o_acc_tile.get_tile_distribution()); + + sweep_tile(o_acc_tile, [&](auto idx) { + constexpr auto row_id = make_tuple(idx[number<0>{}]); + o_acc_scaled_tile(idx) = o_acc_tile[idx] / y_scale(row_id); + }); + + // TODO: this is ugly + if constexpr(UseRawStore && (kPadM || kPadN)) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_scaled_tile)); + buffer_store_fence(); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_scaled_tile)); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 9389a5397f..e106264cef 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -43,4 +43,5 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c3e028528b..ac74782a3a 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -39,4 +39,5 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 57e83a7a51..2b02bcc5d2 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -6,4 +6,5 @@ #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 2a403b0f49..711c5d8595 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -4,9 +4,10 @@ #pragma once #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" -#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index cebe5131a7..9a2e06d05f 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -5,19 +5,24 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" namespace ck_tile { // host side args struct Layernorm2dFwdHostArgs { - const void* p_x; - const void* p_gamma; - const void* p_beta; + const void* p_x; // [m ,n], input, fp16/bf16 + const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used + const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used + const void* p_gamma; // [1, n], gamma, prec same as input + const void* p_beta; // [1, n], beta, prec same as input - void* p_y; - void* p_mean; - void* p_invStd; + void* p_y; // [m, n], output, fp16/bf16 + void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used + void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used + void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used + void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used float epsilon; @@ -27,10 +32,11 @@ struct Layernorm2dFwdHostArgs }; // TODO: Extract some type to wrapper class -template +template struct Layernorm2dFwd { using Pipeline = remove_cvref_t; + using Epilogue = remove_cvref_t; using Problem = typename Pipeline::Problem; using XDataType = remove_cvref_t; @@ -40,18 +46,26 @@ struct Layernorm2dFwd using YDataType = remove_cvref_t; using MeanDataType = remove_cvref_t; using InvStdDataType = remove_cvref_t; + using XScaleDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + + // for simplicity, shortcut input/output type is same as X + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; static constexpr bool kHasGamma = !std::is_same_v; static constexpr bool kHasBeta = !std::is_same_v; - static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd; - static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; - static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; + static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd; + static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd; + static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd; - static constexpr index_t Block_M = Problem::BlockShape::Block_M; - static constexpr index_t Block_N = Problem::BlockShape::Block_N; - static constexpr bool kPadM = false; // always no need to pad along M - static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kTwoPass = Problem::kTwoPass; + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool kTwoPass = Problem::Traits::kTwoPass; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; @@ -62,13 +76,18 @@ struct Layernorm2dFwd struct Kargs { - const void* p_x; - const void* p_gamma; - const void* p_beta; + const void* p_x; // [m ,n], input, fp16/bf16 + const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used + const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used + const void* p_gamma; // [1, n], gamma, prec same as input + const void* p_beta; // [1, n], beta, prec same as input - void* p_y; - void* p_mean; - void* p_invStd; + void* p_y; // [m, n], output, fp16/bf16 + void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used + void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used + + void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used + void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used float epsilon; @@ -81,9 +100,13 @@ struct Layernorm2dFwd CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) { return Kargs{hargs.p_x, + hargs.p_x_residual, + hargs.p_x_scale, hargs.p_gamma, hargs.p_beta, hargs.p_y, + hargs.p_y_residual, + hargs.p_y_scale, hargs.p_mean, hargs.p_invStd, hargs.epsilon, @@ -106,6 +129,7 @@ struct Layernorm2dFwd template <> struct t2s { static constexpr const char * name = "bf16"; }; template <> struct t2s { static constexpr const char * name = "fp8"; }; template <> struct t2s { static constexpr const char * name = "bf8"; }; + template <> struct t2s { static constexpr const char * name = "int8"; }; // clang-format on // in byte @@ -113,24 +137,41 @@ struct Layernorm2dFwd CK_TILE_HOST static std::string GetName() { +#define _SS_ std::string +#define _TS_ std::to_string // clang-format off using S_ = typename Problem::BlockShape; auto surfix = [&] () { std::string n; + if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName::name; + if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName::name; if (kPadN) n += "_pn"; if (kSaveMeanInvStd) n += "_mv"; - if (kTwoPass) n += "_2p"; + // if (kTwoPass) n += "_2p"; return n; }(); - #define _SS_ std::string - #define _TS_ std::to_string - return _SS_("layernorm2d_fwd_") + _SS_(t2s::name) + "_" + + auto prec_str = [&] () { + std::string base_str = _SS_(t2s::name); + if (!std::is_same_v) { + base_str += _SS_("_") + _SS_(t2s::name); + } + if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { + base_str += _SS_("_sx") + _SS_(t2s::name); + base_str += _SS_("_sy") + _SS_(t2s::name); + } + if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) { + base_str += _SS_("_sy") + _SS_(t2s::name); + } + return base_str; + }(); + + return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _SS_(Pipeline::name) + surfix; - #undef _SS_ - #undef _TS_ // clang-format on +#undef _SS_ +#undef _TS_ } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -153,6 +194,31 @@ struct Layernorm2dFwd tmp2_, make_tuple(number{}, number{}), {iM, 0}); }(); + const auto x_residual_window = [&]() { + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || + kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_x_residual), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel + // will check the max count dynamically + const auto tmp2_ = pad_tensor_view(tmp_, + make_tuple(number{}, number{}), + sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + } + else + { + return make_null_tile_window(make_tuple(number{}, number{})); + } + }(); + const auto gamma_window = [&]() { const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_gamma), @@ -194,6 +260,28 @@ struct Layernorm2dFwd tmp2_, make_tuple(number{}, number{}), {iM, 0}); }(); + auto y_residual_window = [&]() { + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) + { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_y_residual), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + auto tmp2_ = pad_tensor_view(tmp_, + make_tuple(number{}, number{}), + sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + } + else + { + return make_null_tile_window(make_tuple(number{}, number{})); + } + }(); + auto mean_window = [&]() { if constexpr(kSaveMean) { @@ -232,17 +320,60 @@ struct Layernorm2dFwd return make_null_tile_window(make_tuple(number{})); }(); + auto x_scale_window = [&]() { + if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) + { + const auto win_ = [&]() { + const auto tmp_0_ = make_naive_tensor_view_packed( + static_cast(kargs.p_x_scale), + make_tuple(kargs.n), + number{}); + + return pad_tensor_view(tmp_0_, + make_tuple(number{}), + sequence{}); // x_scale no need pad + }(); + return make_tile_window(win_, make_tuple(number{}), {0}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + + auto y_scale_window = [&]() { + if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT || + kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) + { + const auto win_ = [&]() { + const auto tmp_0_ = make_naive_tensor_view_packed( + static_cast(kargs.p_y_scale), + make_tuple(kargs.m), + number<1>{}); + + return pad_tensor_view( + tmp_0_, make_tuple(number{}), sequence{}); + }(); + return make_tile_window(win_, make_tuple(number{}), {iM}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + __shared__ char smem[GetSmemSize()]; Pipeline{}(x_window, + x_residual_window, gamma_window, beta_window, y_window, + y_residual_window, mean_window, inv_std_window, + x_scale_window, + y_scale_window, static_cast(kargs.epsilon), kargs.n, - smem); + smem, + Epilogue{}); } }; diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index c767a472a9..16a7c3b86d 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include #include @@ -24,20 +25,25 @@ struct Layernorm2dFwdPipelineOnePass using MeanDataType = ck_tile::remove_cvref_t; using InvStdDataType = ck_tile::remove_cvref_t; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; + static constexpr bool kHasGamma = !std::is_same_v; static constexpr bool kHasBeta = !std::is_same_v; - static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; - static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; + static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd; + static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM - static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) - return "bpr_op"; // block per row + return "bpr"; // block per row else - return "wpr_op"; // warp per row + return "wpr"; // warp per row }(); CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -46,20 +52,30 @@ struct Layernorm2dFwdPipelineOnePass } template + typename InvStdWindow, + typename XScaleWindow, + typename YScaleWindow, + typename Epilogue> CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const XResidualWindow& x_residual_window_, const GammaWindow& gamma_window_, const BetaWindow& beta_window_, - YWindow& y_window, + YWindow& y_window_, + const YResidualWindow& y_residual_window_, MeanWindow& mean_window, InvStdWindow& inv_std_window, + const XScaleWindow& x_scale_window_, + YScaleWindow& y_scale_window, ComputeDataType epsilon, ck_tile::index_t row_size, - void* smem) const + void* smem, + Epilogue) const { const auto x_window = make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); @@ -67,8 +83,17 @@ struct Layernorm2dFwdPipelineOnePass gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution()); const auto beta_window = make_tile_window( beta_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + const auto x_residual_window = make_tile_window( + x_residual_window_, Policy::template MakeXBlockTileDistribution()); + auto y_residual_window = make_tile_window( + y_residual_window_, Policy::template MakeXBlockTileDistribution()); + const auto x_scale_window = make_tile_window( + x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); + auto x_scale = load_tile(x_scale_window); - const auto x = load_tile(x_window); int cur_count = 0; int max_count = block_tile_welford_calculate_max_count(row_size); @@ -81,6 +106,18 @@ struct Layernorm2dFwdPipelineOnePass const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || + kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + x(idx) = type_convert(x_resi(idx)) + + type_convert(x(idx)); + }); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) + store_tile(y_residual_window, x); + } + // compute welford each-thread->cross-lane->cross-warp auto [mean, var] = block_welford(x, cur_count, max_count); block_welford_sync(mean, var, cur_count); @@ -100,8 +137,8 @@ struct Layernorm2dFwdPipelineOnePass store_tile(inv_std_window, cast_tile(inv_std)); // layernorm computation - auto y = make_static_distributed_tensor(x.get_tile_distribution()); - sweep_tile(y, [&, mean_ = mean](auto idx) { + auto ln = make_static_distributed_tensor(x.get_tile_distribution()); + sweep_tile(ln, [&, mean_ = mean](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); @@ -109,11 +146,28 @@ struct Layernorm2dFwdPipelineOnePass const auto beta_ = type_convert(beta[j_idx]); const auto x_ = type_convert(x[idx]); - auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; + auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; - y(idx) = type_convert(y_); + ln(idx) = ln_; }); - store_tile(y_window, y); + + if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) + { + // smooth-quant pre-scale, then run rowwise-quant + sweep_tile(ln, [&](auto idx) { + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + const auto xs_ = type_convert(x_scale[j_idx]); + ln(idx) = ln(idx) * xs_; + }); + } + + if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || + kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) + { + Epilogue{}(y_window_, y_scale_window, ln, smem); + } + else + Epilogue{}(y_window_, ln); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp index 8e9f8e81e4..7ec830add1 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp @@ -14,10 +14,10 @@ template + typename Traits_> struct Layernorm2dFwdPipelineProblem { using XDataType = remove_cvref_t; @@ -27,14 +27,14 @@ struct Layernorm2dFwdPipelineProblem using YDataType = remove_cvref_t; using MeanDataType = remove_cvref_t; using InvStdDataType = remove_cvref_t; + using XScaleDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; using BlockShape = remove_cvref_t; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; - static constexpr bool kTwoPass = kTwoPass_; + using Traits = remove_cvref_t; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index e35d02e707..ec10efbc69 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -24,20 +24,25 @@ struct Layernorm2dFwdPipelineTwoPass using MeanDataType = ck_tile::remove_cvref_t; using InvStdDataType = ck_tile::remove_cvref_t; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; + static constexpr bool kHasGamma = !std::is_same_v; static constexpr bool kHasBeta = !std::is_same_v; - static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; - static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; + static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd; + static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM - static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) - return "bpr_tp"; // block per row + return "bpr_2p"; // block per row else - return "wpr_tp"; // warp per row + return "wpr_2p"; // warp per row }(); CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -46,20 +51,30 @@ struct Layernorm2dFwdPipelineTwoPass } template + typename InvStdWindow, + typename XScaleWindow, + typename YScaleWindow, + typename Epilogue> CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const XResidualWindow& x_residual_window_, const GammaWindow& gamma_window_, const BetaWindow& beta_window_, YWindow& y_window, + const YResidualWindow& y_residual_window_, MeanWindow& mean_window, InvStdWindow& inv_std_window, + const XScaleWindow& /*x_scale_window*/, + YScaleWindow& /*y_scale_window*/, ComputeDataType epsilon, ck_tile::index_t row_size, - void* smem) const + void* smem, + Epilogue) const { auto x_window = make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); @@ -67,6 +82,10 @@ struct Layernorm2dFwdPipelineTwoPass gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution()); auto beta_window = make_tile_window( beta_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + auto x_residual_window = make_tile_window( + x_residual_window_, Policy::template MakeXBlockTileDistribution()); + auto y_residual_window = make_tile_window( + y_residual_window_, Policy::template MakeXBlockTileDistribution()); // Problem::BlockShape static constexpr index_t Block_N = Problem::BlockShape::Block_N; @@ -93,9 +112,26 @@ struct Layernorm2dFwdPipelineTwoPass for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - const auto x = load_tile(x_window); - block_welford(x, mean, var, cur_count, max_count); + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); + move_tile_window(x_window, {0, Block_N}); + move_tile_window(x_residual_window, {0, Block_N}); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || + kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + x(idx) = type_convert(x_resi(idx)) + + type_convert(x(idx)); + }); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) + { + store_tile(y_residual_window, x); + move_tile_window(y_residual_window, {0, Block_N}); + } + } + block_welford(x, mean, var, cur_count, max_count); } block_welford_sync(mean, var, cur_count); @@ -119,6 +155,7 @@ struct Layernorm2dFwdPipelineTwoPass row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window}); @@ -126,14 +163,24 @@ struct Layernorm2dFwdPipelineTwoPass // layernorm computation for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - const auto x = load_tile(x_window); + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || + kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + x(idx) = type_convert(x_resi(idx)) + + type_convert(x(idx)); + }); + } // load gamma/beta (TODO: support no gamma/beta?) const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); - auto y = make_static_distributed_tensor(x.get_tile_distribution()); + auto ln = make_static_distributed_tensor(x.get_tile_distribution()); - sweep_tile(y, [&, mean_ = mean](auto idx) { + sweep_tile(ln, [&, mean_ = mean](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); @@ -141,14 +188,16 @@ struct Layernorm2dFwdPipelineTwoPass const auto beta_ = type_convert(beta[j_idx]); const auto x_ = type_convert(x[idx]); - auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; + auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; - y(idx) = type_convert(y_); + ln(idx) = ln_; }); - store_tile(y_window, y); + static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT); + Epilogue{}(y_window, ln); move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(gamma_window, {-Block_N}); move_tile_window(beta_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N}); diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp new file mode 100644 index 0000000000..fb327f74a3 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +enum class Layernorm2dFusedAddEnum +{ + NO_ADD = 0, + // fused add before layernorm and store result to global + PRE_ADD_STORE = 1, + // fused add before layernorm, but not store result + PRE_ADD = 2, +}; + +// clang-format off +template struct Layernorm2dFusedAddEnumName; +template<> struct Layernorm2dFusedAddEnumName { static constexpr const char * name = "no"; }; +template<> struct Layernorm2dFusedAddEnumName { static constexpr const char * name = "pras"; }; +template<> struct Layernorm2dFusedAddEnumName { static constexpr const char * name = "pra"; }; +// clang-format on + +enum class Layernorm2dFusedQuantEnum +{ + NO_SWEEP = 0, + SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale + DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale +}; + +// clang-format off +template struct Layernorm2dFusedQuantEnumName; +template<> struct Layernorm2dFusedQuantEnumName { static constexpr const char * name = "no"; }; +template<> struct Layernorm2dFusedQuantEnumName { static constexpr const char * name = "dqt"; }; +template<> struct Layernorm2dFusedQuantEnumName { static constexpr const char * name = "smdqt"; }; +// clang-format on + +template +struct Layernorm2dFwdTraits +{ + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_; + static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index ee8c693727..990e9ecc03 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -5,4 +5,5 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index fe2d24044e..aa617ee2b4 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -7,4 +7,5 @@ #include "ck_tile/ops/reduce/block/block_reduce2d.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index fa3007d1e4..c93329bfbe 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -301,7 +301,10 @@ struct BlockReduce2D .get_static_tile_distribution_encoding(), ReduceDim{})); - return make_static_distributed_tensor(acc_dstr); + auto dst_ = make_static_distributed_tensor(acc_dstr); + // init acc_tensor + tile_elementwise_inout([&](auto& x_) { x_ = type_convert(reduce_init); }, dst_); + return dst_; } // return number of pixels each lane need to reduce diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index beb8c718e3..3c68147112 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -17,14 +17,24 @@ struct BlockReduce2d CK_TILE_DEVICE constexpr BlockReduce2d() {} - template + template > CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, YDistributedTensor_& y_tensor, - const ReduceFunc& reduce_func) + const ReduceFunc& reduce_func, + ReducePacksPerXDim = {}) { + sweep_tile( + [&](auto... idx_) { + constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]); + y_tensor(idx_0) = reduce_func(y_tensor(idx_0), x_tensor[idx_]...); + }, + ReducePacksPerXDim{}); +#if 0 constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; - constexpr auto spans = XDistributedTensor_::get_distributed_spans(); // FIXME: hard coded to reduce 2nd axis @@ -42,6 +52,7 @@ struct BlockReduce2d y_tensor(y_dstr_idx) = y; }); +#endif } template @@ -63,14 +74,17 @@ struct BlockReduce2d return tensor; } - template + template > CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor, const ComputeDataType& reduce_init, - const ReduceFunc& reduce_func) + const ReduceFunc& reduce_func, + ReducePacksPerXDim = {}) { auto y_tensor = MakeYBlockTile(); set_tile(y_tensor, reduce_init); - (*this)(x_tensor, y_tensor, reduce_func); + (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{}); return y_tensor; } diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 98c60f1b51..f0a6cf9603 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -9,4 +9,5 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index 584ca70689..4df34e1e0d 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -5,4 +5,5 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index b1143e4a06..fcae3e02dc 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -5,4 +5,5 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index 809473d53b..cc7dbffee4 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -7,4 +7,5 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/welford.hpp b/include/ck_tile/ops/welford.hpp index ebf9406837..a4c479dd95 100644 --- a/include/ck_tile/ops/welford.hpp +++ b/include/ck_tile/ops/welford.hpp @@ -6,4 +6,5 @@ #include "ck_tile/ops/welford/block/block_welford.hpp" #include "ck_tile/ops/welford/block/block_welford_problem.hpp" #include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"