From 776c87ea7e8649242405c19d165573965acf1e5b Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 31 Oct 2024 14:54:53 +0800 Subject: [PATCH] [CK_TILE] layernorm support fused-quant/fused-add (#1604) * add prenorm/postnorm support, refactor using generate.py * update README * update README * fix format * update some description and fix format * update format * format * use non-raw for loading * format and update n4096 * dynamic-quant ready * update readme * support fused dynamic-quant * update fused-quant, with smooth * update README * update args * update some based on comment [ROCm/composable_kernel commit: c3a4800c5fe1f7cbdd00f36b7bc4851e0299ddc9] --- example/ck_tile/02_layernorm2d/CMakeLists.txt | 31 +- example/ck_tile/02_layernorm2d/README.md | 69 +- example/ck_tile/02_layernorm2d/generate.py | 670 ++++++++++++++++++ .../instances/layernorm2d_fwd_api.cpp | 155 ---- .../layernorm2d_fwd_bf16_n1024_instance.cpp | 22 - .../layernorm2d_fwd_bf16_n1536_instance.cpp | 13 - .../layernorm2d_fwd_bf16_n2048_instance.cpp | 14 - .../layernorm2d_fwd_bf16_n256_instance.cpp | 12 - .../layernorm2d_fwd_bf16_n3072_instance.cpp | 14 - .../layernorm2d_fwd_bf16_n4096_instance.cpp | 14 - ...layernorm2d_fwd_bf16_n4096_tp_instance.cpp | 14 - .../layernorm2d_fwd_bf16_n512_instance.cpp | 13 - ...layernorm2d_fwd_bf16_n64_n128_instance.cpp | 12 - .../layernorm2d_fwd_bf16_n768_instance.cpp | 12 - .../layernorm2d_fwd_fp16_n1024_instance.cpp | 22 - .../layernorm2d_fwd_fp16_n1536_instance.cpp | 13 - .../layernorm2d_fwd_fp16_n2048_instance.cpp | 14 - .../layernorm2d_fwd_fp16_n256_instance.cpp | 12 - .../layernorm2d_fwd_fp16_n3072_instance.cpp | 14 - .../layernorm2d_fwd_fp16_n4096_instance.cpp | 14 - ...layernorm2d_fwd_fp16_n4096_tp_instance.cpp | 14 - .../layernorm2d_fwd_fp16_n512_instance.cpp | 13 - ...layernorm2d_fwd_fp16_n64_n128_instance.cpp | 12 - .../layernorm2d_fwd_fp16_n768_instance.cpp | 12 - .../layernorm2d_fwd_instance_common.hpp | 67 -- .../02_layernorm2d/layernorm2d_fwd.cpp | 270 ++++++- .../02_layernorm2d/layernorm2d_fwd.hpp | 103 +-- .../ck_tile/02_layernorm2d/misc/dquant.png | Bin 0 -> 36863 bytes example/ck_tile/02_layernorm2d/misc/pnorm.png | Bin 0 -> 32113 bytes .../02_layernorm2d/script/perf_test.sh | 66 +- .../02_layernorm2d/script/smoke_test.sh | 54 +- include/ck_tile/core.hpp | 1 + include/ck_tile/core/numeric/int8.hpp | 104 +++ include/ck_tile/core/numeric/type_convert.hpp | 4 + .../ck_tile/core/tensor/null_tile_window.hpp | 7 + .../reference/reference_layernorm2d_fwd.hpp | 37 +- include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 1 + include/ck_tile/ops/common.hpp | 1 + .../generic_2d_block_shape.hpp} | 7 +- include/ck_tile/ops/elementwise.hpp | 1 + include/ck_tile/ops/epilogue.hpp | 2 + .../ops/epilogue/default_2d_epilogue.hpp | 28 +- .../ops/epilogue/dynamic_quant_epilogue.hpp | 140 ++++ include/ck_tile/ops/fmha.hpp | 1 + include/ck_tile/ops/gemm.hpp | 1 + include/ck_tile/ops/image_to_column.hpp | 1 + include/ck_tile/ops/layernorm2d.hpp | 3 +- .../kernel/layernorm2d_fwd_kernel.hpp | 187 ++++- .../layernorm2d_fwd_pipeline_one_pass.hpp | 82 ++- .../layernorm2d_fwd_pipeline_problem.hpp | 12 +- .../layernorm2d_fwd_pipeline_two_pass.hpp | 79 ++- .../pipeline/layernorm2d_fwd_traits.hpp | 54 ++ include/ck_tile/ops/permute.hpp | 1 + include/ck_tile/ops/reduce.hpp | 1 + .../ck_tile/ops/reduce/block/block_reduce.hpp | 5 +- .../ops/reduce/block/block_reduce2d.hpp | 26 +- include/ck_tile/ops/rmsnorm2d.hpp | 1 + include/ck_tile/ops/softmax.hpp | 1 + include/ck_tile/ops/topk.hpp | 1 + include/ck_tile/ops/topk_softmax.hpp | 1 + include/ck_tile/ops/welford.hpp | 1 + 61 files changed, 1790 insertions(+), 766 deletions(-) create mode 100644 example/ck_tile/02_layernorm2d/generate.py delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp delete mode 100644 example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp create mode 100644 example/ck_tile/02_layernorm2d/misc/dquant.png create mode 100644 example/ck_tile/02_layernorm2d/misc/pnorm.png create mode 100644 include/ck_tile/core/numeric/int8.hpp rename include/ck_tile/ops/{layernorm2d/kernel/layernorm2d_fwd_shape.hpp => common/generic_2d_block_shape.hpp} (96%) create mode 100644 include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp create mode 100644 include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index feae5f791d..1bf74bc055 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -1,11 +1,34 @@ +set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd") +set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") +if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all") + set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS}) +endif() + +# generate a list of kernels, but not actually emit files at config sta +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs +) + set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" + message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") -file(GLOB INSTANCE_SRCS instances/*.cpp) add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) +target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 405325a2a1..14c6fc0d67 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -1,6 +1,42 @@ # Layernorm2D forward -This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. +This folder contains example for Layernorm2D forward using `ck_tile` tile-programming implementation. + +# Implementation and feature support + +## welford online algorithm +We use welfold algorithm to update `mean`/`variance` block by block. For `N <=4096` case we can compute `mean`/`var`/`normalization` within one loop, we call it `one-pass`. For large N case, it is hard to keep `mean`/`var` inside register/LDS and then computation `normalization`, so we need to load input twice, first time to compute `mean`/`var` block-by-block, then load input another time to compute the `normalization`. We call it `two-pass`. + +## mean/variance save +In training case the mean/variance need to store out (TBD, not supported yet) + +## prenorm/postnorm + +![](misc/pnorm.png) + +since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default). + +## smooth-quant/dynamic-quant +we support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen) +![](misc/dquant.png) + +``` +# assume output int8, hidden_states is [m, n] shape and in fp16/bf16 +# [m, 1] +per_token_amax, _ = torch.max( + input=torch.abs(hidden_states), + dim=-1, + keepdim=True +) +per_token_scale = per_token_amax.to(dtype=torch.float32) / 127.0 + +# quant hidden_states +hidden_states = (hidden_states / per_token_scale).to(dtype=torch.int8) + +return hidden_states, per_token_scale +# hidden_states now is int8 will feed to next layer as intput +# per_token_scale will be used as dequant factor later layer +``` ## build ``` @@ -15,8 +51,35 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd` ``` args: -m m dimension (default:3328) - -n m dimension (default:4096) + -n n dimension (default:4096) + -stride stride per row, if -1 then equal to n (default:-1) -e epsilon (default:1e-5) + -save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0) -v cpu validation or not (default:1) - -prec precision (default:fp16) + -kname print kernel name or not (default:1) + -prec_i input precision (default:fp16) + -prec_o output precision, set auto will be the same as input (default:auto) + -prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) + -prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto) + -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) + -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) + -warmup cold iter (default:5) + -repeat hot iter (default:20) + ``` + +## limitations +Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, N>8192 case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. + +``` +# some case +# standard fp16 layernorm 2d, m=10. n=1024 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 + +# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant, output in int8 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 + +# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1 + +``` \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py new file mode 100644 index 0000000000..300f6c05e1 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -0,0 +1,670 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Any +import functools +import itertools +import copy +from dataclasses import dataclass + +def get_if_str(idx, total, lase_else = True): + if idx == 0: + return 'if' + elif idx < total - 1: + return 'else if' + else: + if lase_else: + return 'else' + else: + return 'else if' + +FUSED_ADD_ENUM_STR_MAP = [ + 'no', + 'pras', # pre-norm + 'pra' ] # post-norm + +FUSED_FUSED_SWEEP_STR_MAP = [ + 'no', + 'dquant' ] + +DATA_TYPE_MAP = {'fp32' : 'float', + 'fp16' : 'ck_tile::fp16_t', + 'bf16' : 'ck_tile::bf16_t', + 'int8' : 'ck_tile::int8_t'} + +def BOOL_MAP(b_) -> str: + if b_: + return 'true' + else: + return 'false' + +class layernorm_fwd_codegen: + API_TRAITS_DEFINE = """ +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct layernorm2d_fwd_traits_ +{ + using XDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using XScaleDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; + static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; +}; + +template +using traits_ = layernorm2d_fwd_traits_; +""" + API_COMMON_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" +#include +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = layernorm2d_fwd_args; + +{F_traits_define} + +template +float layernorm2d_fwd_(const S& s, A a) +{{ + using XDataType = typename Traits_::XDataType; + using YDataType = typename Traits_::YDataType; + using XScaleDataType = typename Traits_::XScaleDataType; + using YScaleDataType = typename Traits_::YScaleDataType; + using ComputeDataType = typename LayerNormTypeConfig::ComputeDataType; + + using PipelineTraits = ck_tile::Layernorm2dFwdTraits(Traits_::kFusedAdd), + static_cast(Traits_::kFusedQuant)>; + using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< + typename LayerNormTypeConfig::XDataType, + typename LayerNormTypeConfig::GammaDataType, + typename LayerNormTypeConfig::BetaDataType, + typename LayerNormTypeConfig::ComputeDataType, + typename LayerNormTypeConfig::YDataType, + typename LayerNormTypeConfig::MeanDataType, + typename LayerNormTypeConfig::InvStdDataType, + typename LayerNormTypeConfig::XScaleDataType, + typename LayerNormTypeConfig::YScaleDataType, + typename Traits_::Shape, + PipelineTraits>; + + using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; + using Default2DEpilogue = ck_tile::Default2DEpilogue; + + using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + + using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; + + using Epilogue = std::conditional_t; + + using Kernel = ck_tile::Layernorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); +}} + +""" + + API_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" + +{F_traits_define} + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a); + +float layernorm2d_fwd(layernorm2d_fwd_traits t, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{{ + float r = -1; +{F_dispatch} + return r; +}} + +""" + + API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{ +{F_per_n_case} + }} +""" + API_PER_N_CASE=""" {F_if} {F_N_COND} {{ +{F_inner_dispatch} + }} +""" + API_INNER_CASE=""" {F_if} {F_VEC_COND} + r={F_instance_func}(s, a); +""" + + INSTANCE_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_api_common.hpp" + +// clang-format off +// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep +{F_instance_def} +// clang-format on + +""" + + def __init__(self, working_path, kernel_filter): + self.working_path = working_path + self.kernel_filter = kernel_filter + + class k_fuesd_add_enum(IntEnum): + F_NO_ADD = 0 + F_PRE_ADD = 1 + F_PRE_ADD_STORE_RESIDUAL = 2 + + class k_fused_sweep_enum(IntEnum): + F_NO_SWEEP = 0 + F_RENORM = 1 + F_DYNAMIC_QUANT = 2 + + @dataclass + class k_traits: + F_kPadN : bool + F_kSaveMeanInvStd : bool + F_kTwoPass : bool + F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum + F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum + + @dataclass + class k_shape: + F_BlockTile : List[int] + F_WarpPerBlock : List[int] + F_WarpTile : List[int] + F_Vector_ : List[int] + @property + def F_BlockSize(self) -> int: + return functools.reduce(lambda a, b: a*b, self.F_WarpTile) + + @dataclass + class k_problem: + F_XDataType : str + F_GammaDataType : str + F_BetaDataType : str + F_ComputeDataType : str + F_YDataType : str + F_MeanDataType : str + F_InvStdDataType : str + F_BlockShape : str + F_Traits : Any #k_traits + + @dataclass + class k_pipeline_one_pass: + F_Problem : Any #k_problem + + @dataclass + class k_pipeline_two_pass: + F_Problem : Any #k_problem + + @dataclass + class default_2d_epilogue_problem: + F_AccDataType : str + F_ODataType : str + F_kPadM : bool + F_kPadN : bool + + @dataclass + class default_2d_epilogue: + F_problem : Any + + @dataclass + class k_kernel: + F_pipeline : Any + F_epilogue : Any + + @dataclass + class h_traits: + F_XDataType : str + F_YDataType : str + F_XScaleDataType : str + F_YScaleDataType : str + F_Repeat_M : int + F_Repeat_N : int + F_ThreadPerBlock_M : int + F_ThreadPerBlock_N : int + F_Vector_N : int + F_kPadN : bool + F_kSaveMeanInvStd_ : bool + F_kTwoPass_ : bool + F_kFusedAdd : int + F_kFusedQuant : int + + @property + def trait_name(self) ->str: + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}' + t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + return t_ + + # string when calling this kernel + @property + def call_name(self) -> str: + return f'layernorm2d_fwd_>' + + # string when define this kernel + @property + def def_name(self) -> str: + return f'template float layernorm2d_fwd_>(const S&, A);' + + # this class hold kernel under same source file + @dataclass + class h_instance: + F_DataTypePair : str + F_N : str + F_add : int + F_sweep : int + instance_list : List[Any] # List[h_traits] + + @property + def name(self) -> str: + prec_i, prec_o = self.F_DataTypePair.split(',') + dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' + nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}' + if self.F_add != 0: + nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] + if self.F_sweep != 0: + nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + return nnn + + @property + def instance_name(self) ->str: + return self.name + + @property + def content(self) ->str: + instance_defs = '' + for ins in self.instance_list: + instance_defs += ins.def_name + '\n' + return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs) + + @property + def name_api(self) -> str: + return 'layernorm2d_fwd_api' + + @property + def name_common_header(self) -> str: + return 'layernorm2d_fwd_api_common' + + @property + def content_api(self) -> str: + # 1 sort based on dtype + t_dtype_dict = dict() + blobs = self.get_blobs() + for blob in blobs: + if blob.F_DataTypePair not in t_dtype_dict: + t_dtype_dict[blob.F_DataTypePair] = {} + if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]: + t_dtype_dict[blob.F_DataTypePair][blob.F_N] = [] + t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob) + + d_str = '' + for i_d, dtype_ in enumerate(t_dtype_dict): + blob_per_t = t_dtype_dict[dtype_] + n_str = '' + for i_n, n_ in enumerate(blob_per_t): + blob_per_n = blob_per_t[n_] + inner_str = "" + for i_b, b_ in enumerate(blob_per_n): + # generate single kernel instance file + #vec_str = "" + for i_ins, ins in enumerate(b_.instance_list): + idx_in_n = i_b * len(b_.instance_list) + i_ins + len_in_n = len(blob_per_n) * len(b_.instance_list) + # _if = 'if' if i_ins == 0 else 'else if' + if ins.F_kFusedQuant == 0: + _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) + elif ins.F_kFusedQuant == 1: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType) + elif ins.F_kFusedQuant == 2: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) + _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, + f_sweep_cond = _sweep_cond) + inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), + F_VEC_COND = _cond, F_instance_func=ins.call_name) + #inner_str = inner_str + vec_str + n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' + n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) + prec_i, prec_o = dtype_.split(',') + d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) + + api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str) + return api_base + + @property + def content_common_header(self) -> str: + return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) + + def get_blobs(self): + h_traits = layernorm_fwd_codegen.h_traits + h_instance = layernorm_fwd_codegen.h_instance + + dynamic_quant_out_dtype = ['int8'] + # some predefined support range + # (prec_i,prec_o) for simplicity this string will be used as key for dict + scale_list = [('fp32,fp32')] + dtype_list = [('fp16,fp16'), ('bf16,bf16'), + ('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out + #fused_add_list = [0, 1, 2] + #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant + fused_add_list = [0, 1] + fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant + + # rm rn tm tn vn pd mv 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} + total_blob = list() + for hs_key in h_trait_dict: + hs = h_trait_dict[hs_key] + current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N + for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): + prec_i, prec_o = dtype.split(',') + scale_x, scale_y = scale_type.split(',') + if prec_o in dynamic_quant_out_dtype and fused_quant != 1: + continue # skip non dynamic quant case + if fused_quant == 1 and hs_key == 'big': + continue + current_hs = list() + for chs_ in hs: + h_ = copy.copy(chs_) # copy the base instance out + h_.F_XDataType = prec_i + h_.F_YDataType = prec_o + h_.F_XScaleDataType = scale_y + h_.F_YScaleDataType = scale_x + h_.F_kFusedAdd = fused_add + h_.F_kFusedQuant = fused_quant + current_hs.append(h_) # + "\n" + #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = 'big' if hs_key == 'big' else current_n + total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) + return total_blob + + def list_blobs(self) -> None: + w_p = Path(self.working_path) + list_p = w_p / 'layernorm2d_fwd_blobs.txt' + blobs = self.get_blobs() + with list_p.open('a') as list_f: + # api related file + list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") + list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") + # kernel instance file + for b in blobs: + list_f.write(str(w_p / (b.name + ".cpp")) + "\n") + + def gen_blobs(self) -> None: + w_p = Path(self.working_path) + (w_p / (self.name_api + ".cpp")).write_text(self.content_api) + (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) + blobs = self.get_blobs() + for b in blobs: + (w_p / (b.name + ".cpp")).write_text(b.content) + +def list_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + layernorm_fwd_codegen(args.working_path, args.filter).list_blobs() + + +def gen_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs() + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK layernorm kernel", + ) + parser.add_argument( + "-a", + "--api", + default='fwd[all]', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated" + ) + + # this script have 2 modes + # 1) list_blobs mode, will generate a txt file with all the files going to be generated. + # this is useful in build system like cmake to construct source code dependency, by + # reading the content out of this file + # 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework + # like FA, only need to use this mode + parser.add_argument( + "-l", + "--list_blobs", + action='store_true', + help="list all the kernels to a file, " + ) + + parser.add_argument( + "-g", + "--gen_blobs", + action='store_true', + help="generate all kernels into different tile" + ) + + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-t", + "--traits", + default="all", + required=False, + help="enable/disable some feature. default generate all" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt." + ) + + args = parser.parse_args() + + # print(f'{args.list_blobs}-{args.gen_blobs}') + if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): + print('gen_blobs/list_blobs must specify only one option') + sys.exit() + + p = Path(args.working_path) + if not p.exists(): + p.mkdir() + + if args.list_blobs: + list_blobs(args) + else: + gen_blobs(args) diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp deleted file mode 100644 index f2f51de5d9..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp +++ /dev/null @@ -1,155 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include "layernorm2d_fwd.hpp" - -template -using trait_ = layernorm2d_fwd_traits_; - -template -float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/, - layernorm2d_fwd_args a, - const ck_tile::stream_config& s) -{ -#if 1 - float r = -1; - // clang-format off - // rm rn tm tn vn pd mv 2p - if(a.n <= 64) { - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 128) { - if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 256) { - if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 512) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 768) { - if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 1024) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 1536) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 2048) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 3072) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n <= 4096) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - else if(a.n > 4096) { - if (a.n % 8 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = layernorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = layernorm2d_fwd_>(s, a); - else - r = layernorm2d_fwd_>(s, a); - } - return r; -#else - return layernorm2d_fwd_>(s, a); -#endif - // clang-format on -} - -float layernorm2d_fwd(layernorm2d_fwd_traits t, - layernorm2d_fwd_args a, - const ck_tile::stream_config& s) -{ - - float r = -1; - if(t.data_type.compare("fp16") == 0) - { - return layernorm2d_fwd_b16_(t, a, s); - } - else if(t.data_type.compare("bf16") == 0) - { - return layernorm2d_fwd_b16_(t, a, s); - } - if(r < 0) - throw std::runtime_error("Without supported instances!"); - - return r; -} diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp deleted file mode 100644 index 2a20d1e057..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp +++ /dev/null @@ -1,22 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -#if 0 -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -template float layernorm2d_fwd_>(const S&, A); -#endif - -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp deleted file mode 100644 index d043efc86c..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp deleted file mode 100644 index a6ffc8cd2f..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp deleted file mode 100644 index 80beeca67b..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n256_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp deleted file mode 100644 index b362a550a0..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp deleted file mode 100644 index 9c2d78999c..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp deleted file mode 100644 index c0c75f878b..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp deleted file mode 100644 index 1bcd0f8a7e..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n512_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp deleted file mode 100644 index 6b25fce8c2..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp deleted file mode 100644 index c4400f0f24..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n768_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp deleted file mode 100644 index 7f0e4898cb..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp +++ /dev/null @@ -1,22 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -#if 0 -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -template float layernorm2d_fwd_>(const S&, A); -#endif - -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp deleted file mode 100644 index 8c3a42cc4f..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp deleted file mode 100644 index 04d8bc1533..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp deleted file mode 100644 index c325747494..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n256_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp deleted file mode 100644 index c71db57a6a..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp deleted file mode 100644 index f3ca0932ef..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp deleted file mode 100644 index 242f1d2dd5..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp deleted file mode 100644 index e3bfa8e3a4..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n512_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp deleted file mode 100644 index 90d960cf09..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp deleted file mode 100644 index 0960a95c31..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n768_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "layernorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd mv 2p -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -template float layernorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp b/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp deleted file mode 100644 index 22895e8edd..0000000000 --- a/example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_instance_common.hpp +++ /dev/null @@ -1,67 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include "layernorm2d_fwd.hpp" -#include - -#pragma once - -using S = ck_tile::stream_config; -using A = layernorm2d_fwd_args; - -template -using trait_ = layernorm2d_fwd_traits_; - -template -float layernorm2d_fwd_(const S& s, A a) -{ - using DataType = typename Traits_::DataType; - - using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< - typename LayerNormTypeConfig::XDataType, - typename LayerNormTypeConfig::GammaDataType, - typename LayerNormTypeConfig::BetaDataType, - typename LayerNormTypeConfig::ComputeDataType, - typename LayerNormTypeConfig::YDataType, - typename LayerNormTypeConfig::MeanDataType, - typename LayerNormTypeConfig::InvStdDataType, - typename Traits_::Shape, - Traits_::kPadN, - Traits_::kSaveMeanInvStd, - Traits_::kTwoPass>; - - using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass; - using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; - using Pipeline = std::conditional_t; - - using Kernel = ck_tile::Layernorm2dFwd; - - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = 1; - - auto kargs = Kernel::MakeKargs(a); - if(s.log_level_ > 0) - std::cout << ", " << Kernel::GetName() << std::flush; - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); -} diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index 4f12d91032..43f4e8c724 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -1,5 +1,6 @@ #include "ck_tile/host.hpp" #include "layernorm2d_fwd.hpp" +#include #include // different threshold for different dtype @@ -29,7 +30,16 @@ auto create_args(int argc, char* argv[]) .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") - .insert("prec", "fp16", "precision") + .insert("prec_i", "fp16", "input precision") + .insert("prec_o", "auto", "output precision, set auto will be the same as input") + .insert("prec_sx", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1") + .insert("prec_sy", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1 or 2") + .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter"); @@ -37,7 +47,11 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t m = arg_parser.get_int("m"); @@ -45,21 +59,46 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t stride = arg_parser.get_int("stride"); if(stride < 0) stride = n; - float epsilon = arg_parser.get_float("e"); - std::string data_type = arg_parser.get_str("prec"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); + float epsilon = arg_parser.get_float("e"); + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sx == "auto") + { + prec_sx = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } + + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + if(fused_quant == 1 && prec_o != "int8") + { + std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; + return false; + } assert(stride >= n); - using TypeConfig = LayerNormTypeConfig; + using TypeConfig = LayerNormTypeConfig; - using XDataType = typename TypeConfig::XDataType; - using YDataType = typename TypeConfig::YDataType; - using GammaDataType = typename TypeConfig::GammaDataType; - using BetaDataType = typename TypeConfig::BetaDataType; + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using BetaDataType = typename TypeConfig::BetaDataType; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; using MeanDataType = std::conditional_t; @@ -73,36 +112,72 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor gamma_host({n}); ck_tile::HostTensor beta_host({n}); + ck_tile::HostTensor x_residual_host({m, n}, {stride, 1}); + ck_tile::HostTensor y_residual_host({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor mean_host_ref({m}); ck_tile::HostTensor invStd_host_ref({m}); + ck_tile::HostTensor y_scale_host_ref({m}); + ck_tile::HostTensor y_scale_host_dev({m}); + + ck_tile::HostTensor x_scale_host({n}); + ck_tile::HostTensor x_scale_host_dev({n}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(x_scale_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); gamma_buf.ToDevice(gamma_host.data()); beta_buf.ToDevice(beta_host.data()); + x_residual_buf.ToDevice(x_residual_host.data()); + x_scale_buf.ToDevice(x_scale_host.data()); - std::cout << "[" << data_type << "]" + auto prec_str = [&]() { + auto base_str = prec_i; + if(prec_i != prec_o) + { + base_str += "|" + prec_o; + } + if(fused_quant == 1) + { + base_str += std::string("(") + prec_sy + ")"; + } + return base_str; + }(); + + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; - layernorm2d_fwd_traits traits{data_type, SaveMeanVar}; + layernorm2d_fwd_traits traits{ + prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, gamma_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), - nullptr, - nullptr, + fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, + nullptr, // p_mean, unsupported yet + nullptr, // p_invStd, unsupported yet + epsilon, m, n, @@ -111,6 +186,12 @@ bool run(const ck_tile::ArgParser& arg_parser) float ave_time = layernorm2d_fwd( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + if(ave_time < 0) + { + std::cout << " not supported!" << std::endl << std::flush; + return false; + } + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; @@ -122,6 +203,17 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { // reference + if(fused_add != 0) + { + // fused pre_add/pre_add_store + // TODO we accumulate directly to x_host for simplcity here... + + std::transform(x_host.mData.cbegin(), + x_host.mData.cend(), + x_residual_host.mData.cbegin(), + x_host.mData.begin(), + std::plus{}); + } ck_tile::reference_layernorm2d_fwd( x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + if(fused_quant != 0) + { + auto dquant_functor = [&](int m_, auto& o_, auto& acc_) { + int N_ = acc_.mDesc.get_lengths()[1]; + if(fused_quant == 1) + { + for(int n_ = 0; n_ < N_; n_++) + { + // input smooth outlier + acc_(m_, n_) = + acc_(m_, n_) * ck_tile::type_convert(x_scale_host(n_)); + } + } + ComputeDataType absmax = static_cast(0); + for(int n_ = 0; n_ < N_; n_++) + { + const auto a = ck_tile::abs(acc_(m_, n_)); + absmax = a > absmax ? a : absmax; + } + // printf("cpu:absmax:%f\n", absmax); + ComputeDataType y_scale = absmax / static_cast(127.0); + y_scale_host_ref(m_) = ck_tile::type_convert(y_scale); + for(int n_ = 0; n_ < N_; n_++) + { + o_(m_, n_) = ck_tile::type_convert(acc_(m_, n_) / y_scale); + } + }; + + ck_tile::reference_layernorm2d_fwd(x_host, + gamma_host, + beta_host, + y_host_ref, + mean_host_ref, + invStd_host_ref, + epsilon, + dquant_functor); + } + else + { + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + } + y_buf.FromDevice(y_host_dev.data()); - auto [rtol, atol] = get_elimit(); + ck_tile::HostTensor sy_host_dev({m, n}, {stride, 1}); + if(fused_add == 1) + { + y_residual_buf.FromDevice(sy_host_dev.data()); + } + + auto [rtol, atol] = get_elimit(); + if(stride == n) { pass = ck_tile::check_err( y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + if(fused_add == 1) + { + pass &= ck_tile::check_err( + sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol); + } } else { @@ -153,8 +312,30 @@ bool run(const ck_tile::ArgParser& arg_parser) std::string("] Error: Incorrect results!"), rtol, atol); + if(fused_add == 1) + { + std::vector sy_host_dev_row( + sy_host_dev.begin() + i_r * stride, sy_host_dev.begin() + i_r * stride + n); + std::vector sy_host_ref_row( + x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n); + pass &= ck_tile::check_err(sy_host_dev_row, + sy_host_ref_row, + std::string("ADD[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } } } + if(fused_quant == 1) + { + y_scale_buf.FromDevice(y_scale_host_dev.data()); + pass &= ck_tile::check_err(y_scale_host_dev, + y_scale_host_ref, + std::string("SCALE Error: Incorrect results!"), + rtol, + atol); + } std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } @@ -168,23 +349,56 @@ int main(int argc, char* argv[]) if(!result) return -1; - const std::string data_type = arg_parser.get_str("prec"); - int save_mv = arg_parser.get_int("save_mv"); - if(data_type == "fp16" && save_mv) + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + + if(prec_o == "auto") { - return run(arg_parser) ? 0 : -2; + prec_o = prec_i; } - else if(data_type == "fp16" && !save_mv) + if(prec_sx == "auto") { - return run(arg_parser) ? 0 : -2; + prec_sx = "fp32"; } - else if(data_type == "bf16" && save_mv) + if(prec_sy == "auto") { - return run(arg_parser) ? 0 : -2; + prec_sy = "fp32"; } - else if(data_type == "bf16" && !save_mv) + int save_mv = arg_parser.get_int("save_mv"); + + // no dynamic quant case + if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + + // dynamic quant case, only in inference + else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp index 861e4a0230..a0f2db0e8a 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -8,31 +8,35 @@ #include "ck_tile/ops/layernorm2d.hpp" #include -template +template struct LayerNormTypeConfig; -template <> -struct LayerNormTypeConfig +template +struct LayerNormTypeConfig { using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; + using YDataType = OutType; using GammaDataType = ck_tile::half_t; using BetaDataType = ck_tile::half_t; using MeanDataType = ck_tile::half_t; using InvStdDataType = ck_tile::half_t; using ComputeDataType = float; + using XScaleDataType = XScaleDataType_; + using YScaleDataType = YScaleDataType_; }; -template <> -struct LayerNormTypeConfig +template +struct LayerNormTypeConfig { using XDataType = ck_tile::bf16_t; - using YDataType = ck_tile::bf16_t; + using YDataType = OutType; using GammaDataType = ck_tile::bf16_t; using BetaDataType = ck_tile::bf16_t; using MeanDataType = ck_tile::bf16_t; using InvStdDataType = ck_tile::bf16_t; using ComputeDataType = float; + using XScaleDataType = XScaleDataType_; + using YScaleDataType = YScaleDataType_; }; // runtime args @@ -40,82 +44,21 @@ struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs { }; -// this is used to pattern-match internl kernel implementation, not to instantiate kernel -template -struct layernorm2d_fwd_traits_ -{ - using DataType = ck_tile::remove_cvref_t; - - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); - } - else - { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; - } - }(); - - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; - static constexpr ck_tile::index_t Repeat_N = Repeat_N_; - - static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; - static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Layernorm2dShape; - - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; - static constexpr bool kTwoPass = kTwoPass_; -}; - -template -float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a); - // This is the public API, will be generated by script struct layernorm2d_fwd_traits { - std::string data_type; - bool save_mean_var; + std::string prec_i; // input precision + std::string prec_o; // output precision + + // if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set + // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise + // can set arbitrary(will skip check) + std::string prec_sx; // x-scale, used for [1*N] input smooth quant + std::string prec_sy; // y-scale, used for [M*1] output for next layer + + bool save_mean_var; // + int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/02_layernorm2d/misc/dquant.png b/example/ck_tile/02_layernorm2d/misc/dquant.png new file mode 100644 index 0000000000000000000000000000000000000000..28b1a61a14ea6774191fc2ac54f195cb86477f9b GIT binary patch literal 36863 zcmce-RahNS(v;O0vH&CAQ;#uSXgk-ou>o{J}|JTe(~?$ zlw6iiGoV~B1~zDgql+;P$B0}ZlBi*+B!QoWZODV>}fhZ zq>X(d?lJ`zPu+MC%xp<$T53z#Xsu~r4P!<0{{$Ky^oI#U0EztXDqHUpILyB*xh^nY zh<{f_zkE9XtDzE;;O4(uU^F;iZ2s#C6A=~g@5%=r8kXSSl^_(v{~XTm|L5UoNZEwZ z0lvD2r2`|}jD1XU@d{#F0ER7F*j$f2q`%e;qkOSRh-BR=u9@J6y{mj_^E|WLw_c2) z6!&!2M73)=2&ymMH&L5k`nzNwJ9t>l{G9ez#YJT`is2AZ+b!0- z13UIsQD;AUhDyq6{&{N;1GP~?%r0|#a;6uI=VqTD*)SuClj{-CzEW@_#DB1#967fI z|IbgYFrcOb5sl#|aBBNP>)Fp5d&jO^R3XKXDt;sO!Pj`BsP~p*F)hTY~EtU?SA;JtfuaWk>330??!KJ7biyEp^-I? z>m(xX(Ce%Y#;@!bg-%^@sZjlU!!6c)S&8zMxz;-|IiFyNDIHcE&e?tSm&^D(pT&iM z&FzXBDhJWM-8WI3I_^Zr2M&$~?gm+go(zAtRJH|}Z>dA8@Mc9I3_xw2XhppHV0e~$ z!_}ST24e-L{uKT6sr~yoh2esRU4T(?0sp@8v3uRc_7b+Z~EUO zlSG{a&2N;JOAzi0jj_*h7IU2|7htASJfF8p%qMEDF_V3_Bwgpu$_r}YPc`T0$#u3k zoH-4DF@#$z%)SurklhFLA~pS zL4gBbdwCf|mU2!36(fV9ttA>+`2cos+=O@28BE}LeTZl(UL#B{QSEWYH>W(4DzZNu< zObss#3H5&1PSlxyE>>KIBK2MQpnCUCn+SVMU)kYSu50))s@OQOG}aX|)1%ILSFL!0 zi#Twtj0df6K}NzL9)H&a`8R^9KS5SZ%hP~2aMGPmgvpjJr&&JNe|B?cV5u8xR{Q}x zk33Q8TVHj4U+KancGfwDSxA(uE@5}z*>JslQ{`^m_%o&V^%;q!@^`pj zJ-!U30hW~M2>fb!K-u#Mvc1jqXeEbq8=;cAXPp3YfBr#)30~v1_no{|6tDEF?f~Ye zs8+Vzwx|eGwJgcMAwz{PxJg4eBfFkv|Jq(!ZrWX&Sp}WW%cktBK3g|pv*4gESv%ZvMd zlRjLp^;8;)?1_yYCGRkBO^AZ6)QN^9|7=;2D$r;hS>fx0sHRc7gz@ANWdmI5MvhaC z`T>`dkM1-rMKHm^Q_6r(9 zf&9vRA+@;5s@q;qp^M&}3(n;B5y8&co#xJyiMI@0KBjCUoJbrp6m$04zQoDS%O5>n zbr3SbDBvBRsF6LT78a00j3W&@H<=6p#|Q?$;1KdXZ0e@)_;fcNBsGxU!)mM+L62KV-9G6d-t@lrXs zrFKuM6BCC{95W#4@2@hXa+lg{MYslcNVbcl=FiRZ5L*q&l}+Z97uct(oL% zeT40}{>%%SJ3#X%q1uR-9o9q_nciW9?2x>Nu~Db0oFL*v+QVvg?#16x|J#==(2&S3 zjHMNP8xE|o1JwK4ZlTC__3M+YwPp7Il6?X;Dqyqi@U}LQqt&()mxhrn2;qsbBJp3= zzr6eO{|p2+H?k*><+frjn~GiyXxR$D8L~#0Cu^n3cVhp9Y!%KE?fHgeo6)}+U^7f9nLzKRjI=1(8hK|jAN zYYVU!2kDMB3Oh51=I5fR%&l9DG(@0Z@5M!Jth{-$x`b2z9WWjew4$1)o!6ZdW1?Q^ zlVZ6Od?0`O-AX^b2jR0t%b_WG##;S+-s~#JiOdGU%$B&@du1DEMbyzRQcU)rP7XhQ zUqqvRoTp9PAzz>5EP0fp4hfnFx#}SuO-b>c{28-p)jve@6L1JyOM< z8Tn>3j=#W0s)274q-q3Z2n7aw23$t2IyQ`J; z=ZOkC>P_m-^k2=lV^&t4IV!eH0;{65^SLbnJJ!a{wdSyHY1FE7I(E0+zw50MDGWWS z{thZ9584R)QSspUH}RG&UKUF=1;#2l&F|HD2SE*v%-Fy-qIb*a!u+2;^`?!lCLhiG zpaf$?%hvYZF4s5jq4x0HzgDwb;wBlXO|Y~+C!{bR{LiPDnEzQmfzc}~(DN)SO+^@7 z)%a*DjC8q0L!y+Ccd{<^Bz&H|r^c{+A$|QxYo-L*sC*&Q?9i&A(%zDilTEnsqq$NKvy9`*kjc+(P2nkp65;rM#@)4vJ!1?S?zWl6?fYN~>U*(?PI zSLNQtpgES-E}y~wor^o=BO%l-39<$VZfm-_IR#IjU=^_4ums;$k z{7b1#f5$yott}H{;XvePQe&zdYo%EGPW@kCX_efPq%z)2U4J zkFEGrgVKis2y`j6GT0m#A^OP|foAMfh5ny+(ExQS?lwXj_jRxA+;1H+NOOV&`n?`q zV#c4pYp~#vi^@031)!&Y=oBgIvPIZF;e=ENF+&sPDev%p?ceL_4FC=IAr3O@FP`tC z_1*al5Z^=`i~mDp`=fiTe|vcKD>%_Zj|N_K>^0UMNw3j9x>KwW`i{`qOlkF5oN1lV zvTa$Th|mMbak3w0%ReveLii5{oNsd z{i)pnbAG&kVYcV z`jhwp_6U!)bd3|R(^%Uk76WN})Q?GCH2?wklL_hN(K-4VPK`bY7!_DGj ztdJNJaA&z|px;EqUhc*9dtIUFh>KOjj=(bzsY97ST?*SsvPIU7+$Nrv(^(W>y|6mf zE4J3wR|SKLE3jpsO13=40hjTNC3DL32J2yNV6DehNb^gMGD{H4?_)>yqYp|N3s=b( z(@ZTE;(8gJ*@7~AB(M*)|GajQ6xj6B3+iLq57;y><7H!HH%W8lphl(&9X@sXiInAd zdCYbOnm8n|3hHEF zWc3~ydq}T}crC{)b!0bxe${x4;6F~hZF$=6_te%gq#7K}n7c>v6~?f5QnU927A-v# zh0vd%Ln~RxUCpct$X6OsEwC=SeER8S5OO^#+OG5aX%OkKoS&V>*~_>;yCKstQvbwg zaAGiRCT8Be%qVmFCwHRb{in?iIFKp7FdBIybZzwqKKUlnK>Q~*7dQ`*l4^ck!NA})Fb~(@N)hiF zZC2c!MW|t)INtPuNUcfUX1PiuSW&xjqm2EZTs3bI0v&TA>&$07VoI$YpTGo7KY_L@ z4wQ`8lpeznDHEQ|u{R#UKy(hRgG0NJ@^$c%cWh@dg8hLLlL6UJnbn?~@`vY=2l&-3 z@)G?Km##CqpbCb+4so%=zMB?Xi}7NWUVRN-NWzvDE*jT)ZsR}yw7@#JeXs`jmc0k1 zL1|KPhROEPRkko%JP zLKcNh=xS32BJNaT+`~=lNo)JDCphJcalv!9j$bq3n3tyZyfH%2;-niPjlPmf?og*l zE_a3wJ?FAU*KUrdm3e-o zH0n`7Fmv{a@evfyo*Q0mIxAXboAH$lsN?x1>K=E1X@QsL>6M|sXTb3bZh`jkPRoF& zuyo>M47kweF*#j5&i~Kh5I!vXiU+wX{dKQ_4|SotLo z^n(kG_QouvS>xxz;|LINpkd?VSD)4#uXInd14P@F#61!04b6j&cZg%>D+)FkH(eIG z>u7>X z-+zVo-GV&u^54BT12rC<5Z_RQ{lYX1$0MFI^ujSy_G@wG|8;aK8~)ZNyL^VXexh8H(M8LB_PQY@zQB@3bom{VcHGstdaYdbzr~@M+*@^ z!(g6v2L8xl>Bx}=s!7NdaOsHNJo&2rO7hS4t(keR!7O6L|lPjwO= zFE*dil)$JO2Z{3}ILado9RgL~gOAPf379>KE}% z4kZliBU%!ls>c42(Ard=Sa;ahB}*E~&!$sBmXn`7xf?5eNl)V{D!Wl2y28IRHc&@C zo|bb^GlqJ+{vvP8B-*({R+eI6NW$Riw}fcF=inmRT?rAybA8ufv>Qd6&yv5&--({C?C^6kp4*oms!b6 z!KXTS@VkisKDG6OZE(F7S~PzS;W6D<(yKZ;beLOA2Ng7H7nMyD{95QYUtx_FTK1so z@z0k~TXw$>*$mr;lH%3Rj_jyyD2cv9HEFEAw#5jUcV>rv1a~pRW#M0s@{ zcebIE*~rFKy9WR)i4`Strji_;d>DWK>{RpOy+0_IEth%%@iIzXj7Ezlr;*>AUebKb zh>+(UNHNJ;I84W_5;`(a=Z|ki%i35A-sX|IN@`iOo02_n8E)q?4%#cJUu-U*;Z}Eg z%i1F}>%$Sdn6%W4VIFgB<})Jk+P{wWOM!vIc%T6xt!65--^adfzTcqQrUcIQYf?>s ztB%ZU_rDr}#8WUCf9it^&<1S7&g;B0J*;+A@I9O{vtPl$A3lRb8uTwVg0_elj*sIvaVR8tEOxBd@HGetYsacfY^?*O zNT4RTpe9eBAkRrbc#J}D(`P%~1$D}@%L$uy=9es`yFFY~Px0B&A|Ki19Q6NfbmUbp zj$eE0=pFYcSB{#VFWbJq@T|MxcG;Lc{xo!sl=7pq{96|o?hKb}}{ zlOG&b(-6evsgPj_5h*LZUVfgHDZq z9RdI6Izb)#Kmqz!z9$ChY%HJ||8@L3lNXrp3oK~0aG)9gcC-b}bPWN!Dh5LRQvd1r zcc%Y$6N9}OZQ%1ucsao!WqynJ#TG*RpZ2~=`nuopmRx-V$wP-0998oNI18c%!+QTl}jXJ{5MymPh&I=t;n3Vzms{L{h1ynfYl`5(SctL#D7R@IOH?xM#ZN^2&K_M z9@w+X6zj8IX=mK;5|>piL@R*)&&2kZ^EJMupoOW%M|Ntt5+xPNS-07Icv&3#U1-?D z$)tTDF0TC3f$Q6;%<`V!snz_HZ90=l4YxpeW-QcrCE1OGERZ`9 ze$`|cCGtxG*y~7bhC}V^U=K!}BST$p$bNhy5An&`PqA8EM#<+{T8(jj>*-%&No*!i1t(pP^Yzy z%N$bfY?oW(kfWVj?{%@rr&yum#wAAcMJ!ld0@8qbuPOF#fJ#Vij*AU3ey8RuYlJva zYb8{Y_iV=k&K@T3^P{D^h>4wQ8f^8MBowkc3$8hzW&&BtI#LSf72EtrJuLz@Z9YI} zFc3j{#(ci)XfnsY^u&GVhm(_2xq97r2G7BnR_;s0JE7rMn_VAowEFZhGmVYc-d_(5 zW<2c|@?KH&aDYf1+&$dmub!p#aRR5mKTb5AUY@ZApj+D7-0j-u4VUe2kC$C8pwDB0 z#B_)ld9(6>NO2dkZyv8DahbVu+^Jp6M}?Ljfz);Gn@&3Jv)+5!w|pEW+g{fTpYmgy z7)fo?qw}{oiLsHbEL$oHZ#1BSe!6k#uU4@62K;{QTj&$IoQFDL(iavgP5-;G}Ge9h%)SHUG9o>h zIE#RQ4iyF-XY`PP#ZNY7*Ns9l3W}hB06MkmKl%faCX72m@!-oNqobpbSa9jV!2@mX z4<4SLQ!A~lytg5%x$^E%P*4~mYU=8E4c~$C_mk?ks(DjCfBuYrWlIR@j95&oeb#e& z;5K^kvT=V2Ysn}&y;CL03L_@i+&6x4a*iOVr_?>mUnO!JeR$oxuJ6ppJX}2p+QIR1 zjW^a(y;8*nZ1>*yS%1i2*bUwEtd%Z@@h5)TUaWc5UWx71+}?&o@otH-`slmSG4sd4 zDZBrZfTE{3F;1{prcSNt)zhBbS*JSg=$Lzs1I0;`2}iGTc!Nh~lW z`lbqTGfE+W{XM)MBQ}cS3E<_qLnnHKuzI~WQd3t#u_mdbnJwDLfI6ISJ=QuK1 z?v7JKdLvU6b#Et(L$_n^3d_5RpPeFJ<5pt+I0hF>M`S#_Qq2tz=1tkvr++gc?6oGR z;<1uqn!NRWuaIt&9vK@0r4L3f`~A_`V#OHMv*kKkKEAh>i*{j5eEN%(2FB^)r5fYB zg{G#aZ{NO|Nd8fId%hDM50*MVU8)(p$D&tvxY-{&TWKI2Nae6xmKAury%+oPbh6rX zlP{Un+SZ28ZdPl%)4#jBE1xIY*VmWRI#;2EACggAY$_uYnUsW$2DruE(_N{zq?Ab) zOyux*ybMMmCNaIay81RkNI=lt-QDbb!Edq1REwk_9RDin@L?Sa)lyF4BRBbBtY+^Z9Z7$v;Pc2CY z$n^QGo@&1M3&TdX*v1jsKvpIXhW>+AJ3@*&LB>tgM6~O@1|NJXyabKgwNRH3wTs~eFfB;Xhow*3chrnzQS9c z-2rfi7nywSlU)`|)ll@`zu~lrN(8^>o>$2a?%HvM_MFPoCQv=*{fxkLT#jIjZp&Fr zi;3CV-Mw{#@0b1dlG}EAdb*y{W=Cs-^xC^1`jY)6f4+^*g>gs)q@>ocQ_rtbu|cOJ zCGP?$r}qg*JcKvppqS+lLx*q!n@bRK?ZV&)Xov#T#j37qTs+)u>%u$pJ&uv4}6B9t=iA-z;&BZo%XMboU zXe7LVSR&&6gfdsZ!STUk5m=h_K88+Z*ja2X0u&3gOwi zM4J<0-*ZmZ#DuyYuw?!*i&+l0vPg(7UGty^idp7$ex~|83dnFhtfJHS2DJ#+&sj%M z$=WBVo1_g(a8D)8ej&A=fRisJsS(CS@6lJIzQU!ET@b(M!fHI)H7n@e{A$Rm+A#jw(AP%LQq>tdulV;DT76Q4}6|mAC>fq)6l6&uy?u*xbqgL*jkDjWIw2_s_ zFeCZYdnq?EM1RV^>a!-JG|ruH_eF$$Moxu9210f!cN{sgz%OWo)QZw9G$7Q{@l?8L zg5v2CuzLi%b~Ap_+}L;ufkx zGy<eZI>|guG>EuM%`%#VOq>+y6h z>yJ)#5Y7E=6bE)xi8|6`%o=y?Xe2Zq{fyO7TPtEN&t=-cK-4r<-(Epjk}c7-Dy}3W zkm4mE;x8nZ^@|mQE-7Qe>QEm8fOsd3l|HfVB)}+abXkpsX#I=pMhQCf1_+6qFL3v> zzYoW-V%TR}L*II6Hj!g)af7yUW-D;5fub?xCeg`JkoiV86#7XnNpQwci{Nn;B>+It z)XDVr{W<@qp;6AaZc@Jtks0-sfJRBTkBrKLn9)T>QK%J}?zWe>q^4wdPI0asgMNPu zY2W@|G@FoIaD@%k#TN{3N58JCOSyPs1)+VL`+1SyEvkWSaE#IgX`b9SyY{a2(>`W~ z5VBW{W0qzuIKa*TRdOtfj~wyo?GWGKPVHfs zol!CZcV-`$kJl(-UzBWm!lTiVxpH}om|E{*h>hENI-Bp0BX~Xj!Kjo$luwlkj2xk# zT}ih0w1kwMVK4HhTA1WBOFnz@jTrZLF0^Qj^}niAAI}v+%8)k^)W{Cn-Tpyl6T4`k zTr+ni&b5rn7?OnfA)opzpIZ`?_SkLXZ2#H&u5R?Ry8^k8-^e|<_^zla3&b&m)yctf z!@ax4myAr)O5SJfvg*V_vPt<%IxTk|p?w3!@QnS2N4@b4=NoD{3ArxlFv`n1>W@Y*P;nH7P3>m^ckUvw?<0oB9Xa0;Yng>%~h0 z>FC0uk_;viNl!|ZJ*$xp{^61J@P;Rq@lL`al21!q(Qjw(o%$*Dot4+_3!jnbC12-$ zCF?_S(N{`YRdu0P1RwC4d(p+2;{-sa?b#jGDGZaEB&#F53A-{UY%DPr! z1N4yHPE2TciOJ2IC=!{O#^-reS>N-Nocg^3gNRmo-NaSzu@~pfcP4nLzqlCuZlU7Z zkjrK-szp4;!GI(9(5?peY%-&t>y|zKHQbU=^jkuO6KLdS)d({@)Jsb3oj$~tWJd*E zgBE8;7^hVEC?OFc6JwH^rnzd*&G`DU1m{ZlA6Uc^7`5&VHC`)ad91pPn4_A_y_=r ze)y#pHRk1|#c=S>5liF%3r`N=%yg(Dsl+X{sO*TLjeR^!ES6QaV2nG~%Bt&3bMAXZ z{)0fYuKM#K&IIzO0&=IuX>7nOzL-2=F(&+b{|2|*#q)3iEmcwOVYA26_0de>$Nh@+ zkM4S_wdWIH2Yg`n&A|i@4^JL66qNC-yV!kD^W=u-1L3h!9ywJk>4Eujt?5R)=M+Pp zXap!=aGy_hc5dDtHW)SP%Rq41dN(H>e)WjX96eHEBHOQUiZ#WKuYa>JxTWEm-95Ri zsDVzsun+~&;R4SjF}~2Ul9z*AEDYWZ6C=Bt-3=a7h-OwxiOiG10670hmn_Ib$A zTcZ<_fP`5#7&rMzi;#k|!P7fLZy6XRa3!~YLT9nbAeym4D~m8j*Cb2vWYwC|^5RR! zARG3)>uZ7e#o6nCWPybpTO{B3`1tX+d_Azg^t{>?6+}teo}9XxPq=I?GMT;ceVdYb z7dH|*L*`F*tiIEfG{c&6hoD_MsOIfD8MW5*td?^BkP+K$P(z>C`$brqeBJaIasaOu zX(WLi?0QBhafb|2%aRCqoKBN;yqdf}_)8jo)cE=NCDN%)gK!d=P0{w-`^&}8sWG#n z5>hT!cmjpmr<;SkJ+UYv5zH451k+^T*xSIgPiu^e9p1_mxIE@nq&G&ONOyYG+U*23T2M<0K67cL1lKz2l;r>7r?z<=26 zg3yQe;3p;|L~Hy^Dw+*Y0G(bz;FOr3Puu|a@Ejd6olX}9fzOr#Uo5M+2xAt!w7`S? zb%?gq7Nofythbv-xW?x_l$1Rz1Dw$?*BZ$3caPj4zxvus4$i2&@hKD+6x7ydEOccW zCC9udEND((+J#8w@t!pPpd^&mcA=@M>VPwA7=<(_k=n46)-(t;nt5P=I93C;#c)~3 z;vFYSn;YIpN~T!^<{Xtg9y;8i+|f1M+Inq#G8npRzsZB&4XUCq^RhMu&dw-C84uRY zrgPk%tGxQl!Z%w>2y(kP){03}-nXpZFj*gmgo9~x<@lZhWHh7VIMs1`zFxKnxWaJ% zz89c*+?a=Cj%wv!u4&Fg4eZ#MFn+aTqMsCTa4Tvl4^Yg`iAG~J1pL55KwR5gmrQii+?^ZA^jG`{xruOYRx?D{PaNd*H5YXiuLJ3@Rav!S^kJtDGFj~H5-nw9&(ijo5!k5 z5rcw?^J*L=_JaDFjFs^#zK8A2DlR;r+7?da$5Nk{2kXfn-Zda4s|}*g(waEU7gen> zq9&fc?f`t|z5(@akriTh2K;u){M5Y5OE6XBp%z?51T{SW>vTW_VZJS;xctL|NoMr# zu?zEcWiX_@!K@ttklz~^R3Jw=wekJrO>+QH_Y=fj;SL*)!8F>eRw z8dL_}EL%pkxuz1+EOo3sxy=$)U}WFDUxYfXH&Z_3naTAD3({3X9!5%JEWbHKee)QRWh&qr!c6w<6<0Z_Dl0Pvh4WGLy2MizPF6U z4}qAJubfg=R$6QFz4Im<6zKAbo?#5R`l{RU!cLStX||xVIqvfDAr?NZJXGc{uxOs!#QyQ6%`thKhZcrQ)EMt_bp*KgM`U)i6>h4fnOy^fzH?2*OU~;~> zS*h0tUv@|YAhY?=8V(+OoFjh#F7J5)_agWeHK#7g>)S+?BvUzXUmGYL~kddhkXzB@(O1MB!* z$czf|Z>c{m&(7M-0| zMssG`Q?W;_y-c_-wN#nqMia~2cV3V8LA>7o2|Pyb{{}beZ*aFnCGtIyBVSN8JGXn9 zO%tsHq`C#=?QZU@TBV_hcx;aYfHbF3u-}0l+fIB`f|5T77FZEs1~Ap(6=%-;q@z&o zn+X&b-LzU&gs+|qeCRtOON+#iX7>3hfL5opxa63P8x4e+?Y0d*6Xsf+lLJfId2<_X zM4qH{-79mtC$n2$=R5>ByB3~u1#+CqP2GNdd4Z5?);!_o5h`U`P2N@97;~5@m&)rD zc{?_z+;ba+i<+2@v@)wHOS6|J0z{~O*A8KN4>+#AGTM*gi6z%#(QWR?Pu5`j{cQ+Y zk7=i0=B_T*X3a=IHU>9f&=NzUW*Vw^1eQDnIdTJM{kX7E9+%!nwB@G~q z>6vcjA(br!dxhybs@e%|4C^S@RbMDqg7!&j?^G^V zzq>t}iHwYNzE~IWeC!EEAr)xxcw&LyKkS>%myD^Zv)*Wz$>2VjFBfK%!+r*797VuL zZG1%N+6$06834{7dQ5el*Xeb4JjZT6|HWFE8P~B@-re^P68$&xBu1S`4nf_I2;O)_ zU!Ihsoi8rjL&1{O57tj(QNN`WWz8&K+=CY&t+F)^O7=b9WgZV?v>~jA*qwSI@a-uG z#kE|xOLKsba0H=_Ci)q^{DaGxu2WHI%f|?LH_Ac!{%gxH$q%ckm6UYEm@5jZ_q@2;lUA0zQIcfimo;&@+rfA1iIq1_+KK{4fuA#pf7_7W7gV&T2gU@{ z#e=Y{#XkKpp~ZPWGX6jm-^fKH|F=O(G_SiFS65=Q?*;q4H{Hg#%d1|+ELSN zV=8_7Lx{ZmN9H5IJ-SmYljWSyp~5xEw|{32_2**&K99#qDw`sWpb`d_k`z|s^Tn#2 z?Y=Pea;&D=0_ijoQc`(mq~e5*j*hEBWu=2csT99^5l4PX$^j6acgb2^$?`+{MiV4u z+X?1`7f2zU+SuIOZi|x69C&$YlOYqFZUK_f^2DM*3gUgc=ZoB?ps&$(ZwOU;c9bA^ zu#l9xda4}N_pV&4wdSIW-_361T+)UvD4Gyi*moU*hOn=sEL@5WO-M$(MeF=v7_CTW#k z%}5ot>%`itq}=BaJhoz=Ke!Z~ksHes>22Tx@C2?Gl-8V*Dw}xpQEhnnhgu6B1HmA% z+0%V3BK_d-w^dxkzLw42-7$nK}zbdGvOjU~@OTALzu(5I^bxM}yJ+I4Nt2UuaE zAV_C=Z}yZQJn=lCk(Qa}ZU45Kx^DHsZLg!U={Jguc=>B3%_Cx=l`3P}#y(cI4J4kw zi71cSW}(LJhj0k!+tB=+D2Q#P(3euCQe=IDvdnqOMJMY`0 zqg$NeizJ*Rol?Q*)|&1_T|Yq((4OL*$-Zp4=s?B zL1SSTXsCY5n*mk97ZhiwN0wTMutbea<0asGuGFn^K?t|6HJ8?=<)?30#<6i1H+c;s zop=Ho~kQT9M$WJ*+3qIicSd z=@!g0pN30Xc=Tw|0oK6j!@&m*J{NW9>NwP$u5_R|ZH!s(Gx7;^-#{Edzil`Hf1RAQ zRw_9X;m0v!OTKq~v-&0zn}yIm;}CKBPN}i^KtRAbQ!LZ;B}X^ zX<9JrajHBa>uIjBJJDPZiB07;5wlU(U145(?9LhE7MX=LH85E$8q&&K=)O!hvORv# z;;wpx%a|1Ns)<`dET>{tR=Nz^k-9h=8wOOpER@L*s&G1*CRb{iHDgyOZL&Y8w^(8! zC50%)EGb%FUk4ScG^!2Ys;QK!-0zM96LG`}bIC!Plkf{36axUT2vE4oAg;CBlc=># ziKk^|{=7#|MP;&7ZD?d<1e|~f?yUl)6z{h?RZwTpdyMLam+0#6hYs}j z@d5i24w9^deIZ2PG4Ryu%-BsPzrc{>WqB%SYI&r zF2?n-68jtSylaBrtG*~w9x1VZ^|yitvDuZX^=vxoRg+nk>t{_gW5y{iiTj{Ej~a3~ zxYfKY&sK{??jdcx)>w@NMGy1zEN1>p*a`R4g7W=N>l14?TJ4nlX_#gB)w0JZ#kshs z{45Wehw2?M*}Io;#;rDwSNPQeduKf8#uvcPj;h!Bn#1QeN|eKSyO~{QL>}q!;bK;s zQ2)BVwTK%5&*m1*rfU@gg%@>3E2gvNUn^HHF^1$eF^Ws>DWwd*T@rJTK3miuz(nP) z71MNx@yp6bSw=#RKW*l~yF0%>O zEy&0LEV$J+G7(EM8$C2jYkJ+n&!$Zr&#Q{Ag?I(QFi#ksUv4Ja$}n&P|J39JV|dn% z3qo;)Smdwrh!4xZ%$)brF1;*?JtaQ!JkQeFEOet+BVcNssIa#iW+~qh`O?zS#s47d zSDT5b`JOZX5W)P1CIP5qe38xg%&hDR-Gn<$~R{()MtQZAu9ju^5rL2+tYTA17B z{r!DE#-v?5!0lyxT}Hg3(~{e}&{A3eCJftN)zi|Kwz-JIz zY$ds9%hO?k>aXe&RB#__lAoxy)}#9q;=ZEJUHlqkD>JHCSOutfpVay+U7NQIaem$rzcd6!JT?`MuDJCV_Y&O!?onmG%`LkX#_^6zy?m<<|^tg~(6 z43kicS8e4`vNLlx6UJ}$EuCF{FBa9ZQ@wooBCZ&X!B@pqN3I};e;t}mXT9FJ>C91o zHiPF>T8hTBT?~O(Jerp~Kx|`Q@BLMFR@5I$W_l{w*3S5hv14LCD?ekx1}vdfk1T1n zE#ZjAE4cv{q`K=9HF9S5%E`(3@DJi(Oh!_f{RDjj5hsExY}*C+O-PGZ97z&pE7;Rf zz>Bb&|2gw+m5jE+AaXrvcq?`%^(V)5^l76`<%900KBoR}h*VUN4B??ev&zm`GM1Kc ziYCd<{Ips|oA%-UIDcM<_U;x5qK4<=6YF!%a|h|h9!Za$dEZifD}?bE-&<&HOar;2 zmSWWT?u@Yq_`3>XmYRt)&}(WOgGJr4&yUUmy4@u8f1ru+RVJ>-ifV6Tm2Ovv(5fr5 zt4((C=gV>K7#RHw740Ce@z*WlA^HMafoYwM}k%M%#o1G;3uEDyaILlL*q~ zC`N8~Q1UU2%nBYHQs??jFV6#zuxpN=Qp%3DHh&FgxRC7t@tfFJzoU5H)34FmBpo1{ zVr;(OV~rV=dB?^xXRDR#mY*~&y?z)35K=OvY=Fw5c6QgGqA9`W&sCtp zK?pi!F1+L8@)`rFF`))ui|3C?%6$sSf=j&i%68z1Y=)4=!sVOMRHSf2&xph(Nx>s` z^7R0rgqj;i93$C5ZQV`X=d#vJPY$7wk}6wN;d-G1eoD>a0(@6-in>a##GDLIZbkj< zr5EePZ-{<|Nh|~B;#*;ZP#LG-7EsP;${j=!J+re*IvyI2$#1m2J)|si@2ovB?C)~} zO%YkQpwG5m_Feo_z4&`JHe_55#l05ht~{8Hop@G$iOJjR%ru9Qp7&DA`)W@2edSl> zDF;>2|6EZjb?vZ;+%?4jqLSD9>oZFmBRU*(etwI(yCV3+RhM81fiiXDFt+Gl{PRTO+%BW}Zxhf2ynwzEG8T@g73ltZJ|kF~Emq zb)gA7E^@j{gK^tv|7g_y@nH8x79{glqX$%5QTIM?S012KF}V}gn8TgTjZ6q>9PhJ2^M@z-X>A3wj`aNFPK>|^)Ad-`l%XOxUPaX28-sbfno zR?bw(YRZf_PT5gz+MH8ZlRTm2M6Q2&k5K`K%RjU*qu$A+R4&4dLV}?sG>U>%w|Bu4 zG$c0j(U>fKVv;V&M1`>>)dDu--EPl;X?KS)XTff6ksP$2YNKfUtP`$syOQ!pei{en zAR3_1jPTh`>HCHBC9+fEA!1Mg?T%G%+akYHNmNaf&3VdgAa?bh9iKZCa)JTFI!C9& z!2?5q(6w4~^G$#6A{~-DfNbS^YRHguTY3vrYNBX}%U2r=(FhX_fyq7@W0xuDJfTdW zLLA`zp)VM2$^8cXit)lAK!YtpF3{f=k9*kFEwCS?W(mtnR9-#_9NCj z$9tWRmx?s&%FlS5fvVysXB-~QVBSHHs^gnFSwgGK-RT1Bfw=t}e#pr9geb$R$q=TJ zbE?9@6%B(!;TCa5(hVDJNn?t}o?!4=QL_qT^`uCY{fBoAcsICfiP4Qb!<o8{hGY#BH7_Q1 zpij9e=%lYFyX;wO$QKt%caQJAMV_t1C$4!5=UAZ=GLdKux}7_Wy>5)CrD81**rS`t;uj!X0oPa<43qs%m|#kVI{^){nb>>Ntm zu%_G&U&^Rj6HAvh_<${;9DEO*WO;D#6M0SQXVpRcEg!DGtq$M8#ByqNTiXGziO17Y z1(n0#-NPBo>u^=U+R~gUYz$0-E(KE2p6DUcF7yfHWw!uqJIQR7NL_Y>wM*MJu_fN) z0}UdW9 zb00UmTM6Jv=!|(Ep3lpLK<;%G^J}DoYs4izzL;oo`Ma4~g1e-D&C5?QG?!8N2qJcd z_sRFi8M%bc2y=-)H=I=Yh%KJusQn=V(e4{9wyE_5^e2_vADe{wp8ZC8QXZ0*WaNuB z8{c$=$?5B1YyU~MrBf<$!Ci52d3ZJ1_Q0hX3THZ5H@o5o&4dmCaj9% zT5d>r<_9QhR~JU=@h8JLh%_<&@p5|FuHSw=>DinTZZXvR1c@Zgp7GeaYq#sz4cxmb zwU0@ryOhf_ee=`8iP05%&9rNfZ1_B}NUq`^`TFM(;mCpk>T!Vzh3?jT$Z`WL%M0p_ znfuu3@ii~otjZ(`cclj#EcKt#oI1XvzpnG+e2EJcH;eN^UU8lA>RGOMVNh1i_{*rv z=PXy2GrQbBtSS%a@F|DrZo`u224<0Aw>BQ^k;5(-T*971Pnyty@uYXb%N^UJtWyQ# zu@!;{R|+Bb4!K$unHCK}$u7a*W_(I3cwJ=)5d##mh(`ZlQi=*ax0?*GTkQAO zag=zs>8`{x24XI*_>ejHx~1qgwZh_bL)AXAsSc?G-4_j2Y?ArnDz8kGo9h1|$a_w- zbw|f1GN;7)rRGCiP&~#o5APcks8(Y#xf97eF&d9QA~t|=;UGbJm}kgJT$F&-7Apl! zZ==$uwTkn1&3k>d#I1WZu|vq7`*0&MvR8qPeAd*3U}V>c^x~oR^Phev3s$|CDzm?M zvmdS6fEIyNB!KPP+$rFGD-^BSJOZRjn(EA26?4jX%^dZe1a8U@P6|IwDgr-(&DBFg zw?ZVL79NL2_-VZ0A1wu~A=Q#af(Yz%(rUwyv1%%JZLNF`hPLFAnU7s~iR4M&BeS_~ z>*c#C#n^tc(YhrTX-Sj5k&eMF&D)H)VhHDB^hnHhJ<=$=o#wimR)^t<>|LJF6Yuar z(15b4G|AEliU%5G`rkj3eiR)LBa(Y|0v5BK1Ns?N}_JYsoza=!Q5LN zG=v;91gmKch*_Q`ip3j$s(D3O^&$!R;{6({Um>ha^ph{CRX6t)jxjIEUioAg$gxLjIUF)q45!{e%f$!8@s)aM>lHrkkS-Muf zbWl97h5C4=Uy8;xQhwPo=oURtbyseV3V&M3OnYY{5#(=6w2h@^U%eU73kqF<{5=*% z_@8dcdm3GTGSb4L#c%~LMX&j!>9^6-ij(@4F^kE;qG#t*vS-mMA5&2I22G%bs9LwF{g7JP6`P?Fb{cQ}DX07mOkr|MidR#9%;J3KpI zCmN~F4xU{mgG7nh;`)_>)%bN~!)`2#0iFgjnOi3gZsruWvOE8k339zkat}cY@LQ3Rm#b} zq*=VMGncoU(LCvrH`c*j&g_PIZeUYoA0yPJdEpsYlI^Kh><4jNl{EfRQtvA;asSOo zW#jdD)(7;)Pft(Qe;cp%Sr+NNMJlikk^cNiiTK-L`@7?%#(Bn8fy(pM3RV1ZG!qso zOeRD2G8go8W5?IEINLuW07pWubjKPvouszRtLC7V%Fmvz5yEcBlCX_48yHgc16 zkB2hMk9?J>3bUIA*|&ZQ(FxoRoaXDj8xF1-I0(KyzAd!uaV>5UvOg`#BIKk%Ay+93>f?fEhY@2kZ#y1(<$ zs%yOP&-yrj>p3#Pnt>(D+EW{vurEi_>2tOJ?aXn;GlmOcw06hBRiO-9WDcbDlb~R> zO;kjD6+r{{!lpC!FQdgxw^oT&KR+I1*v*GF^+~MU(?65-XZR6aY7@iVdW~%dBs(Kj zEW>&sG&UzY>7}WiKYqVC`27;rUh_ISRKuiOX@XPK zVIx&4bvJjRv?9Nj?)e3IX?R80oz|bu28bt2a0Tgl8lH0~%*_NJz^~)`Ter{O*6%nFHf;@XKl@4F;~m z@4r+-@qG&uZELj(<>+DSA_=I4SYhp`8ydwVaf^ojiGrKdn#Rnl{ajwa-ulLwo5i+g zKy&5B5r5s7Uc|Dt;uXyC-}X-q?HLaz-362Vh61|nTNh~lKb-PTY704jBxT`H@p zsti%P@+^nm2)Lq(s4&tOly>%>UO67OC5h#hjcqciiR51SJ$5S;b`E~wh?B`5YCXR9 z(!+}im*9ScsK?)D~m36*Cp=FTM9M$Y3 z9!z=i02!Jjyqzw4;<+V2*^7&dR|DyAJ?P3LdIx4AQUibBVF&iCv&^MV;BBupF?Sw| z>S*%x^h}(lHPu;rTLgOCiMA4_i&04Pg#~q($dXU^T{GGoMQ&&G*QM`_X1tsaF{`So zf*F+i`}?e{EI=0u+O9);uo(@;u6%z$VnSjy8JU`#92y+_(&_<*92yc5)FO!Z6%`eQ zR%gmIISqP$1_cEHolsY=N!BuLLFL zezcv0A_}?_5~l6$n&RW9F!v)F=I7>$Dk#{FWkSXXges6r(S-HL1q7BSCeD@6HR^3h zGe%c{r2K*GM}RPfJy5>`_KpKFA~Y0CgG~Kxaz5C}F4+O0gm%RZlWro}tTpd<>JbCd zCuuYOFgbTr!j+JO3E?k{!wfdKxVU47_6x?1VkW=_Ay`!>zkUSx2*%WF&EbU{E?YM>pr%Z_J<1eK%=p- zj$~l$9UUFi>deen-=_gf88D(DD?4(!)Yj4i68XX$yAJ_V%_7$!VvbpC1t8Kiu9*RHcK#GZq`ugELARNH74F2Znc#b%@mx1NE6c8FhHl9f3 z7`I3sZ*q?J=${_vLeX+j!U?`bQ|)ONv2#m39iPbj8Yn1>G5+%BdA$sObYW2c`h_;b zUAWag*E~k3^Fgkk%2!g*BcSdXqw0=0%O}1=*HpsS$H#ai^&Jwq!{02N7V8xO9GnCo zzwRt}ettF?N!4g^V|Q4{T5L8RP6qeyXr*msVBlS2$NBmBSb0D|fbc`J{f_M1**`b_ z=KC^{*F9|$+fvC@#=#3+R9 zksoedyGCoWvpc)HH^G3Eo^pD6I+ze#wHd=lc=ryM)%a&=DZ^tk4ud)>0Rf^d0U=@5 zJFrw#pe(Qn&oxl^&V){@W|7}W;%JcEY#lr(n5-T>mZv#S-%Wb#WUtCOQ zBOom;t*g6YzcbPyJ_h55IGVyKcdmaMNh%T~sy~{}PpV<6B%99nj?4wH&y9^dnwpv_ zDk}Z2Fo{`E8hrX~Ole5CvJF4ziR4PJ22p1Z5q=ut9rUA~T9PeG+RQf2yBZmAS-nL4 z)~dU2oo8C~-LKYa4rH`tlo&XLGE!ue)1lMF5a*$)BB_Z@^Qfq|5_>^!oTOKF@>$kLO@4 z?N9=9pfR!A2q5-+<6o+!z?u?RJsd7tbxWSy$zW8+%`D404JQs2qCtg&v!?37Ob;(| zOI9iX8BOok(0dm$cWA>z5oe=2oM2%h|#}5w+YzBD#OJxIZ=finQn@ZtRTFSdQ+fCOTjF zPE4$L>g!jHgnYcv9VyEiwv_fKoZqzGWlKgU zr$%<3LieuJgecskdH1(Z72(~_}y}_eC-}mD@Ed3ViiyDUA^_Y-*R(04=Q7g-(P!+w+^!X z6lA7GQ^KJWU!JEh8X|UT%PZ>R{Ga+55Y9OkQ4|WG`-RL4imh;-jRd=HlGk+|(52;Sq>p&}YBi-UbP0 z6_%8UM7(-tq@^AB^XF$o1Z>*i=x73i_u<|i_t#G7-O3JyOzs~JgF{240&MW$8{^~S zM>7Ni4N0jecx_hOeSDy}L=6oo`pGEZqMbn!>c>jd0D{kmJ_G~=@*v`1AFAiOos8&p ziMQ<{N>DKKb=Dav_>YP2^ui)rK<4SXj-i$dKqgF1PNo}B;h&(ek!IeBd6#SCS%aNM z59h8}*{^1h9kUSH%;V^K9;&7Ko!vPoO;PPlqdZBfJ%^lll-<42bcN?{*tx<^IetC0 z* zbDxh6q#rW8p=Pxr#g0c07Zq$>)V9p}?3t;QUydsX*NeaN6<2A(HiVu{*L+)p#!+E) z#ATA>y0&pb!MIGebCuw=jfp;3JXp}TB~F8Ws_yju-A*)ToS1S9cT zsLo-i?Cf9V5RR@k8}VrBtt1{x+QRGn6)jp8p6Z~eUpEeddT|j{ADr}9@Jmf3t^GzL zd`S;_@!Ux=d;}T-zeA-M#ch~s7i4^$7=yE@#P64tax#-SPaK{%4i`@N*sxtJe)1%4 zSTVHC6^Q7lgt?o3xztC1!J9WQFrPE^9pRa9#z2NSMiP8H6;V}<&z+@2t4G4;0HQls zhWD zo-Y^2sx$tQA2+j5DvNBZk$DnV}j!zui>?p6!0X76ZA zO|OHJx9n|*MGr<}k=WL3QGX;SGTPtx6gRR(nXqdSX6>RBf`67?COfDtU#pM>9ly42AEF&W{^3T|y=fgFa7RP)49vT{2$Z+7pIkjB61d)Z$ zk*mAAS4ewYT-@HlfiK$qL+JNxpx$u@K~7W@TC}{ZjQ{&17j!2Ygl3HyZYh>liNx?^ zxdr;{Vg?M)f;rp{r|;ie0|Q}ADR(o&!iITatZt6~f=Q7U#{<{`+r^u4Rtu@{Ul+DMyFeygF%rT14W;kB!Cm%Q zAur|)FFp43xooepza0BsVa|s=2{9AG27iHtp6PF+jAl2+DQ2|&n5=W*7K&usn*EXQ zj`7!M(w6z)$Vd5-K2xDlKym7OxfxphH5JFIkF{#hZKRTS&h>S(`)Dt(uDXPx%Z`2_ z_!cZOn!;P$mwnbD$Ts z1lqc+gPpsf)CToZ$|d%#gkhhg=!|uLmV`fyOG(D>qa`=1Pi@KjVhlKIx{X-d97n~3 z;4pLvFwpOrG;!kCDV`h^dGbm@htC-~^7%{Q=@(F7!p$zxw(v}MPphx# zxDBUH-YFpP*3Jw*Vnf{=LA>3mV`P7OHUkOvnbYFDc;sONxt=L3apxw7PHe=Qk@}zK-k!-6Tn3n_ygFW@e_Nqa#xT zG=ZaogLqsvt3O`9K}14AA>@I^2l2`;77UpRjl{;rLPJ5oaAZDE8hhUJ=5v4~JU%gD z=j%FBsO|Us7OvXS6BXU^4O?cw>Y(=Z`kmTj?hjN$!#bR%PmatwBi|lZ&nSe z3oK43C8(nX$Jz$61UM)VnbJ4vjPRy)ts~T;U{qAoV@@Fu`1Uv=f1VeZY9G?V&A$8g zIeE2|X9hA2ilrh+vG^}`{wjx=Ex2;qq1~Ya+m4zUZZ6+oKYt3D-M;*tH)CBjCb0_T zTmXVz>QNEW?6{Gd-3I(z0TsSUY*lkzo(a>z+(gaQYx&h=7=0;Q?TBHsQ|%!|+R-Xp z@*9%^Mwk^c?y^upwi)@y)Yv6waVa^3dz98}@l>tWOr6ybi@BMSXagJ#DXZo^CQnz} z;-Vs+X=6vJUPXF5|5SBl&TILR3m6-+vq5v$crU7a=1N({t)FulX^q|6!x6th;r2(Z z+Zw<4#H0-BUARDEIhFV_%M=Y+j``Bpxt`jZ=WP9{(!N{d2dH!YkV<*0POwr!Q1_NjS@?D zbq8vHTm5Xa11qxDSD5vjr*|Im89=Ccdw18-21K5j;sCh8 zIxIMn80Hoh&PNL{y-xd6w5Xz|j3zs?2;KnxxC5XRHa0d$v;emO<1`-#a2!w-rKO~x z>fCqO$Uak@!&Js#dXuE@S^)v0=l!KwG;9%Cy%ZHnc|sr6d#9N!PR#mB*h%-qi2~Q4 z@~G^Cq?+}A;{R6>Z(pw`=Uk*;3d!??OpZRs_)6RE7V^P zmrK5`@CUdOc2ENrLtWV3ncx??K}CG{)UISdBwq&#ZJ4gD%EiFjmL)IJ){d8D6Yo|s)x?lBdQa{`%;7?6hD@nj9fY>_-xku;N5)QQ}J?iOCKh1 zCE4+W;@sB@h+?r+A$j~0?0cP`XfZy6xNE~+c|JSjFSo5(2OwhTGF`vCTwAHSeF>L-`SieU$XHVk z%C-x%%|gc2ZgQ2P7q-M}T(nDtKH_b2r5q-|Tf(G;ifi;_QSpG{<|fCfT~@&)<^I zcO9uvH3VS~yz;KriDR3SZD^u>TCIkEW3=yGDON=3NR%=W`n-Edg|JE6#{yY(Lc-A| zzbKjWbT)QLrro7ea28-#zBPISEc;Sl3;dSs!)xGFJ*t_Y9 zW}|bxwNaix=3Bf`Z_OoJppH{^024BR%Y1CMLoT%rTPu>H%-GB zo+Y+2kR^?$e-th3ubH;)lbSEo><RB!e&C?-8sr4xyfP2bHerVyt!u|wf2 zNj=CrGY;ZCQ}n>n3VCQv+B>@9v*G)0Ee*qL|LH?>6zs$k0~S@&_lZF>Mre2~dkCH* z$>ivluhG>!bC1Eh#SJ>iW6({3R1DYSQVhly&_1?*r+!Z9jB@qkz|E{MrR`?t-x(9t zTLi1vw9*blf~OmkgU0!nIb)}YH2pb5esDECq9h=^J54#ilE*P56fa41uuLx_ybN8i z#>?|B&uQtBmQwHETX0opFn@B!RZs=GWe0Q|CVh4rNzJBTrKN9AaKN)u3CvsUA__z> zCx0xxf{;Z`sl*ZME|uimE&JmxX0^BU8;tNIIeXRO1`91+)jMc$s!)l$@Z{V1#NdTV z$_ZA&55QDa1`OtN~eBws;=_{A=Q6=y$`=GE8Q%r%x%E`D*k*?v6=|5WTU*r{Oi5?OGS&HA-`?=sRUZ zTb9w6rxlem(+Y;bn`S`W6lXG3r!KZtpW|f2nCPk|nNJ+rPzG=y!!J>jYD0;e{<<{m zLKQx^{aBZt>H$*Z_SO+b^pv^Iu-Owe1-#Y+di^#RwdRY)QxCYhFz$=n>DV&rJc-rF z3TPN)Gycr>8cuH!fmiSZ6N#*HI-lo#paQ3Sp#pys$*;<-V~PZdDBug<|9AM{H4HKgauF;zAL4(9uYN!U z{?Ccn-u`#kO~_lsTVx-c2oIe9-^1YQvf4?n{yYCahyQO~7M#}UjFR(nqPa!x!wB3J zl(UO$odQ3D{AwI}UN6ZZa%+I=t#ve(80(zJ!7i>Mtg{L(`t+){VC zSX0|UI284fe5Vy9jJ3NrRABS=W=_He2hoeI&FJ^PJZ3b5lSV}>%2m;+TAt9s?I3&z z0V_2Gc4W1&2@SDlk$wvOcfKu8Qf`if2j(T1^(6|lVtALho}an2NbghFL@4^@B^S%{ zO%kTAkM+Z$Ze)eP`)Z_{5(|GGFK2~^-Q?!s;>BHls613O%tB8KfeH+g=%-DtIx_T- z{RpgC?$J+S+{~pu;%^Vi(?-KlPS+5clx>~EXthR$UC$uD3d|W&q_Z%uU_v;aoUY2N zo(#%tyVebr%u=AK`9}2Inny$yC6th5#EO^fT=`yOo@AW7L-d^?=?RLX124V$cK%^qDqY6iY=gdc*{e;<F3 zR8cyxAW@+q3lv~|yMwg~L4`Jp0W0zd7n2Jbf|~moY?$9B@~hx)B01iH<(>Y&`47kV zXt=nNWKoTPESbP{MM@X~W$n%Bsl1Ac}lkXC50$M`|=`_l|& zG0DjE!O&iJue3rqKBugtq#0=3Kd--Bj0zI+IyuMW7Ye1`X5$*&*PT#>U1L7awvY&s>^;mlh%(TYrDQ zDWGF)p1WEZwFUmy%kcB}F9T4#fIzDx*A(lSc1b8N4-YaS&t*ZzcXM@Dcdo)VtA7ne>McbzIBv?QO#*#~2qJ0{V41BKYmb045vc78%gWL|d=&<$ z9qA+J16f#DUcY|*OIum_H|P&1CME!j2e>Bd_8+eSq#tT0>-FJ0u+{@|w9CW!3r`Iw1zv!&f%Aj5 zay8Q?2=t%7)EOeVQAG&}2_>|or4dH_5OA0hQ&L7#xmo%731u*~s3|DAr>1^O`T*X$ z(dCS+e+4uqLPDg_=e2;|1T&Yyim0ZRmW;u!SNkeljmgOiZ7D$VE(VP4h}w@1}2k*n*D14;^yb)XVz|n&O2Ud6I4cs#S!4K+mw)$8~{Mg zs-|y<+xg}@WrlhK=?qmkpJrb{_)}d-sh&-?FwFohK^23K>;ln^63){!)wIe`T2RUl)AdenV67eOox-%SG->M zyAQ{60#j1(3q=eKHv!xYp=v*BSMCibM4*DjgO0`|Am9V?DZ5|GXFb31^lqcTua7i4CV$n(gHNMMK-0}IW_uNW9H zKx44KZ{CksXyR#}sY|a~3bC~y7N(7SKT0f~ z7ScXyD;yU=u2{h&V8kJP9no1KU+(`1qvrk-4zCe}l@JDlDHI9i#x}{xyh-FJE4h}U zh@2iwI4>|zpzq`$-WJB;guHKOpeAxEDn|7_0MKJ3L-1sMoyG0a6qW0jH?Xz4O^J;a zr9cM=O?x)5ZGj4BH|RwigZqH)bE~Q~Tds{lz=it$J>32)shgb)FJTuKms`Jg80`i^ zi4KmAYq~&a0e3QyhY<7^uyD&GD57uKqkgLW@B$G9kB|@v{mlZbyLWQCH=r!wxf{Me z@jxDBbtEJt0Mb`m>ttp|8-%kf`&3p1*j)>Y1Aqqk3Ziver$3!`p@0@&UeP88m`_IF zvgWUi;{e~wgBiwCQh0x;k*OwU>qriXu)Ym%2I1Eoc%-Oy^}x?8C?ho$u>mTNm4hPy z%ICyq4su&%G$cd-mVv#!{n6oJrrVkg>M^A;0M3cA>p{RdyYu)U;iJAyKtPb0o+S*% zr{99M{^l6)fPkGY#2(WKdh>k05h3^tWG%mt_7UUXiCgcnB0wVZf$2Us(?V>bnsEuE zV5gM-eEV7uG+C0w9IjHl?l9i;Uyj3vwf|p)^6lgM{`kJu!rWX)90bPA<6|9W8{k6} zosh877Ok-$Of3&gn}~pMZ?K{*vvYJ;nJvr@&SClc*2iv-fuFFXUg#5N^;{#pr(2O>#yey zc!}tbl#HInq;M0laya!sgY+?Ny6bHUntJn#mk_i+xdVcGBj+rh4d_AbIG>@kaB9JB<2-2EP|gedN&m8o}rOZQAvsKcZf_OlCG|< zfq}u??XPz4I>_ra8?dSHKk7p#mKff`6k>&mDJd&oUSA*2SMOA`J>QQoZz(||(-E+# zK_IW!g)B?88WdT4C@3iWBq*}mNhRnvtD&u2#A`&nsT>);dvZkyWk$pOMW*VD1BHnt z3TD0Wb`DgX)6tg?Z*71j(I4g@(Vl7J~kb*gbQMZ+#7?}-|0J-XZcbr@@ zetR^dA}i}P*6zl}#{NE?>)0blBplodurbi%Wnvl`A6HN$rLlpV$OpC50xONp@J~ZL zv|rHm`@R(3e8A?W7%&1?zVD@^rGc3Nq&YP;wOtKFH592jy$&dwAfS2@COJgM$9D(H zp*h=s_${!NPIh)+i=9tbA#=|E5E@2eVq&Ltz)Z(RKC%Fq=0M0!{!D;E0n$lAd)fo? zCkThFE@#GAVEsK@9l&yMtN--&e%)^cT&_62%?Y8l z-MpRPF#hxYt?wCN`#@qVl!Ma+6Vw!NWC*Q}K;t&XBF||f)sNLpEbvJ9`2!dI6yp@B z3u!Fjx$Dibww&DPJp}RW>nyWGlFcLF zCt}EgdnwI=@&VjZq1xo`X1gtT0tpse_)X&VKR_X*qm$mZA-U}CTZPfhxs&-ro~A!5 zH-zFxp|=@#XoJUHm0=auE00Sc_L9$#$0w9xUV3^Y>g;nAvB^~FaQ`?k3J3^9$9DW$W9~NlpGf~-&A<1$d4GF2<;wn| zlT*52Qbbs|{mT~~iwpmhf=c@a1gFu=c1aYVZq(4wFg!mPpj%7H zlcxbfDkKxz7^`|MD+dR*h`tgWt~m~N_PA!+Ad;Y1&^Z1d?I-^4b}Z07kv&+yTs>HV z3rR%K5{f=U(J6Lq`$Y6l(kQbOfhLjzEW7{ur*>wXj2BPd-Xh=Lu#J_aTu;5ygX%tv|E!Fk>9*6P{fZK2dT?^y1JLagh-> z818rVwqrAw-g!X?W@d0+w#ka`ndi*kfi?+^ACX;r^ zT#{z|Nzd>yl2hIMhd13ZKTet0BdY2NqEG9)Wm$e;k+)+}bn>n69kKVkR`BR&#xk<# zAY*g*mNLYt-Kyx8eVF6D=GRrq=4d%q)}my-T?ZC_vjUTlMP1@m&OhY*stV6x5Wn~^ z9Yq00fuly`mYy%0khC^_=4Pkkpu_DidRlip_DT0Chy8`D?((ac0fbQVDdrm-JOceYwt0L|MHOVto-MwQzCB+c$kEB zyLEVa4G-`1DguHYiO*S-vcFZ~|G(SB$#&RVBpxg~^A?kZmW-SDkB<i_Fj%+`#IF-~EK}5#+4tMab`Q~^P=3QAf39U8gTvq0GtGpU zlGrF9_`-=!M0%BHgi6%bj3Hz!i>oC)?_BM!bn4n1)JJ}uOQVxc+c9e58GGl%hY`c= zvS?()?k_O&mYPiI-M1i-arqow;NxaqdgXn;)hdR>Chc*2Q{W8= zdN5zB-gQzCe?)XW0|ys9O+c#8VQ0G2P>(Z4h9MTp-Hx%2Mhy+?D_zsvA`kQIh@7nL zMJTlnLpd(Hlg=R-A0Z(U^sEs_Wvp7ni}(5_OA|ZV762am7gg|J^~QI;9mz5&q97X| zmGUZXax^!awpPnYlo_e|9^Q82D8p2~-*}CvOQTM_Ia8Mdu~Ip@#6VfEmYzkrNnDxa zU@7?Bw~erB5K=WIgx|$acQcX7rm#YueWYI zrC@A}@!pneIV4$F;WRL?4{5?P6aRxxB9bKmJ+iER(BqQ{vy*NK7g2SX`Q7mgF@|Fv zzNBkQoz8?}{0hgPE{HBUM;pUtch|c4c9LG3cdZzXt`W+deV8!W9^KIJ)$g@a7z?X5 zh6;Vz^WED~n$a5-e<6|&x_8!Mwp*fg;ceG@cra!}O}ahn0*U{7C0Gsm(G=QEd62RnBlxG50ZAze#eA2VS7pP z%4^sQ606v0URBrTu_-mr0YWB7Fbx|Ddu>{7#+kxiwR)N1*&fmvZe;+V`WL`C{|s|% zYe)U&s{amBr5xpswBY{q%^{#9Nw#~bo$+Sfn6Rm$AnQ^)Lbi%Tlwy6L@Q--W9Y4d> z$?!38oYH>l9aL!7iohR(M)AG~?7{s}uNVZn_#5Xh3kCF$ORlPZ`-@Mc5EnN)>t9T8 z4QbJ{?&9@7nCm%mteNROgWI_iO;HZ?&nJ6K?&;und^`z}MCBMo!UgB-1xLQ7FMkQ_9ea3v+nK8gdo;zImoC zHxqSE_JWksCpnGQU&#%7I!g_SGG3ZF^hysD*eZWE|rDdb*vU9plpEwt*%7t%C>qp1>TL+G2%=KB()q1|x0b=S9^L@i2!*n4M zRCwC#;NwERR-5Tp1*n_{@q9;Lkp9bhn~U&TRe}D;zT6csRx?FGjCHyT?@fE3xlHG` z`URQz>o~`uv~!{cghA%rn5}YMC0LL;XSS2#=g`%&zcul3hn;`#`CERQ{`oU~@okN6 z-O3~3>PpEwZck^eJca&XK@?U6oR*UeMXhPeu7>?DPIK=of|=r@d2Pa}7T$bA^lCa6 zPGfTCn5%0FW`dmW>liNRiC4$OxK(m_kOjYL(6@E5W1(PosB}{2_@{tEWzzZPFU}G3kJIu&fEA>EXMF5Do@v{cR`w>=*(f z$Ih7g$#}RiL2LVA#70A0O_s0rj?2Q7t;?!=90NxLwQ^KK*3j?iz2jt@|7C}pg#=_D zj@!1jlwR7=<1TAf02~?R!1%j1q!hNoF+s_nybFmJGzU!}~%7 z%oz2l=g9+o#BhE-+Fks^Y?vBa=}!x+lCE2AC^kAX1qZq0&NuuLzu4s4P!COt@)(Mf z2-V&(jQDFjvG1}Fr7#Uf4uB+B`dfuQv9z6^lPGr+b`ICK9g9_+qeM3TgF4s>pfYLq zhH~=P4IbGwGH~Fi$4L`|=|Y%uMB4wlV#+*_V+dHGf0I$O_^RL5Pabyvm$sc{E#LQ> zriOF1lM16H<-?3+Q85L}x8)g4mDquODEB+})4L2XpEUZA_YpamS(^MvP_2lk4I171 zcIGLqQkYJ;npsJ|zp)Gx{8E_7KRePUp(Y=FP?&daz}(c~Lnab+jq~j(5$!|dOBLEXVdoXnIk|g|Jpm2pCjEXRx7iR=0Gpg)j%vmJgEn4IfI$X z(#xsKxV?8haWFk9s)TIEBl`tmzVEw1LjlTMOLXuyo(ns5?nb$;Q=PV>>t7tnYbW_R zWNzDYmgD&e!ykh>y0kbmbJk`8ZEFau#2KeE+!s-PxjxBix-uS&;(x^1us<&`H#4lN zi9cs^KCqR|#)8SVlYVm6GCG41?{&QaMwdDMQgr31j{COXAR62;)wYob4iVuV-ZrBf zxh-j$4@Dn!d15VoBd4CLjD-CV+n8n+dJ$e7dnG5Ie#+$%6e)iGnuBYW^ zr_VkQH~fIA+)>kUnwSH7oCjW#QscE#sp&R1mAEFPR*t$W!{g{uMb!EvEo4BQh$v1+ zCx?!nb$K9lcq4Gx9m~+b!!DW`!Z*+#fs_7eCk{yL`O|Q6Da4zSV%x5dhjy1(Z1j;} zaOIT#X1LF5n$g!?y3Qglxg-fwpn4v)nhyd8)_?Oo8G1c8zL-ptzUdbhs3?hC z1qoyXj-D+{e2bFL*S3j!Cu=tk>4mB5Ycs>FLEb%E_IDoKc>coXjoUmaLmTjZ&{?Pg zlx=s)B&Z*_w6Y>45=fOKvfW+c90`bYUujs@SNtH!RT?<2NKizbsk-yX5{efNh^0Wg z^>`iAM-KnaP8t60WvHdi;G;WLVAob?e4tuXOw^LN?;Bxc5>hu`m=jkh0Rn$f3Ywrm z-5_VM$e;j+pA$`&^vx%7YL_E*_lPsAxS2tcpnv_G3yL&YBDSEjBW1bVz^9 z>h00Y8_K{X;Ugj1%wqF>E=;;8*(@n!3RL){Z{?9f%6=ZLlGX7@d8G2eq758P69fwj#T|J^dhVMC2k9kqc%C6&kx1kqWR`d#18KPn`s;*6@e z-r^0es$e-Tb;Q%@GT0a(FAJbYZGJ0yFti$*HR0X7^~K?)RJJ33**Jw(S}HYmI+>b? zqa&j$aM&(_Yc&-;dKpIi2AG_T^d)I&g=6}OXY9>bWuiu6N=&$5(e>GanZqBCE8)^D z7jz;tsQrWT4|AHb2S@A6<8Z|{)f4|O+eSd-HMA9|BID!<-VixJG11C9Dphh&?~$G# zy`D&Py_4Qk5T}e(qxhuno#ONenX9H?W*t9&OF-J&O!_p&iQV% zKKw?G#xR)TwyYH49sKW?6?}^tKB+iV>I*NpRA1@AAO{K5_O>yob0 z=kLE3Z`pROe(!?k*ZZAZCp>%lesg*LUZ1~}E4DuuUpX5%D+o;LE?pI`w*e>PFWu5u zo3`pw>d7bTtAO*ypw^iOa8Ruy1UPjGW(VmPnC9rs^-`SJkt4T%*3rPplPmA9ovSr1 zqe>{%GU~66x{Tibw!>2w2R%7Dxg_NJ#;H$IqINl)PzR1xfDO2$TX2*o?ZZZP_v4kD zEB+`Zn!gBn)o3{H6<7ZGFN!fmnxBH4%%-|rTW~r@@>SldbtU$D6CUnNXKHNF7g)zx z^l|Hz+11ZBPM>*B!TRwuJ+V)P;Q3U6Ot)REd57QLSi|M^B4PVq|H;?y)Rh`8*}id> zlWx2J@<6SbGCNMJ?s&h#)=Pg@*{6dyHTRyqGUb@;&cz;!bETi%d;0$JV%f9xAGY59 zsF8Sj?~WVt?f1NOf4p`(bM<8{hsCXJSAk0kU>>+@9vfkQVJdJ%m3@u=@;MhTJLTUt z*skgRE;@Vl5wU6FAy;`fnMOT(f8Z(ij+%Woy{8^?FPm49u|{kEnfr!c{Bk;Gi^jHz z#Yf!BzrIzxn{DN(i+7y#b0i-DfDfJv-g8vytUv?#{!HtE5(H9olyP+iAzLW5%{F-6e%To0?YLieJBP!t%gvd|xN1 zZ2BMhsVn>9^rpW$TSGY>rK%tJ5XN7S_41ISJfz1uF>Oi217Hj9+>h@UAMa#-l6mE_ zzVY=tRk{Vh#)ypc`^%24e7E;*i;qo<@JE3IIPO#CpuTu%-8aLZPZ@WwyZd`zd`{xeHNpMzXJ4*g?-XD1FEgm{>QRCH zyyfw4D`phr>?sdtE@kW1R(Dnu_Ei3M_sgaWfBkBX^JSieRL7TkM7U)d72a(8d1;q} zdb;x4&l@9W%N@BpLy%$NyxDrC*8)%Ktk+wU*js24|04OcFvqmGzCBYWXge#dxcty@Zo87JH!17b@gSBhC2`gMEb zO;8IqW@im!_wnE<3mq4@zxb*AxHjZzCdmBYp$!9YMjz*dl#)<8gbA%}o~sD$^) zDI2H7ZBcQcr6H>=d*pGveochWd4zwobN`e-b6CF3f4z9_^T+K^z&mA4O*|x^9dSbT zSTOLrgHk{O{M;czP!JG;1I3|;2natZ{(B1b|He~e;?Py+hW`HkhK7clo16If_}*UG z3sq)qLG$g~w^LJ7vvYF{k26RJ@{)y+(D(hDLn$yAEa>6pbpKDTR9G~HSoXJXg_r0m zhZhekfj%z^OOe&F=;^o9dF+NqM(CAO#dm)pAoxl|LA3Ss8Z8Fmbt;T&pLC?9EhdXp zNJvQDym=ED8JWna)!w%uE{6Eyf=>o)W79u6%Erd_3`ZDSVjv~W2)MucSc&@r!B@ru z5-L02;&Z;XzCOM{EgM-`U(d(Q9b$_wj@#JS$Y$K)?CgAVvMRMN^5e&kwzjr>*{D4C zaRdaT{Bn?;y?rK+-SqwS@%X|{R9_6m`N@fTlR5%BFjV%-EsYHgM6E2ewD<%BysjJa zii8*lV=wmN&D6zu5e*g0(VyUrHzOP-*8 zf#5DH1F^EVKiJv|T}H#g0>-?+-34Js(#R;QrpDda_#Ht+!1YnqmoG1XN)UE}^)+C? z@x=xgZ((B@na!C!IfRne^P=#nsWG-c@w7@UEiE66$PmO)r{ck?s;UNc_7YCQz85xg zzbq43Um=K$Pfx$X!m_Zi7#D-FcLn2@FAX1&u21{xX=VCqfG|E+L#b~f#N2*u!KhsB1|-Y8NcEEI{?z=8kNHb0gtMM^^Q9*zK;tupTef{|E1up33J4<*}!sV(8;cPig4kxri}W%@18E@FAAt09Z0@*W1X%#n`)knegPr{)UqONVQNh2S|TI`+BJQ&!$yqp|}Z1NKGA@0Iw``wQ5uTi35kd$J}G(<@G z%DYV|{u?5NErUQ`xw!ZCP-4U&^I`V4n7Vi!(evE!pszuHJYbZGj(1 zfzOJlSQ%9nL+-JY%B2d8kmwi!l8#Tgug#1{}0 zaT6?KDEuaUaA61)H4W2j{D85EH)taHry%cT>F#HPUeh2k*n|TlHvHO)0Irl~DukR) z@t(%e*!C!0U*qb~z$a?#h;LkOGxg381HtCs2nKN2KteI+^8S{bbh=^Vq=UL2hW6r< zSNT=5GIp1XS`H<#2`%V;{k^j;tGR4$z|+*9j8MPjXnwyu-u)L$l2T>ewjXbrSk8lh z@(!r3_5AAUPk;Z|j)I{s_t2w{*IaJoH zcTVq@_AyBqo7bqX4S&m00-8%$(J+5pWaJW2?0VVU5wfv|C|#$oqeuFm#wSwm$_+c_Jnaf6 zb`3NCo*yMrZIt3%41vtt7lX;diU1U1#w8J*ZVvafH|8$zoywX{~aR;sFx~ z$epDKLN-?8mhVpy%i20Ra$%0BIPV^kz`!D8H&ZU-6iLDZ9!h3GeMKgIUh#}6G0YER zHh;`!cZILF_$kdTh=F#?=`9W}vn=Zzv`c!t?ZZ;!DZb;?V5zloyCM8QM!5V8gf0k} zp<7#9wY9at+G%EHc6D{Nla-yF9U2;n{6fOOfHa_UXFMO(X5V4GN>;Vi&nH`!oIrpV z+3~HOcRm`E+eyH|si%@;Y&~3^?TAQ) zVo7AM3Tij|oYSFvC2mle5#}axx+qCkxHQvFEfl|J;FK(@&76RMEHdZL5A62H>ER(4 zi+iRT*bhp;fV@}p_U0F4?pk$jKN^R_W#sH5BasB1WQO?*nH?VA;m)2>O`59SV>aQQ zvyG(H2)!WhZrt=E@yG;alvx=dB!K_s0AY+U2P|h7mzM_ZLX3=z5)u+S-yMSM736?^oR@b6cd(GJl;_?+}7#Jvr57L)!J2{@Aj&*faIo6T?r~6DEI~q8Mr|Y z4-dtZZ!gB$Tie>Q6-6W3`-X>ydwZ|0kEKSYD@=m+r&%s~Q`k&^7XpfDM3SS@HYSBYIOd981U>eFrUq1N#tWaMrj0y_Bb_Go>L!<@^@{h6g^Paj|3HBwRrqvR(V zCTh2fSq^e?az#RmpCAy(ex=RGz`$>FFzIW}fa3GznUW)+f^KV+s6g%JYZ?s#pWfLW zZ~2c%xI%%NpOPnAt`fk@bdq;LJP;t1yf*rAXPUj*1MX${$Y0{&O-?~Hao!1X2R-=% z>mnAKJdYx@yga-7XE5y>oYXY*lXJ}`{>LK3u;%0K?Y;H$*-5wEi9%rMe=G_2;9POv z7~o@Auho}S$uXBQMTlK3=vkKJ^?bGg>jWj!kVPFT0iU!X7o>5ig+gx09V`*iAbT~9)Sx?%UkzJ?U% zU8J(A>S+H@cZnTLjjZb@2SZQvPI?}xkDdqD`}%nXUxO;y8savaqwBj0(&KA9{8#AV z?}WW8=eucoM%6&_-zS3Hp1H(4hxw?e*u;W@XKoNMf=i9|aMT&}w8cobUb0nZ2Dt~)b*S@dMD zCFVNXk_MT}kx^d{=ZocZ{ zYaXl(&H7=g0pxi5Ck6Imroa-v{E_7B*zJu@Y^q;w5d9makNeI1iRQWt4Zw9VVrD zP+F(%tgXK*-*R;jODl;GsaEg0!|pQ+6OX}O%aS#fkIPTqsneM}Jd>=$@dHE2?TOY( z6qfe_<$1*O-@`wkYedcC{xA_DpLH2#c``8HzeesF7Lc`wzoFqcVjInAVY4IQRFG23 zjLnKmOVR{go(n`8#^2pLjOzU%T@gd7ze^D+I(~1A)>b9}_1MC?KAl6KyxfkznGtF* z;(oiohx*rpV>^Xa;8CwbJy-CXY1V=YH6o`ufd=;22Vkyz0p(lWs3FSD*z-T1*}3VO z>UwVyF!SkG4a!@SPzbV9NgE6>=G+E5lZrg8wY`nPm$cRR{nhX2n&L)|@X;m^aUJa& zOx4jkCH$gKK@N;p;k5KAn||>*I3tr&Dcm;=I~C4Rk*P-V(+Hcw{gm1x9(bdPhvp#5Z($ZAl%u(TGwEUjj zNdKUKf7!)GxNy*n%Y-8o-kS+eb9y`@twQf&HFe!w)tVeQ$jY8@fY#dd`_v45TCcqU zhae6?DxYotx@bA-3t5WpLS*|fg&%2lg(fcZBOjp6k#FTLFShYGk+PiV;dOpU!>=i$ zjD{?m4VESQ0|Mun8Y5UZ&)5vDr4@~y{UEZ;N+&Jr@8WgUd$?s)?cjn)FSC=)CmO!Q zF4VU!7o>2-BxPAw_KL^F>^kwu*nUk1KdZOKTYb{ReioQL427vV>skp!Kj>7WSNNXo z#A;K#GJ2HG2dZ^8$Ti#3lmGN-YleSswarOehZ>fx3Dff!#@GDXDtgP0BKr49U!07d+ISLX3;uNI$EoJ&JuRos)m&s^Fh-D#Zq__Hn!CRU-il%%w8X=&fkBP8S~De|SN#{&$lBz&|QSKG6r zZQIy&@v7RI_?fu4o`nz{@Mc%%kDgMnp4VoHRfc1rJq!TJq8Jn!?Ji z?Gv{iG6FD3lzGgi@uPN=O|?{*^GI8%$O$8lu8!#$;Obgz*HD&*Sl!XzMa~y08uqp!2{(Jq^py^@hE#?PE(pgTCIUQTvX z5(Tk~(@oTF-GDW>z+-czps!9I!SZh?R z=5jOt5w8~uLH0qqZbd9^;^1hp0@pJ?(5x=kNz#@mbap~Zl?C$$6 zMlk(a;HQQyA=+C98m9|RG$4@=c@hE=DAnt0N;n#%?hIu{Bb_VU_D7@F6z@mZ;x}l- zI&Uw?;$JJ{P0Vwc`oN=Gn!sa+q^tRLwrR!`j&TpplQgdB%qj)*+oL(?$Y1`hWa_8K zy4x4ojuh5!C(YM>H2WifA0FU@UdBi`(dCK3GeatR-qDet-6MH-k3*#RaA!(Qxq7DP zHZv9f%V}lMF2u7vT$GXTeS*~GOLT!UjfqE|b8a+p*zEmU}Bhgj-PSX_S41I=(q3O}Y+Ja)uuVra7zH{v>3JM^f>vu^KDq5)%uWt$Zamd87m_ zoKrE{n&O-=OK)1D2k%R}UPtL`j1V_lT%N=(bR@xy@D`?2_&S-3Js+!96Ruc{P5A~H zM|)Ds6fo1}dALV`4YI47TVq4NAj_~*$ObjzLg*sK%E~Gf+e1wogSp-0&j&Aly1No( zwMT~BGjmUgmi+JzI5+D$VxuZW`3W3m-i*^owN*vV7J5&JA_^~QH%29zcR4vZ@xQ-1 zTyFI%)rNte$KDs~rxmAFUJmJFg)QHzkIG~QJ%@_qLqQ3#vGQta?cS$rf3k#{15u7H zRc(0|n5p(J@LZf;$zM%SBqV^by8ZpIZ6hl6nV*}A)A$J0bubC=@nx>-fuUw*W}aVB z>j0TA9ohEz-EgSJ~_QoG;UJjv9;YzOibjm z{&TQUUlHRZQ3fkBUESJq)S@*?_VT*l*f_RezQ7N`#t43i0y&ACVd%d&@Y`wri zCDzWGhhZyEK26Ij>7kHiZx@H9e#8nnU}9wOIXTW&eoji_p`(*kRNNb-|7)BcG1{CS zQ)wq?CA=`TA1I~=>HV#KRL%xIEBNx+tJW(z;+0lUdUf@AkxEABD{>N2(g>^PihG81 z@KeQzH(ohRvp%8}Eb@8);=O#?`$r*&jJtRr}R^C6u^RS<#0j|5HF)GF(tn-LBo=eM?AFm zs7V4F5&iZpRFtOr2clG%v?s^Kp<`f3hB?Bnalp{k{Q6N7a}#+B&0hmemfUD|$`D4l zx>7Rp>FFsh0l_PDblg4v5q8PWRmfxvJ?`vOh1E;P-B!Zd^H)Jy&e`*V(d02Yi`_8JtorHQaZpZQC6=IXHp3Cf}?|!Rj7Clnyp| zM;hNYPe|AhC#vCqu%DzK@NnY-OlFhzK$-h!LN?=?E9loW@XZa^MCp;m8$EbgPFN~d zFAjYP3w&^B2p=Ec&D9l1u*L!+UQzQFO!>6^-uj6`cT4SZ>>$tIqUHz%5Dl zE^BMCzNIDYypx5MRbF%F=g*%X+0E?Gbj3_hVQji-$@>Bq(-RpLlYAY_-yQH362zWe zq1YC!8dVPu{@B=2)5e##=i56C4zD$%Ui&HN3^s}<(qk}ljV;j8(f#@J2S5c(ZBFzBtmYmS zLvgvLQ%A_9B$=cXjh9-;^?qO_8MjG=CH`sNUC3DA>ujfzB3NR$G+Ju;=>){SO|p8E zXj3_{rsgDo&nc^5_h`a=$R77n_q1ao>yoI~RSp-ON}eBRS62#2(qEgj`RA$!Bp^z> zzQ~sl!@anIYLx;&ji~$1m}sl!{^yEot zzbE8k{Zj)E90+Hr5pp?cY3cLtu%qmwez+Pb%4+j(&c3$y-lPGWiBG*Cz_2B!r77fM95q)~vd7EX%a2bQ zp%^$7SMATfY)VDT!&e*@E~PTRB7TRVpaV>4cHiXtw|Q~ag1%G5}%qYF6OrckJr zsF;c7=H@P8gpZN00!wmeD9XjvT(z~7lqAX~8V%)EUisPOLGMvc2e&`;62Fgn^3JU> z$PUV)Tlur7$SU|*ob&Ig_XTxHX{nUc@k;wMfFg`oW0nS0M)vhrCnFAZ4-bE*+zd$n z9uP`^%Wh((tu_bZgehe^n0WgZn4iSi9^w=wh}@n# z70mygVZv8ZRGgTYU}R;@=66{GU^nEKFHvocu_>C};=&yeJ4GhX#*-mg(T3avJ*X*> z<1;hy>qpz`@gV++UIPe~S|n6TdO}j=4ftexx3}In(y~Z>>>iBeVT+URLeB zY7BpoSCRS6LZP!cSR;q-Bt-|hsT8Lytmnkh! zC~>Xqysw)h)P?B!Lo{ofUQG%ZNw-4VVmuOv0NBvQ&N$ZBT9jW;QBmD@ zs1(FEy!>L#<(%G+47G*&z@!GA|H84>Lttd8r{u>gb!OAqyoDV98LZ>Sk$%%mG9H8C(1b^)EG|Z3YHG@9tUP?qKDlrm`0xe^zwPc~Y@1=-tAW3a zpl@q&rLFS}-SleSQZp34{T#C*+tTf^rcHjErc+K|HQW4pdjscebKZZPMZ#l)7B$rJ zEFMeDamLg2R6?P4Ff92zDqD;FheP0Y zt@+^lY{y8b2DU`3&-vE3u(15*CY$52CY^EB*7~a8wgtZ78DmSctBX&(;KAX<#fb&5_2%>k%ITB<1{m!FE_u{PxvkVB1DS=+^WBhcyHH=@eF@ z`C26NvD;%dY&3{r$svyTeqylYcRP4^rdOAfD7 z1GuCbtN9ocRC@ULO&?qScMuMTYRuL`5{wjozx?^ek{`et-kCT1MR@$Hdm31$+^`CO z@pv@cYRUe zi>QhxJ?cA7Cf*ol1-hS_s;i}iub=7|wDZ-!wN}^vYY>msM!e8KHl4r7b(}?^{-&m; ztw_o!TE2n<>3jOBm`6%EtuP|+of+(FCrF{t0-TU#MctO zv_ou1xvYt!lk%zBYAJj}uwXOjA{`9Sz=ef{x3{+k2S@>*qr_^+j4Bn9Z3Hw?D5nhJ zGPFC(zl>SQ5g^m0qHrcHXh{hXrv)0R!FG&quqHNxzYU!Se_p86W_r#!N?V7qy_6P2 zTT}Dvk+HMu6}KS_dzl#)z(Qe~UGpdtxpeSO=nba^HlcOC@VU1kQmE5r7l<$5O>!O< z6nN+4KR>1KKNjLZLe==0TF*SK&I`-@)Oglr2uWvjm(7<>mhJi~N|kC2y{3@NH{v~Q zd`Furx+N;H!ZX2GmPz zHA`LeS~vM`(8)PY+GGA&#A|2VyAe!Mi1uJ^I&yZ zW%?7E!qs$QNst5NYw!NHe%VqRO80x)t6;a3a5dl&;EizmM7dE8*Ry^emF`Gl-3<$R zNu;r3#po-HQR+{#z4c?+vn%#Ty|uNIt3t=P1N5V|(eQ1dl4_7PJE%m}RCn8RQvT=6 zSFD0!{mcAIt{7>6odE@umeAKN_YI4>b_WfoJ`)&Wi{)SOg>94w!0$j6+=cBLFvinA zcL6#rn0s@@M?P-Wj~iQU$PAdXw6*(s#KrR-PJJE)naI%gabsw%-;04_xRriBakp?s zH}zVbas8#NK9q8^Be~k+K$DaxwEqAnz)vxJFR*x8 zcK^A{ug8ZtUpzz;9nI55c6w1!-lEE)TYg%rhVXzLqsuyd$Vfnwza-dEKh?L0Vo5HD2gUcDl7qSB3RyMi!s0bd#`dpm*Nktc$P`KC< zpA$sQ&Z(1`+o`!asp_ueKng{czFLl1B=C53LP8+zl){DsJ{HZn6u4SO-m!^x)eaG& zxXpnd+*!xd6p(1#UQKC`XQYfc@vH;tf({6GqU}ZjiD*yHP0h^x@S>U#^VJc$)vR!e zIx3!{f7#GGrHX>Y#R_+qVfH^@2tfn~Nn#HMZPr3SGwc?B#hF$6NzyCnjZL z!OxXXB&E2nO9W-YxQ26AdS;bnhMB`+XzSYKP2L`-$cii#xg#JuW8O3Op+)?e$slgi z&>tkTB285vw?-N2mNl?bEB$;vX8LBcHEUifO7^)Rl41^VTm=lx zv+yX01tTFc_?rckgOd&g6@Kt~0)X~jpHJ;P7F#KoYWcBPfE-b(0J60$%Zfvtq8}M8 zVdMlDntLK}v`=X7JS6!Qdx)j|s^)Rvl1tW$=;SIEEUy*cS9REYNF9N$LPUVLG!%Hv zrx53Sh07?#1(~Lg1-oNuM8)3)VAE8!wUm$urHA>U;b$5}hH7HVFk|q4`sBOoK%5(n z*u^dB*t1N1KRud4_T+}{s{;w0+6Zkl=J(yZ%3$-p(9&hSRhglfW5p{Pz(WJbdXkqA zS%t4YEC(HO3y^87lx?@};ZD{I%^_MayKWtaCii9XpBx?;@EOe_`C;Ss7>KKZ)Q&p- zaaGC{)IUj&_>scyJIko$mgcg5$8BO9@(AnvHxLpx2GHC0a2s)tUww_##TB=O+BGH^ z<7a+%Xd1^Q{twP8|z?G{~x5CMxJMWQav*{HdzOBaXB=@X1S6P8ns5+#b<@p{^291Jn-GScL04C>(CX7O))uXve>4w`vvR! zv{CuR#f3rw-N*lzg-S@6Y!mSlMMjQXU(a{fyUWQfw46?1J4HuNWJTM-G%;BzR7!3M zdhi~sU<#eYO=$I7DK4~paOrdrM+N8>NLyQA9dvJIG>1i58Ts~>!6-9uGPrB$`Wh$V z2Zr~R{n270zzyzTB1?>cFmW&!D(%yV^in<*IeH5(Gx6W3_&;cxn5d|z0E&(aF?k-0 zQYcg~@E|oTOb$r1m;4C;ObI}4!d_8mYik#|>xr-M@bVhB`QI)tHCwKiMQm@%&|e z6^%&y)@G8D$@YOEN7~m3T2v4U!_`Xj;@Zc;1IHAJko+dEyRo_k#pWzUR% zBz`+EX*o|wOIZNp(kVWZ2kXB+KO(Vnyo4sto$t(@&SkcWv{FuJSVbEVnEus(sU&SH za(ulnRC%ky!^woI{S&n(K?He*M%S^EpXLniYIo$DkbBKush3b5+k|I@L zOVHJ~;brXj*bbyurk~!!B;LPFndii^%e!;i=! zK}!Fn_enHhOh8RI1tc)aEh7VNhXtX{{asNzpJ%F=*b1*?d3%4cZO=t>^FZhbPXVWU z%3j{kZs+|njby!q%6mA`Pah2ic^(o?JR7>!6RS(JnFD&JW<77&vcZT5mBB$KXs<$i@|+!m_1M%|QYAqU1-%QMqFXC#sfSJBL(?A5pq_`s**hVwbG zUBs^8Lfr;E&Sk9^RgJlun3&yO7!O%4mRoP-gjVoFl$2jKKPHk!7V*=8mM`sjgKfJ0 zWUsr(Z`!Jc?jMPS1I&BXJ*g)q`Hwk0I;Rlk47TyQf<)c4cUMLFzHL-E$%v4VpmSmF zMyw$wF0aqOevb-D9tRSh30}I5yv|KK=P8tno5jCsm`heGrs5M2r4(q!;!>3KSyuj7 zE@o}eT(P^dC)yVE!1%Mgf2GI5!Kh;)wm-2GOefmdsQ)8%-)+8-1i-YfW}2qkkh=sW zZCT6QVPW*O&tx-{->%N-^jk=p zPGr+Z_1hf7?7H_dzZc_Bi25Fo$Dc{{wf^Pb>er93J60qyAbDlg?;Z76fBY_eXz~af z?OVt8^+8q85U2T*a9SKbI_0R8*$;uY+ImL()mlEFdwK%lxkj_?h`7)#*-6={BC+PJ zzb4=Nk9-JQtY1)b>?iN0mpXB3o-?6gN%_e1c!zPavnx1Vx0o}Ibs)d{(#g@`%jA^IM1yLKTZxn^zc$9lF{;LYET{NR!`0k1*F zlbmUzKvw9N*^o`6$CKKZZj~|Z?a$Sf-5MQ6ap>BStbTRxT%%a*AT-153S8Hp!g}(R zjOUKH#dg!?TRdO4W~8aDFLLFO{}NG&x~kEuuuth3X{LsU*Izlczh3^$#_MTgsko7` zND{5!%ga}x_u(;{`w&u;LBXoz^VI*?K|LB=|8miajkEmE?pP_crBlVxqAzESZF+L( z7jaF;hM+ZPm2k2?RE8)A$aE~}){*b!fk4hdC>>qZ67ImZ_0KFnkIwZ4(cC~U$986W zd$ai2Yr)mw3zGoR@TjD25OpcS22VUIQ*q1)9@Yza9{n$g%i~!Me=3YEeB3mzfjm!Xtjgf#+UUVDhEF8`AkY{{{{QeEpc!ZvS8XiEF5hfTalE`o2{L+$nQZ zI2!BP`=l-LyQ$_>*0soC*n}&KUin}FyVL0J8pWEgo9RBB!WQlJc5?FlP38rSb6)*> zeWt1r^PRuwIvg9Xj(%%spL$$>$x5jsKc%snhcF$mtX$FEZPZN9XouNrj`_4R((T1s z?Ca)s*lng-UNcq2-zx}re@*Ou*`|t7wd8TDg5N1@epp25V!OtqvZY3Bbb9vVGVQnI zi~N)1--)GeL7p)*_=F7TjsD|eloRWLAcnZl9&IDxxvL-7i`ZKlcC)mF+}pCA>o{nM zvr(@_2Pr@G*%C87+;AFpeR@*%G3_m|TNue&x2xet-9PvE}XOsf#Ibyn(PkE?jIb0ldt*2kf7P`h0tboJccL~Moi zkF=zUPMfFjdRF#h=cVnQSI9?ayFQTApcITdWnOD<5xdp=!|Yk0{2bH%w}?NoGwz0S zCcD497xh;JdrfjDroma;nD(#0pX_tM1$^4B(cgGZ_MN5U}2)qgY+wD7O&c3L|A9;@Jl@m2JO}!&8NV}F|2}FeaQVfsMsA2 zxyX)PsQB-OeCOFk&B8VohXjyi+hndfaCnvKE?GRnw+TmZ)U86|0BgX=b0`r|larbTPRdhl`IcE%$m zNc8SZ*WmVk)*;bohQyj6q|24zpyfC~)=zmKH$T4Fcg-$tqGA_Cnd>Eg|gN zf5CjJPXYyXi#Uxf&fIl^m%I5NwBcI%c~V5qY$H9}Mk8{%G`v1x^P}4iLLoZW_N}RZ zU{qAeN}JOCT%`;SsI###I$Pn#XilH(MvAhpj{N4CITj_2_dg#jfpEG^1cGDc>4o+U z1@BB@SIS=Fn>4VZ@w&6IPm4d~v}U8p`84*bRZY|EzGnNK1$23Cmf>(`)SUQIRs}$L zkjwBU8jT_ki5Pdn8h!+&`Msm*f zZcuwe#4#V`>9Bq6R~PD;#UQr;1%v0wzqZ}=i}~r69V2GXwT4+H<5v#Wi2@;N!qbz< z=;2|-jAUw5mHUk;-sLv+mXZ@6WxnORc!yHDv09^K#8jN~;G687U5a&~<_ITyWR`Ww z=qrv?4bicO>xLA~Ln?XM=CU~-r5P8>@K#4*8-Z6>Jvgd6O~IrW-4NbNq}VYSN#yr~rPDkMcHkt%e%COQprl(0yLb2hb)$ z!w;kF4-&zhg7xiBxQxra4|k5_Qr|9m6|JCg`Un%G3fD5yNr`Dt4auCuz- zV>HtH{*vdCne4#9gw5Y2a(i55KXaWl70$Zt&Vd_-IVb#aU*8n(SbuHCK*S*})`m-E zRhR0sd79iQ_17yAC4&iGGdU7526s5d%KKkCR=C;|dn#)X+pMeq zc%zKJ*r!BR-nZ>oQFivzwywn4IlaY`!y|?Uqs2Q>}$;REZVqdtRL5Y{+%jhuQRDJw zF^g9Y7^~rT6du9y2dZD-m76uO(ZbD?#^D6mGR*@no_tk)(cik)WL=fC*FRddfALl5 z_|$ktSJ}O41Cr2w+FJkTZ<9h(FXx|4jXX)aV-P)OOoF5YTZBR_f31!uO~9ThVO@8( zQf3+v3-zEA>-HxU-1UC(NmQ(RB`rc6;+W_D2Hte=Y4 zoC~##4jdp`yQ+5^Q#^_CujQa|v$~XUgf!scqY2i$fj3=2Y~53;Em*!6I%VOIS4NOo zd9ST47beUa;Ov1s__X~Rjm)wGj3-Z$x?StD>lUS;xNEOe(_A>ye%E!^Zp00L+trNj z<>QRz@ym8oX?2q`IBxtBbo-SNW%`z{4W-O3>o2AMSv|*+J7o*nO4FVF8FbQ%)hYe# z{6;fn37{HP8309qWEh{5sM@Cy^R7CUxG7U$V#R71f|Y8g^$Oy1#Wd@!Tc_MpfN0+e z&OF-%b&mCTDFs9d+2>JcMAh_SKv4Pt1)@~oA4eC~pp(tjGqLs+kMLK$N?y0-SjTdU z6HbpRDr^vf2Gh>Y3>mD>t&l|~5vRs`klo?UYt!>sZlQ6~J;^OoZ>)X8`v*b;p=bS{ z9eoy=%dJjWEp`FtPAoS|an$%#UNJ`k2M>E(uN)$0SF1-uULMBT?)oPY5n@41Z+z@Bi}A;s#hzC{JRIQ1w+cudZ+G#`;TUaL0wn% z5Ls7q00*!D6mhgk(~Wq&?PBa&o63p_XSb0l&EFQ<3Pg(av%*AI z2I*^us~ucc2}B$k!a#Q3O}-BrGp?*sF)v~^%bdKFN{xqUKFPuS6ZeZ|c3t_Bya-M! z&Yg(r{>qjgH;wBSY&cjz%AlZy1+C&`v0nodPYw*YT!TBNbi(`nW-J>7PaV9~E(^)-coLZZQsAGX$uWraXQA2XWhPOw<@-V=J z5=Y@{V8;To$>i5uqugvGHt99j9&2-|cyAq=;oZ3rjzOnJ=jRO?kJMi|LSvw5@B>## zX2D|pp)bd+PU(O*V6XvL21A?GcpmNpesPUjp2DtGnmKaFA1mOAde)kx9R9u8w&hBK zWli03y7I%V@nZgsem~^SNn9kqM72fibqyUBg+RON)n*#f!_!=ub1&)bY=XG8S72mJ zVB&07Te<;yhxfHzRK*QbDd6@%=g+IDK6@VhfxEBz)JZiecAOYz&)4D^4PtnrVg#mq zby7-Y-r%Td^2g2g|Asi`ocNcH9gojf41&BK-ZnH|q7F4|lhQFu_C0+uFt(`4*ez{p zPF8YC29sDc5^&hKDqd3KldzOP5 zU*C5A+Pv5b36a7;S^3i|1bNP`v`u4dnykEWEKhMcxB4$5JtEt1aReybU3dUtds(*M zTOBrQrxFh%VQ`d_RPn3*ZC{mSpXi(~`a))$KO)*h4@2?}a%W|=4@pgOVB1GED%?dD zR#~Gr)NtrA^P1j%jCryU8EL@94CmdJOJ3q$4|9lhlv+oYj7X9Gbp|aBxlXz%-g!QH z7ggoW#!~4P^noSETqSf=;G8^LQpeb6(snLrC%1^HH8O}7RTI%DMFZ>Ew|$XRjF1}F z?T}}aMY5fb&Ci~9cl%MsF?4d;JCK%BP2g2S z#N-_xuP~nWm%cNnlhMX$PAbP$mEN^`&9BW z!C=)i-~QG^^va2A%w~eN$W!PiBDsE2M?>LUN+wPkyU4v%K2_0|{p@-Pk3(|Fzeo+) zbFSP{;$M5A9kNc)>$Qf92*9=L@HVW?d zu?FNVH(Ztq`>2`twe0;{UoqBa)`shD*p=6IX1UZxiiVvP(n@3hQ`o39U6m0V-ZTH2&h4kh z{=q(c6)E7Ap6x*N*v!mA?m+FCm@8M*cV?0cdbH^Y!N(cQ2n7pnig4}k3d2Pl>sa?v zo36&av_=vlQz`SMX{2RM+X=TU0?2uv_aefq-jo$gb0cva=cRXN%ab83w3HG{7z9Qo zolqPQj@Du^QgX6KW!DWZ+E+c4gtnm78<yF5IKf zA%-}D<{%sDkn-956B)cooI>9BU;pY7g~p%X#Bzuz^1ZSXUQM@!)2}aof#Xo8O77mUmHP1${}c z_kMUBRrxI+jJ7vW4cgZCgx_=YG)!~3amJ*Z)m%?~JYQkD_Ji1(nZ~wuB~%9YA=*e} zJO{_QMx>0+=-F9v7JBqScsH}*p8E^9Is0+Rmk9?{f~c7|zutWmO=K#pvS?gY%H$T_VIM60+-gmh-khN{`-@OMi$KKtXDseIHo)1Qk)cjfF{Wchm6xp*h8w@;iY`U?7k~gV+>id!abAK3An@NYnZP0MEs#f!(LJ zLtg8tV33ax&c0sPfUes&P@70=r!e?WQgd^XLZDH-H6LJN-g^O&9YV`_h|%T89X`sM zpyK!m?TRUH9}UTD0QYZI5{MHq@k5@ zTz1CS99xz<`Iy>uY)au+skv(?yr@z5Jv1A|Xs)rRtOy9~{O{r7F|iXwIlpnT{xp8y zj2$z~`;8>Ui+}(JY;6B~f&TxKq{9DCTVGK$4$rqo)pA6mfWMd+HU+&k8K(-}s(-5Hva+&7*0%`3XPlgzfRyL_>Cx|k7w{0<*pvVwtU}C}CAc1i^MIN+ zGb@W&pX#4z`k#NC6&E2`|EGDah~L%h{Cr-$TqIFkVq*8`s2bqOhDymm2yk(8q0k@8 zv5AS15fNWMk|6w5v#B&{a7Jv+EiAlzxOD^gEp>86MIOMlCnhGwOh8IbE@C$$vHV}R8(Ky_ji879JkL0eA@_9`@l#6nhU_A z38-;^Ap_X8pR9ntx~{HH$m>Yj2^j^YQw{jCfyhEzKyz!Wi0}G8Q#}_ULQP&Wh>Mf+ z?BwK=sVM|dsXS%@44FGU6O$4kxoWMl93ntPMg7-LT(AO9&F8!k_q{0y2L}h>8$$?2 z`TqU;#?}^~(P9x4OrNks1cm5pL`6iDHZ{$UjlCwX0vJUE1d>X#o*#fSwer`m>8UAp zE-t_;7_1MNB&({b9&b*&y1O@ymYV-tA>hd-6Y|7rJ-yu5F*Y{-*Dxknp%$?H^Kx_l z`x=70`_?cq@V$TwS>S;O2O|Jpz>=d(0o;33`IpwFfe_rx0|VHSGhJZPi?u!Qb_fg5mvYuw)5 z{5o{ug1y>_ikjtl7SDY6_WdV$yfS%qNRS=uXP{u($Lk|{Hgi@*d?0c9Zz(*DwOZ9B z(_ek(M)Ay8H}XmY0-i5Bg-nVJL5@KLLjE;LEoW&sDi5R&Zsm5^xgb+YxOmDxF4DpI z_-bT!z;in(3X8<7pHPk^&Y`1W+|D242oOmd$$vg1t1_w}c{W%}Kd$|eVO9v_p8vVc z8t5rcQhf@}D^9}hyg4m*ho=HSyJo+qGl^6m+4U5t{~vJW_?Q`*!rA4C_7msz8`cVL zFBKMs`zu(Btj%@CPQI&h0Ncv}U@84iS6G}W4c?OB8+(_pN+kS94bn53;3VsVxN$8l zB-R_1tbQ4mCNO-COsWec^fFUa^2cQqvm)#kN+&W@S?9bO_A=tMXCK6bpfn)^wpyES z70{awj?8KwEozSS17^Kauc)q+W@`~{vgIR%kOnDDt4w*>UPr8cmuZVfkD#oW@XBn1 ztRlDaXOwr$a|=gF5fMt#(jn#L<S{M-<>8L6sPt;t%DZ>SiT^&G zu@iZ-+KtuN#nC<9n>|;q#6K_c#lI$lsqUOM+ikb?fu1R-BB;B4nW4}sgDr@2_Ym>) zv+f}qJ$*`Q>i1=KcJ>C(LlQzlonnww*KmLT2YPzc*6oMe^JcH(BKQBNx3_GotJ~H! zAq2M&G-wFHg1bXP(BSUw6L*IMcXxN2xVvj`nYa^NgZn(=U29eC(tYZ~x#|ZD@4b)K zp4RTCGralzPvzK2ci}tuM}u= zG&MC9`}l6kZ3S3}A8(FRQxJ@7Fy+aCA^TPQzu~GkEdPtECOnXAXl~|q`kIuONJ&p` zWp3_vG)H?aD=t1{z0ydIe#A09w5MEoO@lbimrtq2YbgpVGwU8)`E z>G?I|j7WnCOqD5EmYsy!`WsIX{n2pgV4m_*vXL%!vzap_r+52F=N5AYiq^Gh);--a zWzLs>cYt#KF2H{4-4PHFtaka5g_CTGz}Lh4>qO)uVg}1KZxP(noA1O6md9nu3tgS4 zE3d!?3U#9>OY1(06+8%t4vMF8WP)Duii(e~FMJ+aGBOOz%>P<=_Fz> zh`DSsA|lWVM}=Wk0A~mYMEO$VSpazFf`Wp}%l5CpQmmozVa@BUw6rt;G626zZ+ClO zvH)V~a%h`wp|<=Z39gC(UqVAw!B(m^7ti{jSe|mAc4sJ|Hj{AbejUuDYNuXRGdu zgF(;9s*hF4dJb(8y+_v9Bc%+%A7g8W*#^aizK0QpEdEV@Q z^tKo1#Rl5!@EJ7O-L&A9F@ZJ|0C(8k{Uc`YK)Tc@6O)}iGjPa-&YjVgy|LTos&l@k z)d0dR44;i(Ue>grE!{df*>^d|!^SIevVW8@GRl{e7=8`MuK@UyC}Q%b8;%;;z4R#l z2P%_xRjHCgM=Nu*Z$_Eb8%5#MA`>QO@X=^Iz5=*p7Ci1zYnw8F!v!1A8^r7Vc)cr? zTfn3)59X#&fy7Y6^FuE5yz$MFH+CxV281Xzl!v9hpO(9n50 zJ+S_Ri2oQJbx%lG@A9nz=rNVq&XEQrPmv?|!85P*F#&T%VV9gE?}PJm1BrzZx3Ew7SleJo24U}@)lO}~9`UbWBuPbil-~JwJjbW2^`fHD= z=GfAh4L>)KrBPz=5@b~KJ!c5vuPh=}IA9Ob)RAVBj9=@$sHAkN-}g>@{8eu)TDj1J zZ<%M~E?k0c;AYOqASGB%Im@=$XNa?pHh)v%pbHQffyOBs48;on$N!eFCc500Ix?JK z(s6jK^UD=>M(N{iN*8!eXAaZ1`@H>0IYOz(-KvH1f7&>QJ%% zIS+qh4!Al2i?PCLP61(VEi(hq-tk8K^UccOf1%V(CwZ8q^Ld@@bPNnt#C*aM+L}%{ zfk{4d{koi*wbYwzY32l+t7A; zXBU+7+m}=7c&{q-d?+{#7Ad7G`P)IfEmurutuhUG=seLbAFSH(^dgz{d2nxq(D`M5 zq!$|X*S?+PJ0lN;SCKaQQ?FXVCckOHqZ~np)Z)EzEaSm*rM~2jV2OtFVE)?b^dE3L z=vw(Z2*v2V()Nyj)tYCP_SF?dpPhtd6>B;*HHU(5Zj;VY+GJ@B*)~}V_gX~N$NJ8+ zSl0&)%sWvdg0>aI%iurv%=@05_K{!R(@H$qut1j#*!XbsrhF|DWOU*J(-4ZE7JfM{pi>+Sz*ZS3lbE0f(1B zuU#o0Os-22Vs?JToQ4&L8Z@ot$Z^wTl)ZKL{YJm+#w*_`#Zxk{+Dwe^+c@=2&=>bh zZ|_mxMKj_@J5PvPEeIufSHE+9TjfK|uCIDjDYcI6Su(-$3$QkkIgY%L@20REXJoBs z$PC^PKu?t^Oj|T6L;uWZbu0;&zUXRp+-31Rb|xalU(Ll_xGBkkPd`*kYR*rg{W;a@5?)|$t%##&w$D(aN4R`A0 zLiq{*xFLUdSOQ=j0P^+!g{F?4f6rP;bHMYNJcX*@rf%r6jKtj;k^=y5;?}dQ9fGT$ z9=3`k|2WkHxCM}%nJYo-9LfG~gtK?f^m!&Cro0vDbuX zz3psLIkwnHnp6HhG<1#$F3Aje`n8pIp75tBwmqK+2L~U2VR5mfq-61~>EmaTbPG+# z!&CP{7hI7sp5jk6=M2A8c7;6_WpnqKIQ}o7b_eUjfG|*%XFA!%Q_W5&BEn<-+yN{p%5t(RnAG)9yv6}B(11wWNLbLa8OcK2E=}c zPY-XkDg^rtw?fUVB@j-*lH{PB&Ma0*jei1VKw7A-wi}6~?6dM4`*$V zBs;!i4ZM}lhxd3zS00t?@-|{(R}>uJP)Sm4cVwA%Wz=UUOo67L14E`3?;gs`HL_2o`!=0>=N0^O z$2E5D)a>Y$iMQrmR-Z(;u2NsT)CFX%zNtlh>eCrQ9{CSODLjLdWRj)Dj^}!k5wof>>F%{vwPgl=GI^;*;gUk$_OroSH-tvgNp5d# z)c|gOv`R(Y&h_V`ozxJ}G!y-vaZ4HDo&FC{419ULDxznKnmuL=_?GB`9uM?>?R z-?1&U9i-W5IosZ{lpcIV9^#{M}c({2*_v41z#4#rcDuT?fA)3a7>k*p9iOP0LY0`tI z3=9mJnVG4nsrmW&Y89F{OfE~;U&CkF+Y?pKOfB+u+oyXYm|AIQ-T331fzcZ0a@DQ+ z^xo?0v$E3I4oQMRgM9>D=|3ql1x0*z_68SNwiB`m_4V}}b9zejBkMrtbP04UZK{8g z;V#Aj0}2vc=ax614_+V6Zj}QpQAmnB7;qmkuXLdsm=4hn^|mvA6K5i( zaD3bv@Q3rQ<2nKN`OPeZwe4y4C$`*rVv=6(y1OA47g=>ZRoo(BGg|Ah$NCFFhg69c zns4{Jps)V6E>w>&!H7|uUTw=nj=)kSOP}OAks;ZndFSo$`$MV#tY|$zvX`_Ay@^w}^I?LLfZ_52XD$5YCS|65? z@pGN);*E<}$Kd#)8`Nmgv`Emm$0RFZQD!G`zA81s+n_^wBG+%RCmX5qHGaX%sY%u8o0f2z6;GmoMT~+@$w>L>z(G$)he#8D`AAg?_pMD7kzRgYVUKzP^-_*0z?pxCOFY7zUP!!n3u%#pA@5REQR{N zllTJC1Z{tz3g~m$ZYwHx1^cw;Kh=nOdOrRB{d**8=%721Q&6z8ql2VV#?{FwK#qoG zqqmcvp2o3)AUG$@3`=c!jUmX#*(othYU!>~cEoaE1kv{nGB>Dc-ON2p*&|zSD|eQH z==&;E&iNOZx|?(wfVp918lbCBPfr7i|7X_9+w{h<_@WiYKlz?VZz0;d;Y^j5^aptv z_6^t9PJprmc;CXp!rnVo*VM?2H52lVC;4jWQ~FuAX`SdChH8byyd8p|prFiGXsPMw zj1LUJn;i{=<72;0VHtJGuhk(mxarM!iW?TuM7e9DCtK4|`A`uJD^u9+yRv~ zEbKR+5Z*sLY=rR7`v7%2CBmBH*(P>8cie}!wzCkiivTqJ|Xs# zHxW;1#%9%B-89^GPd|q|miuIyoq~v;dm1>{>s>J%#LbpNAzGc8za5px*g>HMd1i(t z1l#p-*^oh&O5;I*b_aYgwp39wRQ&z@i)D^G^{9EkJ-!uaGn}=NWx=gr__%MNu7&uD zkSomu{f4{qjVkSy+H)OvL%+SRoC5qgh6F1T19itLi$1Z)+C;%~<+p(01yF(LLw;P) z8949SLYcr8HZy9}^emb75`c=~Y{br;jK2TFlHOWUf+Xk)SOEhD zzCesSHYW0R5C0w$7H?TxZ*rrAX*`HgI&3t(YvZYrmaJ@S12p?fKI24H2kWt_DlblJ zesR5{q|PlTw7L{AX@n1h^!Zn3bD7<_&x|0a<3|g%tkEAr6G9h`mvzf-1Bqki78Y3V z-oxCcGfVePT%J154{L>I>`v{tdz>8i(OH`a5&VDEeFq?6 z0CmpK%kwz@@($aO!wcf?e6*jHcZe%1D!9b&`ssWeh?CyE?dkEoJxK+qVnf62oCk~p zyB&Dm&W)HWS3$;>eC9`-ff~)o>T9^HZ?kQzs{s~sr_^Nhc&MZFKwU;v6|WOBzZmMj z{-vaWo`JTb0Jev!R`#3Qyj=25`VZYglLSurI)?ZJ+%1AH9>A0b!^^c3fi4iEE1`-D zK|}>n6Ha{Y;RxwmvR}T}Lk>)9C$j|AwT#Uh?8E6a41gy%JOo^m-#7c+R-|DYad`Cw zzZtWFb?H5iGUXtUuH46Hp*xHO7I)MZ1zfUbxwr-hV9D(BcYA=xaj)CZ$i(F3=HXro zJ=s~wvQUa8OGSWtuWY`V>17c!u6B7moeLNrP(X$@F*5^9Z_pn-&&|JU^Won0B(z(C z#XHuA%sI4QKZN4Cq;jWmsG-JVGJX#B`~%3Q*#e#}Kz+`x@MWj)BQ)V`wq|M@S-{`| zv3Ey7IUoKMYNUO7ejnc)M+&TI7uK|dk)-^N7N5!yARbyr@KVA32iS@aa$wPpYJ%69 zAkcSfG%iuESB`Oi7l=h|<1$c(1OrAD-@iB1*FX9Qo~c?&EmgM8?$Ch_U7)VP0aW+% z5e4N4T+=}PZ-0Bc7O+?)=lc026X)ShKvaognJD}CGnb|zaXO^>1NDHa-Y^h+WRH8e z`(@K#B2iPD0RpEvzaA1ip=1mn9Uvf_zptdgez_%30R=Eb#FIZ1I1?}vJHvE5`oeAv1e@?zPT$Q z2Oyge0&HU5n5hsFrgdgV*Fkd1g7MxIXE(7*w28G?o24Jb8pPG6H8#2c!Go;qM8t=_ zdiC5iPeQMie61KZE|1-yl!*70EE*Mu(|F&k&nMqyj@_&lz!}$8hX9f=5Eif2lU!6r zhvpPpSf-C|)F~n`jwlXlJ&wSwDUim&sS+g$h0#g+^=$z;B9bM|^%iD%(0D7o`5%qd zmoLZD6+u;@8wM_$Jqf2(eYa4LUviV4;&3rR>+r>@QC`76t59Y+_^}ig1RMiaD*y{- zZAd9X!ws|$QtfP;Rp_l%P*!_jhS9s1Imgm0zuGW7%=T2@-#@uUNed?TU<}an|NVU? zsai^Y>4rk2&k`3pui(;uHyPOjN|Dx*EUil=k>Jb^e(2;iO0{Q6y%wyL<)9@|mTw4u zf482H7)AEJ;(*LrL6%vx>uVm7vI*t%C!wVb$JcDYt{gZF4X*kZ4;9&MdOQ6_Vg-_*ByL2GH- zV%+PePySm)&H;|Z=zMMrD^c-sTu|=${+Fq;{KHm5u80ZV}oXW<+k^>4M{oHiY|x1w;bOfoW5V5G%uPoeuzeR z$WC*-q-^OrjQignkwvSbWB9Op{)HC3qr;Cz%tuG&0oRJ{z(+7(Ik2y>iPNX@i{~BV zjX zM}09Dz0SfXWQ3>VA~F)4pHX@Y+LzL(pAd8_;UEwHfa&*Kj-b}O8t$M-t&ccO&x-7J zRw{6^24k=4Pxu~VxQuVt^n_mWl1ZPeV$^e$NqTk1&q8+CE-f$!AHQR$im; zP*146gqq1&zLY%$llzL@fDzW&?sGt`5?W0-cYW!X)OuQXrc+TX_k*7nC4)2Tx?Qj0 z%#fpc8Oq{21;@klOcUuoOG(`)P}KV9Yr>Jzn8%zh<$BijdQBoc#cWrpFvDkD1H4A# z37b!HU}rfZ_j=FMqP9Sx)|lDm8xsFjS0r>&L!xm|O59gEZ`A0M^LKGUmhozF(OCEg zp2T`NWmDK{XaH%Q6<~HUYRC9Wnn+5lEKQq?FPjo zS1%W^#h&+OxeYIZnD38Lb{@5C+RKNW$z$f|N4ZF9_*!L5Xs;&4m4j*t=QLlhUpbyo zvn=$gFPh{d=v>y(9j0_71}(W^22rY)jo*}W9>RcQIa4sPT=1TV;{|Wp^AJMwjPNz8 zDhg}GMD#x}4NAG6;x2+71wS-~OKo&b%E)=)5XS2+g`F=Y1<5!pa5SL?J>ziL5;F7t z=l0;r#SrTMAZXXL9S|mNEV`B{PhMqz-(~s?iJ&}gB5e>bpqdbVv)27H6hYmZ6;?If z=_h0uzNCfwLBU+aSlVhPJ({68*O=4$%)xyZ9hAmPw7_7JTORxq#s~&DKtZ$12bqq} zL<`v2zyXMb<8_h@9(=2pfxNuGzn`j&9vT%@1+M{5RtC20obxJ$u8U6o9-IJ@gPsv; z`fgYcB3R%?!?G$F8%t?>2TW>FSsD8a3rfxwr>7Q&`$bx-UKTq%U0+w4CSRVZ+CJ_G z+bG1w1Oa^zuw-N##(-UY%EzA&s(2tfh4X$f!z7gm3z2rwcx5>t6T%rABT(1K?jMqK z(GdGYODn1HjeKNCTvphycJwm(*T-SNQt3G?%&)h11nBEvwub>miQPI23nvB!#(GJ~ zu;g|4iexqh8WlfF4YpraVS$~aqnwMqm_c}|2H;J+6@=c7iZ2UX?8D;3gt(+6{O)dz zy*uu2W>|IC4 z@zD{@9*(?{Qg}$1CK*3aCud$VfrX*M(flAtK&u4M)x`q?2KLB+ z`RM?o`p-rhFuO42Koc1Z%x_rre;{v|1i0#dcGQ5OLHu8M0v!0WQKjNdGO?txI{C#% zzq#gNajfa=!ZQS|E5J#%aNexRW^P!^R&$VfX#i20W@_Pp!sE#=EdbM9sQP^_k%^^~ zvzXyrzBpF&-=OFY8TG*m<&@>;xqG1hWr!)IJgd!ImNIw|Mgm2jBMepz*R1IZ6uE9q zP4Oo`f2i_)aBt4ph1^_q(X_J>JBM^DeG!u!v2pdJzIoqxSiqzf6OP${-P1;Y1#fm* zw8ITLdxv3Z9=uk4Gay6v#D_KcBh6k1Y1!GeEbAnL17r6WVH}B7MlW`}u8LPMu&6jc zTl%9@Nz#Z6A$;|K0V^PooUcVX+psK(w;eMaSSM{sF-kJjNuGc11u@HeZ?m8u%i$Sx z7KN9F-V4CM|AcNDwfo4ahfi(gEY{Vsi7mBK-29AbH$t{~E*Hz;=?@X#=&H0-tl(_? z_XDE9lwqt<*(zAY#W)#0oMRBxci|oH*|BpdsD5{{oYD}T$eiSR2>QeZs>MWRHu#G4 zR-u@vteJes@q))?wIXPgm$LOz8^Z=TqoUPC;%o9wV3d^}${ADwTc4vAQSGGm+wWG> zp)0Qafu~!tyq5+uYYa>3Wza>|Vg>WO+Ku9HnmK4yah0Cne zola@!ml3+1g(a30!@S%c7f_z)gf`yvV1*Trs$!LiB1)aY%aq zT!!3+1y8^(F6AD2JpGY3d$Lt%YRj?x?P!Up%p`R{8OdoE9FUMPd0=s+ zN>hn#M#k@qzLA2TEd{0-{T^B2j!p7ZnJS~H7T7h}E@frpe%4mLI^*$~uB*Tghl>Jo z0DmAgl*}cELXd!aZ;mA|W$$41!vK4nzvkwL;YpwLuO6>)_pCG>-IcJ%=u9nXQx39G z`S>?tj9i_i7Z#e~Ibr=)IB!5tDR{{VTG43~^{OfkJ(KUc$+~`c#X?)0za%I8_V}jy z$#JVI1blW>u3~yR`H;7?f+O z)u}4^#$Bzm3SuXdSjchOu>{8x7-%_tat1fZshqHohs*=d<4m>Cs{vQ`rt{&ushI3Y z@DbkU**Ccy4%=6UPqtsnA+IGcM&QpcMsOVM?N&L?UrO=h8jQvD4+I6jyp6ND{UnRW zoE%UOqe@|xqiM^ado@hP8dVsodWFDV>7#Y3@-t_{1^e&+ zo-A#s^X!a%Nc({Q3^WwcP`!ZL)jIfwL}faF^y!{)j(E0O(Kgb?z0r7s&(xWr#oyoT ztiGsyuKrRLLDyOfbHN)dPPT1X<%acv5DS;&hPwi%VGhh*Ci(QoFaMyiI|+bdOQ+v(PoI?EIz%Wa1tyFR`aJ|Gbh zK(INPJ>i}3FV)~wUJt&@Xw6pGz=r$$qKMiRutHb|QX(#Am<;~oDo1(}#lO-@SA+Ew zD4tgNTD2VrMrv3$CJ~zLo{^6ztYDtMadE=^>~*iwyCu5DL#URGkN*ds0O$Qxt>4I} z&S8yCTg~7jJzlMIp6ofwqIy}cZ))xiPMlXBudK=4?GxFm$27i2t{V@Ri6JvljG(OV zla5uDes!Kws?}y9iV4a4b!8Kbpx1n+TE>d!#B07EezIpZf5Es*ol=g6*S5AVwyQp2 zxJ{jR7wrz&J_v>toJnsYD&o*S5P6MsFqlqQ3_Oq#*@sv-HoNahfB#qlUm@JyNDuQu z_!+*M@B3oSPX&&LaA5>iR(+|rAV%^{EgrNLlPM)4MT zH0tI{K0)CGH8m-AQK?M3y8tg8N4HQ1Ms!M^p{I?e!y`YINV~Y(w-xrpHnP0NageV` zRXJAoDY2)xt%0lQDUAyxejQsSSYOr74a;nR%V7nfY~1Vl6ZQ1g=6Jnnh$&<~Q;7}r zK0;bq#FRs@&3Suf=L<2UcP;uVW8w8jbn|C4@t{SXel?`@aW69o+jml3PcNj14vlmxEmVC0{} zOV4hY6*XINULXDqT$}$^$Q^XkV)!q9O8YeF0Ra}Kg@u>@9iQk=_reEt|%UU|39Bto6jeNXN`gw+w8MC~1m;NNO|PcFe*mnZ7=Rc2I3 z9F&H}J8aC9(#Gqu6ZgljUcm=ggsw|M&XXkiMg z%Rf3Kh|u8}lJlpBJ(Mgw0yk#{x#%=&4Pv;?`*nUf`^Y1QB0%e^pj%s0;Xs z-b}2B^EH=<(Q*-U{iiDRjsZGx@tHc2CxEg2cF~&u=&tN9jo6`MXbtD501VEO%cM6O zuYWLFb-5y)JfU}gdz1ANxA-1yUOJ9S7^J(RE-H)G&7RE6lAvRQ)|fZ@)+ivco6%Ky zoq(3+vlk5OE^S|b!2uDW#?3XmCX|gy5DCkk_EDCe!b{;BYUns|t50M}r>Z6D)D|RJ`5fjwm0z=ETXmT26(2!#Le-Jn0Ho%eYO-G7~k z{iC1@AUYp z+cbkzm)a%tie=KNzVAaKGEgJX`dm445FM2NQm-D_H$At#@2E?<|HLC4Fj|-Vu4c4f zX3g&0axRIKr%}viYxc78!4}EFND&ihm@jF&1pi#>w`)Qct7401cV~?Tx%;6&zcIn{ zfZXxIDXg-$l)+Y~F})15_an0$Jsx=dCE|}@N}Q|pjf4@4FBy~eJZ)^^{^j9(hE!CY z#_KFt4~{A5f^3aWVy#Ogq%)N$*;7n-j4^!abNiT9hwWOoZ#(0VBTa*24g{~va~>`3 zM>9)sl&x>+W-Zgm@653bi@j_YSBjn_67nD`4_Jo!GE%+#Cf3)ouSRu*HSixhBxiV)@Iq#Er1_v+CQUw$ zQS)4N5#)&e5)!6hH%l8V<+;lpG)3dp9iF)MXkw#17`h8)`a`j6lY~xCM&f|+N4?<9 z`J1({B9jzZ8lKzZm3L2>dp761=8rGSZ?&vbS~9=)Bfyr6kpe@MiKJ!)uWSwc4I>^P zDhRJf>`&w1CaY-%WxjE|F28t($gFiKznXsV33~0@`?|K}>+J^AeMH>GA@8Hs{th(liVt7i|0Wn}W9t0&A}BJ$*rsnQ!mYmjetwcV)FNrA zzy!UP@ET$wu)i*X=H~-6J{gG3LgzrIcO2@{!PT0RJ$S-(P5fmStIRV2v0YJWcRkAR z)`?u5dO@ecA4>_~ucLI@n)U3zLC?`9R^=P&wm*T3)p@PET74KUd<`42g|ss9Ha)ob z5w~O0GYvc`T5#3TFPMPPS8r1lw<+$pxKiUqjM9gt7cidPlp%inH7=n#dm#tc73yv} zP_=oT8@@!-+@ofBuSeDcY95RK1fJO6KHyxhLSNzc4xnG-`=ZDfsUK_5^RHqzP0!8K zJ02ZG5IPkxmOV0RvRRX2|Emr)L_55hRc+vx?$=y% zl_ha~bM=Novwhan^JDri2n_T+Fxq$bhv(5cx|uL$e%2S8jfa36p@ggL=sA!Wn|GO|Lw8l-Rc6dbkn4H<`hcCAh(CgL9sj zAd>yU*GI1S4w+e>qnjJ1Tx^hiZKJuHJ}*&hz4Vg-(+acW$KYDc_+rALMss>IsD@rg zJ>o(2hrq_^VcrC)^K_x-=xBWo^|Y&oGcIPrx5{r1oP{{~t&(0J5JBtyleLwS7{9OO zRfY#>Ci>b$#XTny7Xt#l-6iDrJXkOWKsd^0V1dmwJ{esQ@!mAab=lc&>|ULGff+~M zj;PKPYlE!1*iSoM&C8cl7g9S9jbQOwLP7eQ*~3g67G=zW_})D45rVfnV&}EhKIiEV zQeXo;GW^8Ogau0XhJ&Ttz;q-QDhll-y>li}Wm|s~X zblhUyq^q+H`7Uh_Ps+K8eZQN|nZ$?4sxK~ZezIiTcm1>!@g3#N)>!14ZcOmCp-GVS z_+c0Rddvc-ePNL>X(v@gHI4P9dn0UU##N|RIx>odWsDrT8>6#aTg9Ht8gB2`@?o*z zWY%uq+P&CRyr2~2jTR5JAz{)RZ8tJ0`fJ^0wC#jHMqG0|ncG=~iYWbkDN79y`{5>)Oue8fvo1VnJS`0f4nflqOFE`{=qv*n^cLRc5j6>XZ$s`I|E$kgf z_-K%bW9Z6Tu#A#3l-PRW`3X}7@Ezvpm^#*;$*0G|S`SmQlNG*K7Td=$xl3ELmVR-G zcXYY=veE=dkM*eR6HnY)zwNG{L6!>rxT7!R+D(eu5VOcmr|&D?=&E(YQLxCz`&QDL z@X2y=NTCE}$efAH*G~c2jQbai1{>QO7Htd$O)<4v@z)J=_i!DfHPP~CHu^Hi6)l?V z!ypD-P+gOlk$F<~gqsnR@6wJ>;p?6UGW!aBUayq32>8Xh-A<#p|E5dF$5s6A-J<$r zPKH`=kequ83)_~ltwzqFPVk>aJe4!q6x#qMSIsU%y_q}4FPsf4a{df%$cZp|r<@^w<5H?uU`vh`x*v9q zl9uwUr8n)aaj2|3xjo`_QM?m8~etreo3{Pwu7rb!{~UVpyum3=g- zYmAra#|7$Kf`_Wou@z_s&0=M^IzCPyTzU&k!H~%Lppfbz&SRY+E-KPu;EQ;~z=>jW zl2&x0iq+;lf8Kljd07?*{Ozm#&vho+hPT$5^|aNwmdlTPWQGiT5iv+mL0th`NBf+{ zsL-yD-_=FHJEtQMMXrk{vSOX2GZm88K4+fSzEXm>nSPGb@oPdWKNjy~2`w@CO9v=Q z%tEaj>!;Lvmj73oF)?GcV@Wl;nTrs0q*s5;&C= zffr%2a*(z)U2}TTN-_*V=uKj$o7^uV*(WjKkk3Q?Dp2%;RoW`YLr+I{wtBGguy;>Q zmUg=RCOp&SPmcw%$J?JyfD+-joUIl9h12+mB-8rO{E3;4F=XmuYwi-)}Egkr<@)e0=5#_#3FGDGgYJOgZWDEdJ( zvz=GGN^iNYwdq$N6vy!}Vze(L;5q0F{(Wry$#zFeB|M~d1Lx?QK(N^>at?w?7L%R4)lvXp|@wkNYYwh3ci8=hPh z&{ls8wqax~GYJL1Ys{-29XdVh$8mND`8`6Pshj>#W|*`b8(u7v#Fi_@#MSv_t5Hfn zCHwrU>@4+a&awZR(|T^jT^*OmYo3ys{o|2$lT!>kaa@oXZ~E;CRzv=oj7Q^4S@p?^+g9NQ^ut8e8y6$85dNdeXBkdp^4U@q~;} z?BbPNIU_%#O+Od1?y+4$k1t1B*3u##XOYX7q>JjA7hbO)&~K!(Gdk5@npZ|>UE+#= z!_QmKu`#pG#cY05Zi@19bzJlRHBceIXNbqn?YLm^go79gVO-=m(j+_r5#fmJP3OP6 zkM4Q8c4;x?+-!@XX>DQMGb5QGV4C3lTy~-!@J+s^oPX0%?_(GTmUX~y2@%*Nxe+_c#yH}IGc;0i&I7JwWr%FiT#xuSK5% z1rTtaeT}SXQ?rRkxH#{EQO-u6YA_ivQPCr0cB*T>Pre!_UB|HwZ3|R>(_Go%VKuw~j5@!^Cz!EP-z>=AGFH_$W+ny3 z`wl2Jmil%LAshd3GtknqLt{I5w^$Sn?>>;da)^(_xWAtwn6&V5Q%kVjQ%%`Fi#J$K z#Aypnd8pV=4vl|R`uhpohNe!8j=Yfbv_o#7w1D+(+2wh8I6?=o^WilMAGpNeNLaz1 zRG8xzVYuS#XskB)^FzweeO!G^e@w2Jb zTy@mKGHg8#f{XXGued(dh$`(d)z05f#@hP|Y}Vaf0K6kCeW^2R#k$M<{i1 zQT1-i_@fUsa2zv`QNRwODBp2z|9yfDXU!wHU$Di=Mo+VXfzfw>{r|DP0o=_0iVDjs ZqAn#(9&-{s@O5BdBt>OKDuwm^{s$pBT$2C* literal 0 HcmV?d00001 diff --git a/example/ck_tile/02_layernorm2d/script/perf_test.sh b/example/ck_tile/02_layernorm2d/script/perf_test.sh index bfb7f9ffe5..a34624536c 100755 --- a/example/ck_tile/02_layernorm2d/script/perf_test.sh +++ b/example/ck_tile/02_layernorm2d/script/perf_test.sh @@ -2,37 +2,37 @@ # run from top of ck folder EXE=build/bin/tile_example_layernorm2d_fwd -$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 -$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 -$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/script/smoke_test.sh b/example/ck_tile/02_layernorm2d/script/smoke_test.sh index dcd40fda40..d56406b6f2 100755 --- a/example/ck_tile/02_layernorm2d/script/smoke_test.sh +++ b/example/ck_tile/02_layernorm2d/script/smoke_test.sh @@ -2,30 +2,34 @@ # call from top of CK folder EXE=./build/bin/tile_example_layernorm2d_fwd +for fquant in "" "-fquant=1 -prec_o=int8"; do for pr_i in "fp16" "bf16" ; do -$EXE -prec=$pr_i -m=99 -n=13 -$EXE -prec=$pr_i -m=17 -n=16 -$EXE -prec=$pr_i -m=1 -n=100 -$EXE -prec=$pr_i -m=4 -n=128 -$EXE -prec=$pr_i -m=80 -n=127 -$EXE -prec=$pr_i -m=22 -n=255 -stride=256 -$EXE -prec=$pr_i -m=7 -n=599 -$EXE -prec=$pr_i -m=19 -n=512 -$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 -$EXE -prec=$pr_i -m=11 -n=510 -$EXE -prec=$pr_i -m=171 -n=676 -stride=818 -$EXE -prec=$pr_i -m=91 -n=636 -$EXE -prec=$pr_i -m=12 -n=768 -stride=800 -$EXE -prec=$pr_i -m=100 -n=766 -stride=812 -$EXE -prec=$pr_i -m=31 -n=1024 -$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 -$EXE -prec=$pr_i -m=8 -n=1501 -$EXE -prec=$pr_i -m=3 -n=1826 -$EXE -prec=$pr_i -m=5 -n=2040 -$EXE -prec=$pr_i -m=7 -n=2734 -$EXE -prec=$pr_i -m=1 -n=3182 -$EXE -prec=$pr_i -m=9 -n=4096 -$EXE -prec=$pr_i -m=3 -n=8192 -$EXE -prec=$pr_i -m=1 -n=10547 -$EXE -prec=$pr_i -m=3 -n=17134 +for fadd in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 +#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 +done +done done diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 2c423831e1..3b198502d0 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -25,6 +25,7 @@ #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/int8.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" diff --git a/include/ck_tile/core/numeric/int8.hpp b/include/ck_tile/core/numeric/int8.hpp new file mode 100644 index 0000000000..9ca3333c39 --- /dev/null +++ b/include/ck_tile/core/numeric/int8.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/random.hpp" +#include +#include + +#pragma once + +namespace ck_tile { + +// use int8_t directly for int8 arithemetic +// here one can use ck_tile::int8_t to access original int8_t +using int8_t = int8_t; + +// limits +template +struct numeric; + +template <> +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr int8_t epsilon() + { + return 1; // not used + } + + CK_TILE_HOST_DEVICE static constexpr int8_t round_error() + { + return 1; // not used + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr int8_t infinity() + { + return 1; // not used + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN() + { + return 1; // not used + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN() + { + return 1; // not used + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min() + { + return 1; // not used + } + + CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; } +}; + +#if 0 +template +struct numeric_traits; + +template <> +struct numeric_traits +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr int bias = 15; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + using bitwise_type = uint16_t; +}; +#endif + +CK_TILE_HOST_DEVICE +constexpr float int8_to_float(const int8_t& x) { return static_cast(x); } + +CK_TILE_HOST_DEVICE +constexpr int8_t float_to_int8(const float& x) { return static_cast(x); } + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index cb18cde70d..4011e08ce4 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -10,6 +10,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/int8.hpp" namespace ck_tile { @@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float) CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float) CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) +CK_TILE_TYPE_CONVERT(float, float, int8_t, int8) +CK_TILE_TYPE_CONVERT(int8_t, int8, float, float) + #undef CK_TILE_TYPE_CONVERT #endif diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index 9707f2990a..de99be1965 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, return null_tile_window>{window_lengths}; } +template +CK_TILE_DEVICE constexpr auto make_tile_window(const null_tile_window& t, + const StaticTileDistribution&) +{ + return t; +} + template CK_TILE_DEVICE void move_tile_window(null_tile_window&, diff --git a/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp index 837f52c399..62cd26b6ab 100644 --- a/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp @@ -8,20 +8,44 @@ namespace ck_tile { +// Note: for simplicity, each functor only care about single M +struct reference_layernorm2d_default_epilogue +{ + template + void operator()(int m, HostTensor& o, const HostTensor& acc) + { + const int N = acc.mDesc.get_lengths()[1]; + for(int n = 0; n < N; ++n) + { + o(m, n) = ck_tile::type_convert(acc(m, n)); + } + } + + template + auto operator()(int m, const HostTensor& acc) + { + HostTensor o(acc.get_lengths(), acc.get_strides()); + operator()(m, o, acc); + return o; + } +}; + template + typename InvStdDataType, + typename Epilogue = reference_layernorm2d_default_epilogue> void reference_layernorm2d_fwd(const HostTensor& x_m_n, const HostTensor& gamma_n, const HostTensor& beta_n, HostTensor& y_m_n, HostTensor& mean_m, HostTensor& invStd_m, - ComputeDataType epsilon) + ComputeDataType epsilon, + Epilogue epilogue_functor = {}) { auto layernorm2d_fwd_func = [&](auto m) { const int N = x_m_n.mDesc.get_lengths()[1]; @@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor& x_m_n, if constexpr(!std::is_same_v) invStd_m(m) = ck_tile::type_convert(divisor); + HostTensor acc(x_m_n.get_lengths(), x_m_n.get_strides()); for(int n = 0; n < N; ++n) { ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); ComputeDataType beta = ck_tile::type_convert(beta_n(n)); - auto y = (x - mean) * divisor; - y = y * gamma + beta; + auto a_ = (x - mean) * divisor; + a_ = a_ * gamma + beta; - y_m_n(m, n) = ck_tile::type_convert(y); + acc(m, n) = a_; } + + epilogue_functor(m, y_m_n, acc); }; make_ParallelTensorFunctor(layernorm2d_fwd_func, diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index eb06fea2dd..fb8d7221b8 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -9,4 +9,5 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 4363ea1f55..1510f18a30 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -3,4 +3,5 @@ #pragma once +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp similarity index 96% rename from include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp rename to include/ck_tile/ops/common/generic_2d_block_shape.hpp index e4b60331eb..64ad20c3be 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -1,11 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck_tile/core.hpp" - namespace ck_tile { + /* // clang-format off @@ -42,7 +41,7 @@ template typename Vector_, // contiguous pixels(vector size) along seq index_t BlockSize_ = warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> -struct Layernorm2dShape +struct Generic2dBlockShape { // block size static constexpr index_t Block_M = BlockTile_::at(number<0>{}); diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 62ba9dc0b3..cd1e43fb8c 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -4,4 +4,5 @@ #pragma once #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index a98f60b364..c24744bdbc 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -5,4 +5,6 @@ #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 5dc49c3b0e..7c5d5a6f31 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -9,23 +9,29 @@ namespace ck_tile { // this epilogue just store out a M*N matrix, row major -template +template struct Default2DEpilogueProblem { - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; }; template struct Default2DEpilogue { - using Problem = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool UseRawStore = Problem::UseRawStore; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } @@ -36,7 +42,7 @@ struct Default2DEpilogue { // TODO: this is ugly - if constexpr(kPadM || kPadN) + if constexpr(UseRawStore && (kPadM || kPadN)) { store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); buffer_store_fence(); diff --git a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp new file mode 100644 index 0000000000..2e29604116 --- /dev/null +++ b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" + +namespace ck_tile { + +template +struct DynamicQuantEpilogueTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; + static constexpr bool UseMax3 = UseMax3_; +}; + +// this epilogue just store out a M*N matrix, row major +template +struct DynamicQuantEpilogueProblem +{ + using AccDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockShape = remove_cvref_t; // can consum generic 2d shape + using Traits = remove_cvref_t; +}; + +template +struct DynamicQuantEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + static constexpr bool kPadM = Problem::Traits::kPadM; + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool UseRawStore = Problem::Traits::UseRawStore; + static constexpr bool UseMax3 = Problem::Traits::UseMax3; + + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2d{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dSync{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dCrossWarpSync{}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); + return reduce_crosswarp_sync.GetSmemSize(); + } + + // TODO: this function assume store out vector size is the same as OAccTile last dimension size + // how do we fix this ? + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, + YScaleWindow& y_scale_window, + const OAccTile& o_acc_tile, + void* smem) + { + auto reduce = GetBlockReduce2d(); + auto reduce_sync = GetBlockReduce2dSync(); + auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); + + const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; + + auto row_absmax = [&]() { + constexpr auto y_size_per_row = + OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at( + number<1>{}); + // constexpr auto y_size_per_row = OAccTile::get_lengths()[number<1>{}]; + if constexpr(UseMax3 && std::is_same_v && y_size_per_row % 2 == 0) + { + // fast max3 implementation + const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) { + float rtn; + asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" + : "=v"(rtn) + : "v"(acc_), "v"(v_0_), "v"(v_1_)); + return rtn; + }; + return reduce(o_acc_tile, type_convert(0), f_max3, sequence<1, 2>{}); + } + else + { + return reduce(o_acc_tile, type_convert(0), f_absmax); + } + }(); + reduce_sync(row_absmax, f_absmax); + reduce_crosswarp_sync(row_absmax, smem, f_absmax); + + // here y_scale is Acc TYpe, need convert to YScale type later + auto y_scale = tile_elementwise_in( + [&](const auto& v_) { + return v_ / type_convert(numeric::max()); + }, + row_absmax); + + store_tile(y_scale_window, cast_tile(y_scale)); + + auto o_acc_scaled_tile = + make_static_distributed_tensor(o_acc_tile.get_tile_distribution()); + + sweep_tile(o_acc_tile, [&](auto idx) { + constexpr auto row_id = make_tuple(idx[number<0>{}]); + o_acc_scaled_tile(idx) = o_acc_tile[idx] / y_scale(row_id); + }); + + // TODO: this is ugly + if constexpr(UseRawStore && (kPadM || kPadN)) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_scaled_tile)); + buffer_store_fence(); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_scaled_tile)); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 9389a5397f..e106264cef 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -43,4 +43,5 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c3e028528b..ac74782a3a 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -39,4 +39,5 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 57e83a7a51..2b02bcc5d2 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -6,4 +6,5 @@ #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 2a403b0f49..711c5d8595 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -4,9 +4,10 @@ #pragma once #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" -#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index cebe5131a7..9a2e06d05f 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -5,19 +5,24 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" namespace ck_tile { // host side args struct Layernorm2dFwdHostArgs { - const void* p_x; - const void* p_gamma; - const void* p_beta; + const void* p_x; // [m ,n], input, fp16/bf16 + const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used + const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used + const void* p_gamma; // [1, n], gamma, prec same as input + const void* p_beta; // [1, n], beta, prec same as input - void* p_y; - void* p_mean; - void* p_invStd; + void* p_y; // [m, n], output, fp16/bf16 + void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used + void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used + void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used + void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used float epsilon; @@ -27,10 +32,11 @@ struct Layernorm2dFwdHostArgs }; // TODO: Extract some type to wrapper class -template +template struct Layernorm2dFwd { using Pipeline = remove_cvref_t; + using Epilogue = remove_cvref_t; using Problem = typename Pipeline::Problem; using XDataType = remove_cvref_t; @@ -40,18 +46,26 @@ struct Layernorm2dFwd using YDataType = remove_cvref_t; using MeanDataType = remove_cvref_t; using InvStdDataType = remove_cvref_t; + using XScaleDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + + // for simplicity, shortcut input/output type is same as X + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; static constexpr bool kHasGamma = !std::is_same_v; static constexpr bool kHasBeta = !std::is_same_v; - static constexpr bool kSaveMeanInvStd = Problem::kSaveMeanInvStd; - static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; - static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; + static constexpr bool kSaveMeanInvStd = Problem::Traits::kSaveMeanInvStd; + static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd; + static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd; - static constexpr index_t Block_M = Problem::BlockShape::Block_M; - static constexpr index_t Block_N = Problem::BlockShape::Block_N; - static constexpr bool kPadM = false; // always no need to pad along M - static constexpr bool kPadN = Problem::kPadN; - static constexpr bool kTwoPass = Problem::kTwoPass; + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool kTwoPass = Problem::Traits::kTwoPass; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; @@ -62,13 +76,18 @@ struct Layernorm2dFwd struct Kargs { - const void* p_x; - const void* p_gamma; - const void* p_beta; + const void* p_x; // [m ,n], input, fp16/bf16 + const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used + const void* p_x_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used + const void* p_gamma; // [1, n], gamma, prec same as input + const void* p_beta; // [1, n], beta, prec same as input - void* p_y; - void* p_mean; - void* p_invStd; + void* p_y; // [m, n], output, fp16/bf16 + void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used + void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used + + void* p_mean; // [m, 1], output mean, prec same as input, nullptr if not used + void* p_invStd; // [m, 1], output inv-stdvariance, prec same as input, nullptr if not used float epsilon; @@ -81,9 +100,13 @@ struct Layernorm2dFwd CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) { return Kargs{hargs.p_x, + hargs.p_x_residual, + hargs.p_x_scale, hargs.p_gamma, hargs.p_beta, hargs.p_y, + hargs.p_y_residual, + hargs.p_y_scale, hargs.p_mean, hargs.p_invStd, hargs.epsilon, @@ -106,6 +129,7 @@ struct Layernorm2dFwd template <> struct t2s { static constexpr const char * name = "bf16"; }; template <> struct t2s { static constexpr const char * name = "fp8"; }; template <> struct t2s { static constexpr const char * name = "bf8"; }; + template <> struct t2s { static constexpr const char * name = "int8"; }; // clang-format on // in byte @@ -113,24 +137,41 @@ struct Layernorm2dFwd CK_TILE_HOST static std::string GetName() { +#define _SS_ std::string +#define _TS_ std::to_string // clang-format off using S_ = typename Problem::BlockShape; auto surfix = [&] () { std::string n; + if (kFusedAdd != Layernorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Layernorm2dFusedAddEnumName::name; + if (kFusedQuant != Layernorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Layernorm2dFusedQuantEnumName::name; if (kPadN) n += "_pn"; if (kSaveMeanInvStd) n += "_mv"; - if (kTwoPass) n += "_2p"; + // if (kTwoPass) n += "_2p"; return n; }(); - #define _SS_ std::string - #define _TS_ std::to_string - return _SS_("layernorm2d_fwd_") + _SS_(t2s::name) + "_" + + auto prec_str = [&] () { + std::string base_str = _SS_(t2s::name); + if (!std::is_same_v) { + base_str += _SS_("_") + _SS_(t2s::name); + } + if (kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { + base_str += _SS_("_sx") + _SS_(t2s::name); + base_str += _SS_("_sy") + _SS_(t2s::name); + } + if (kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) { + base_str += _SS_("_sy") + _SS_(t2s::name); + } + return base_str; + }(); + + return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _SS_(Pipeline::name) + surfix; - #undef _SS_ - #undef _TS_ // clang-format on +#undef _SS_ +#undef _TS_ } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -153,6 +194,31 @@ struct Layernorm2dFwd tmp2_, make_tuple(number{}, number{}), {iM, 0}); }(); + const auto x_residual_window = [&]() { + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || + kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_x_residual), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel + // will check the max count dynamically + const auto tmp2_ = pad_tensor_view(tmp_, + make_tuple(number{}, number{}), + sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + } + else + { + return make_null_tile_window(make_tuple(number{}, number{})); + } + }(); + const auto gamma_window = [&]() { const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_gamma), @@ -194,6 +260,28 @@ struct Layernorm2dFwd tmp2_, make_tuple(number{}, number{}), {iM, 0}); }(); + auto y_residual_window = [&]() { + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) + { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_y_residual), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + auto tmp2_ = pad_tensor_view(tmp_, + make_tuple(number{}, number{}), + sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + } + else + { + return make_null_tile_window(make_tuple(number{}, number{})); + } + }(); + auto mean_window = [&]() { if constexpr(kSaveMean) { @@ -232,17 +320,60 @@ struct Layernorm2dFwd return make_null_tile_window(make_tuple(number{})); }(); + auto x_scale_window = [&]() { + if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) + { + const auto win_ = [&]() { + const auto tmp_0_ = make_naive_tensor_view_packed( + static_cast(kargs.p_x_scale), + make_tuple(kargs.n), + number{}); + + return pad_tensor_view(tmp_0_, + make_tuple(number{}), + sequence{}); // x_scale no need pad + }(); + return make_tile_window(win_, make_tuple(number{}), {0}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + + auto y_scale_window = [&]() { + if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT || + kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT) + { + const auto win_ = [&]() { + const auto tmp_0_ = make_naive_tensor_view_packed( + static_cast(kargs.p_y_scale), + make_tuple(kargs.m), + number<1>{}); + + return pad_tensor_view( + tmp_0_, make_tuple(number{}), sequence{}); + }(); + return make_tile_window(win_, make_tuple(number{}), {iM}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + __shared__ char smem[GetSmemSize()]; Pipeline{}(x_window, + x_residual_window, gamma_window, beta_window, y_window, + y_residual_window, mean_window, inv_std_window, + x_scale_window, + y_scale_window, static_cast(kargs.epsilon), kargs.n, - smem); + smem, + Epilogue{}); } }; diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index c767a472a9..16a7c3b86d 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include #include @@ -24,20 +25,25 @@ struct Layernorm2dFwdPipelineOnePass using MeanDataType = ck_tile::remove_cvref_t; using InvStdDataType = ck_tile::remove_cvref_t; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; + static constexpr bool kHasGamma = !std::is_same_v; static constexpr bool kHasBeta = !std::is_same_v; - static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; - static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; + static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd; + static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM - static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) - return "bpr_op"; // block per row + return "bpr"; // block per row else - return "wpr_op"; // warp per row + return "wpr"; // warp per row }(); CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -46,20 +52,30 @@ struct Layernorm2dFwdPipelineOnePass } template + typename InvStdWindow, + typename XScaleWindow, + typename YScaleWindow, + typename Epilogue> CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const XResidualWindow& x_residual_window_, const GammaWindow& gamma_window_, const BetaWindow& beta_window_, - YWindow& y_window, + YWindow& y_window_, + const YResidualWindow& y_residual_window_, MeanWindow& mean_window, InvStdWindow& inv_std_window, + const XScaleWindow& x_scale_window_, + YScaleWindow& y_scale_window, ComputeDataType epsilon, ck_tile::index_t row_size, - void* smem) const + void* smem, + Epilogue) const { const auto x_window = make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); @@ -67,8 +83,17 @@ struct Layernorm2dFwdPipelineOnePass gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution()); const auto beta_window = make_tile_window( beta_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + const auto x_residual_window = make_tile_window( + x_residual_window_, Policy::template MakeXBlockTileDistribution()); + auto y_residual_window = make_tile_window( + y_residual_window_, Policy::template MakeXBlockTileDistribution()); + const auto x_scale_window = make_tile_window( + x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); + auto x_scale = load_tile(x_scale_window); - const auto x = load_tile(x_window); int cur_count = 0; int max_count = block_tile_welford_calculate_max_count(row_size); @@ -81,6 +106,18 @@ struct Layernorm2dFwdPipelineOnePass const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || + kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + x(idx) = type_convert(x_resi(idx)) + + type_convert(x(idx)); + }); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) + store_tile(y_residual_window, x); + } + // compute welford each-thread->cross-lane->cross-warp auto [mean, var] = block_welford(x, cur_count, max_count); block_welford_sync(mean, var, cur_count); @@ -100,8 +137,8 @@ struct Layernorm2dFwdPipelineOnePass store_tile(inv_std_window, cast_tile(inv_std)); // layernorm computation - auto y = make_static_distributed_tensor(x.get_tile_distribution()); - sweep_tile(y, [&, mean_ = mean](auto idx) { + auto ln = make_static_distributed_tensor(x.get_tile_distribution()); + sweep_tile(ln, [&, mean_ = mean](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); @@ -109,11 +146,28 @@ struct Layernorm2dFwdPipelineOnePass const auto beta_ = type_convert(beta[j_idx]); const auto x_ = type_convert(x[idx]); - auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; + auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; - y(idx) = type_convert(y_); + ln(idx) = ln_; }); - store_tile(y_window, y); + + if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) + { + // smooth-quant pre-scale, then run rowwise-quant + sweep_tile(ln, [&](auto idx) { + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + const auto xs_ = type_convert(x_scale[j_idx]); + ln(idx) = ln(idx) * xs_; + }); + } + + if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || + kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) + { + Epilogue{}(y_window_, y_scale_window, ln, smem); + } + else + Epilogue{}(y_window_, ln); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp index 8e9f8e81e4..7ec830add1 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp @@ -14,10 +14,10 @@ template + typename Traits_> struct Layernorm2dFwdPipelineProblem { using XDataType = remove_cvref_t; @@ -27,14 +27,14 @@ struct Layernorm2dFwdPipelineProblem using YDataType = remove_cvref_t; using MeanDataType = remove_cvref_t; using InvStdDataType = remove_cvref_t; + using XScaleDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; using BlockShape = remove_cvref_t; static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; - static constexpr bool kTwoPass = kTwoPass_; + using Traits = remove_cvref_t; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index e35d02e707..ec10efbc69 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -24,20 +24,25 @@ struct Layernorm2dFwdPipelineTwoPass using MeanDataType = ck_tile::remove_cvref_t; using InvStdDataType = ck_tile::remove_cvref_t; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; + static constexpr bool kHasGamma = !std::is_same_v; static constexpr bool kHasBeta = !std::is_same_v; - static constexpr bool kSaveMean = Problem::kSaveMeanInvStd; - static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd; + static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd; + static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM - static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; + static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) - return "bpr_tp"; // block per row + return "bpr_2p"; // block per row else - return "wpr_tp"; // warp per row + return "wpr_2p"; // warp per row }(); CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -46,20 +51,30 @@ struct Layernorm2dFwdPipelineTwoPass } template + typename InvStdWindow, + typename XScaleWindow, + typename YScaleWindow, + typename Epilogue> CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const XResidualWindow& x_residual_window_, const GammaWindow& gamma_window_, const BetaWindow& beta_window_, YWindow& y_window, + const YResidualWindow& y_residual_window_, MeanWindow& mean_window, InvStdWindow& inv_std_window, + const XScaleWindow& /*x_scale_window*/, + YScaleWindow& /*y_scale_window*/, ComputeDataType epsilon, ck_tile::index_t row_size, - void* smem) const + void* smem, + Epilogue) const { auto x_window = make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); @@ -67,6 +82,10 @@ struct Layernorm2dFwdPipelineTwoPass gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution()); auto beta_window = make_tile_window( beta_window_, Policy::template MakeGammaBetaBlockTileDistribution()); + auto x_residual_window = make_tile_window( + x_residual_window_, Policy::template MakeXBlockTileDistribution()); + auto y_residual_window = make_tile_window( + y_residual_window_, Policy::template MakeXBlockTileDistribution()); // Problem::BlockShape static constexpr index_t Block_N = Problem::BlockShape::Block_N; @@ -93,9 +112,26 @@ struct Layernorm2dFwdPipelineTwoPass for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - const auto x = load_tile(x_window); - block_welford(x, mean, var, cur_count, max_count); + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); + move_tile_window(x_window, {0, Block_N}); + move_tile_window(x_residual_window, {0, Block_N}); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || + kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + x(idx) = type_convert(x_resi(idx)) + + type_convert(x(idx)); + }); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) + { + store_tile(y_residual_window, x); + move_tile_window(y_residual_window, {0, Block_N}); + } + } + block_welford(x, mean, var, cur_count, max_count); } block_welford_sync(mean, var, cur_count); @@ -119,6 +155,7 @@ struct Layernorm2dFwdPipelineTwoPass row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window}); @@ -126,14 +163,24 @@ struct Layernorm2dFwdPipelineTwoPass // layernorm computation for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - const auto x = load_tile(x_window); + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || + kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + x(idx) = type_convert(x_resi(idx)) + + type_convert(x(idx)); + }); + } // load gamma/beta (TODO: support no gamma/beta?) const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); - auto y = make_static_distributed_tensor(x.get_tile_distribution()); + auto ln = make_static_distributed_tensor(x.get_tile_distribution()); - sweep_tile(y, [&, mean_ = mean](auto idx) { + sweep_tile(ln, [&, mean_ = mean](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); @@ -141,14 +188,16 @@ struct Layernorm2dFwdPipelineTwoPass const auto beta_ = type_convert(beta[j_idx]); const auto x_ = type_convert(x[idx]); - auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; + auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; - y(idx) = type_convert(y_); + ln(idx) = ln_; }); - store_tile(y_window, y); + static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT); + Epilogue{}(y_window, ln); move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(gamma_window, {-Block_N}); move_tile_window(beta_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N}); diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp new file mode 100644 index 0000000000..fb327f74a3 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +enum class Layernorm2dFusedAddEnum +{ + NO_ADD = 0, + // fused add before layernorm and store result to global + PRE_ADD_STORE = 1, + // fused add before layernorm, but not store result + PRE_ADD = 2, +}; + +// clang-format off +template struct Layernorm2dFusedAddEnumName; +template<> struct Layernorm2dFusedAddEnumName { static constexpr const char * name = "no"; }; +template<> struct Layernorm2dFusedAddEnumName { static constexpr const char * name = "pras"; }; +template<> struct Layernorm2dFusedAddEnumName { static constexpr const char * name = "pra"; }; +// clang-format on + +enum class Layernorm2dFusedQuantEnum +{ + NO_SWEEP = 0, + SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale + DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale +}; + +// clang-format off +template struct Layernorm2dFusedQuantEnumName; +template<> struct Layernorm2dFusedQuantEnumName { static constexpr const char * name = "no"; }; +template<> struct Layernorm2dFusedQuantEnumName { static constexpr const char * name = "dqt"; }; +template<> struct Layernorm2dFusedQuantEnumName { static constexpr const char * name = "smdqt"; }; +// clang-format on + +template +struct Layernorm2dFwdTraits +{ + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_; + static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index ee8c693727..990e9ecc03 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -5,4 +5,5 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index fe2d24044e..aa617ee2b4 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -7,4 +7,5 @@ #include "ck_tile/ops/reduce/block/block_reduce2d.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index fa3007d1e4..c93329bfbe 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -301,7 +301,10 @@ struct BlockReduce2D .get_static_tile_distribution_encoding(), ReduceDim{})); - return make_static_distributed_tensor(acc_dstr); + auto dst_ = make_static_distributed_tensor(acc_dstr); + // init acc_tensor + tile_elementwise_inout([&](auto& x_) { x_ = type_convert(reduce_init); }, dst_); + return dst_; } // return number of pixels each lane need to reduce diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index beb8c718e3..3c68147112 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -17,14 +17,24 @@ struct BlockReduce2d CK_TILE_DEVICE constexpr BlockReduce2d() {} - template + template > CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, YDistributedTensor_& y_tensor, - const ReduceFunc& reduce_func) + const ReduceFunc& reduce_func, + ReducePacksPerXDim = {}) { + sweep_tile( + [&](auto... idx_) { + constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]); + y_tensor(idx_0) = reduce_func(y_tensor(idx_0), x_tensor[idx_]...); + }, + ReducePacksPerXDim{}); +#if 0 constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; - constexpr auto spans = XDistributedTensor_::get_distributed_spans(); // FIXME: hard coded to reduce 2nd axis @@ -42,6 +52,7 @@ struct BlockReduce2d y_tensor(y_dstr_idx) = y; }); +#endif } template @@ -63,14 +74,17 @@ struct BlockReduce2d return tensor; } - template + template > CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor, const ComputeDataType& reduce_init, - const ReduceFunc& reduce_func) + const ReduceFunc& reduce_func, + ReducePacksPerXDim = {}) { auto y_tensor = MakeYBlockTile(); set_tile(y_tensor, reduce_init); - (*this)(x_tensor, y_tensor, reduce_func); + (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{}); return y_tensor; } diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 98c60f1b51..f0a6cf9603 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -9,4 +9,5 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index 584ca70689..4df34e1e0d 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -5,4 +5,5 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index b1143e4a06..fcae3e02dc 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -5,4 +5,5 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index 809473d53b..cc7dbffee4 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -7,4 +7,5 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/welford.hpp b/include/ck_tile/ops/welford.hpp index ebf9406837..a4c479dd95 100644 --- a/include/ck_tile/ops/welford.hpp +++ b/include/ck_tile/ops/welford.hpp @@ -6,4 +6,5 @@ #include "ck_tile/ops/welford/block/block_welford.hpp" #include "ck_tile/ops/welford/block/block_welford_problem.hpp" #include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"