From d939c3b4fcbc6d220874ccf3b81a2a6c469cb680 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Tue, 19 May 2026 21:45:23 -0400 Subject: [PATCH] sparse_attn: split-launch dispatch + 3-mode PV-skip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Per-head pv_threshold via head_remap LUT (CLI: -pv_threshold_per_head); sentinel 1e30 routes to kEnablePVSkip=false bucket - kEnablePVSkip bool → PVSkipMode enum {kNone, kPerWarp, kPerBlock}; new kPerBlock matches upstream sm80 (LDS vote, V loads unconditional). CLI: -pv_mode={none,warp,block}, default warp - README: PV-skip modes section + MI300X 3-curve sparsity chart Co-Authored-By: Claude Opus 4 --- example/ck_tile/50_sparse_attn/README.md | 17 +- .../codegen/ops/fmha_fwd_sparge.py | 121 +++++++---- .../docs/pv_skip_mode_comparison.png | Bin 0 -> 113868 bytes .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 46 +++- .../50_sparse_attn/sparge_blockmap_inst.cpp | 136 +++++++++++- .../ck_tile/50_sparse_attn/test_sparge.cpp | 101 +++++++-- .../kernel/fmha_fwd_sparge_kernel.hpp | 62 +++++- ...ck_fmha_pipeline_qr_ks_vs_async_sparge.hpp | 197 ++++++++++++++++-- 8 files changed, 585 insertions(+), 95 deletions(-) create mode 100644 example/ck_tile/50_sparse_attn/docs/pv_skip_mode_comparison.png diff --git a/example/ck_tile/50_sparse_attn/README.md b/example/ck_tile/50_sparse_attn/README.md index c7191c8e82..9fdad906de 100644 --- a/example/ck_tile/50_sparse_attn/README.md +++ b/example/ck_tile/50_sparse_attn/README.md @@ -14,10 +14,23 @@ Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/ - **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53)) - **is_causal mask in pooled score** — required for causal-LM prefill (Llama, Qwen) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338)) - **attention_sink** — column 0 forced ON; upstream is hard-wired to `True` at inference ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355)) -- **pv_threshold per-Q-tile skip in attn kernel** — pure perf, ~5–15% on the dominant attention slice ([spas_sage_attn/core.py:L265](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L265)) - **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345)) - **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371)) +## PV-skip modes + +`pv_threshold` per-Q-tile skip in the attention kernel is implemented in three variants, selectable at runtime via `-pv_mode={none|warp|block}`: + +- **`none`** — skip disabled; baseline matching the no-PV-skip codegen instance. +- **`warp`** (per-wavefront) — each wavefront votes locally via `__shfl_xor` butterfly AND; SGPR-resident flag. CK-tile-specific variant, not in upstream. +- **`block`** (per-block) — block-wide consensus vote via LDS broadcast; aligned with upstream sm80 ([`qk_int_sv_f16_cuda_sm80.cuh:L334`](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/csrc/qattn/qk_int_sv_f16_cuda_sm80.cuh#L334)). V loads stay unconditional in all modes — the guard wraps the PV MMA only, matching upstream and paper Algorithm 1. + +![PV-skip mode comparison](docs/pv_skip_mode_comparison.png) + +*MI300X, b=2 h=16 s=8192 d=128 fp16, 5 seeds × 9 sparsity points. All three modes dispatch to the `kM0=64 padK=0` tile bucket at this shape.* + +On the canonical recipe shape, `none > warp > block` at every measured sparsity, with no crossover. The per-block guard adds +33..+35 VGPR (6..9 spills) on this tile configuration, depressing occupancy. `warp` is +0..+4 VGPR. The default is `-pv_mode=warp` (preserves R25 A1 behaviour); switch to `none` for the no-skip baseline or `block` to exercise the upstream-aligned variant. A shape sweep is needed before recommending `block` as default — the `kM0=128` path has Δ ≈ 0 VGPR for per-block and is a candidate. + ## Performance At b=2 h=32 s=16384 fp16, sparge (vsa backend) reaches **1.78× FMHA throughput at topk=0.4** and **5.04× at topk=0.1**, and stays above 1.0× across the full topk range. @@ -37,6 +50,8 @@ ninja tile_example_sparge ./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001 ``` +Select a PV-skip variant with `-pv_mode={none|warp|block}` (default `warp`); finite `-pv_threshold=20` lets the per-Q-tile skip predicate fire. + Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k. ## References diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py index 9489d3758f..e5182c3dc8 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py @@ -114,48 +114,55 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< {F_trload}, fmha_trait_{F_idx}>; -// R25 V0: instantiate the Sparge pipeline with kEnablePVSkip = true (existing path) -// AND kEnablePVSkip = false (PV-skip AST removed at compile time, source-equivalent -// to the frozen VSA reference). Both kernels live in the same TU; the host dispatch -// in fmha_sparge_fwd_api.cpp picks one based on fmha_sparge_fwd_args::pv_skip_compile. -using fmha_pipeline_{F_idx}_pvst = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< - fmha_pipeline_problem_{F_idx}, - ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, - true>; +// R30: emit 3 pipeline / kernel instances per traits combo — kNone (PV-skip +// AST removed; source-equivalent to VSA), kPerWave (R25 A1 shipped path), +// kPerBlock (R30 added: block-wide AND vote gates gemm_1). The host dispatch +// in fmha_sparge_fwd_api.cpp picks one based on +// fmha_sparge_fwd_args::pv_mode_compile (0/1/2). +// R26 split-launch: fmha_fwd_create_kargs_and_grids(a) forwards the new +// fmha_sparge_fwd_args fields (pv_threshold_per_head_ptr, head_remap_ptr, +// nhead_in_launch) to MakeKargs. When head_remap_ptr is non-null the wrapper +// also shrinks grids.y to nhead_in_launch so each bucket fires its own kernel. +// Suffixes: +// _pvsf = PV-Skip OFF (kNone) +// _pvst = PV-Skip per-WAVE (kPerWave; preserved R25 A1 binary name) +// _pvsb = PV-Skip per-BLOCK (kPerBlock; R30 new) using fmha_pipeline_{F_idx}_pvsf = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< fmha_pipeline_problem_{F_idx}, ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, - false>; + ck_tile::PVSkipMode::kNone>; +using fmha_pipeline_{F_idx}_pvst = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< + fmha_pipeline_problem_{F_idx}, + ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + ck_tile::PVSkipMode::kPerWave>; +using fmha_pipeline_{F_idx}_pvsb = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< + fmha_pipeline_problem_{F_idx}, + ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + ck_tile::PVSkipMode::kPerBlock>; using fmha_epilogue_{F_idx} = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx}_pvst = - ck_tile::FmhaFwdSpargeKernel; using fmha_kernel_{F_idx}_pvsf = - ck_tile::FmhaFwdSpargeKernel; + ck_tile::FmhaFwdSpargeKernel; +using fmha_kernel_{F_idx}_pvst = + ck_tile::FmhaFwdSpargeKernel; +using fmha_kernel_{F_idx}_pvsb = + ck_tile::FmhaFwdSpargeKernel; using trait_{F_idx} = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; #include +// R30: 3 specializations per traits combo — int kPVMode values: +// 0 = kNone (pvsf binary) +// 1 = kPerWave (pvst binary; R25 A1 path) +// 2 = kPerBlock (pvsb binary; R30 new) template<> -float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}_pvst; - if(s.log_level_ > 0) - std::cout << ", " << "{F_kernel_name}_pvst" << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -template<> -float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}_pvsf; if(s.log_level_ > 0) @@ -167,7 +174,42 @@ float fmha_sparge_fwd_(const ck_tile::stream_config& s, fm }} template<> -void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvst; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}_pvst" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsb; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}_pvsb" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsf; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}_pvst; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); @@ -178,9 +220,9 @@ void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& }} template<> -void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) {{ - using k_ = fmha_kernel_{F_idx}_pvsf; + using k_ = fmha_kernel_{F_idx}_pvsb; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; @@ -261,10 +303,13 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - if(a.pv_skip_compile) - return fmha_sparge_fwd_(s, a); - else - return fmha_sparge_fwd_(s, a); + // R30: pv_mode_compile selects 0=kNone / 1=kPerWave / 2=kPerBlock. + switch(a.pv_mode_compile) {{ + case 0: return fmha_sparge_fwd_(s, a); + case 1: return fmha_sparge_fwd_(s, a); + case 2: return fmha_sparge_fwd_(s, a); + default: return fmha_sparge_fwd_(s, a); // legacy default = per-wave + }} }} """ @@ -302,11 +347,13 @@ FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; - if(a.pv_skip_compile) - fmha_sparge_fwd_oneshot_(s, a); - else - fmha_sparge_fwd_oneshot_(s, a); - return; + // R30: pv_mode_compile selects 0=kNone / 1=kPerWave / 2=kPerBlock. + switch(a.pv_mode_compile) {{ + case 0: fmha_sparge_fwd_oneshot_(s, a); return; + case 1: fmha_sparge_fwd_oneshot_(s, a); return; + case 2: fmha_sparge_fwd_oneshot_(s, a); return; + default: fmha_sparge_fwd_oneshot_(s, a); return; + }} }} """ diff --git a/example/ck_tile/50_sparse_attn/docs/pv_skip_mode_comparison.png b/example/ck_tile/50_sparse_attn/docs/pv_skip_mode_comparison.png new file mode 100644 index 0000000000000000000000000000000000000000..b35c20a679ad9dd16548170b39c8836fc8caefea GIT binary patch literal 113868 zcmdSBc{rA9+c$iXR8)pcDPuAwDhXwlgfdT=lR~LXMP?ZzLxxa^M5aQKAyZ_CLP}&- zMCOnoR!|e|+2a{_(EuzSml{bY16poX4^6ziB^1G*6${M#V}+B9XS8 zR8i6NIpI_yuU$waPUtek!7Q1D;`hR|vfAIDH@n4!Bp=x2lsl-j^ z=jV4z_cRZ^JP$pWxcE!mF+H=+?ruf=iktgUPIk3w!-%8$rfsYgQ7ZmUBh9fcdhwLE z+=bU;mX?;}@KX{J#rUaC_hpm56RI`$j$E1f)nQ)K*T-I3T1p$DH1qSPz_=Zw+O=!q zS$f%&LPA1io+~eNvTI}xVfk}*M|rC>ue5&ttQLZv- zMMjz*C`PbLTUc9f%+#gYv}scq>tQO5q+gCg^A}Gud6V zG#JD&Fg3LuUq1Td$7?L=pFe+=u)b=!vuUH1>wjt*qV^094N*u-ORJ=5hc8lY z*svj3i5p+aY#1JV73=M>GFddUu%KLfi!miTn;KUwcJcF8wY&R|pE+~JcU(nX{SCfs z@cVa4D=Vwfv9apLMk+5aug=1Y2KkF``};Yv^`BOUGK**yT{d-fb9>$1Uf4Mvt)`|{ z!{NQ^J6*A|9Y38{P_XZFMBZHy^Sag)#le9A5}p{os<^cD?3-Y^rRmCbhS{;vQH$^I z5?p_N-ROlyT-XDt0=Ty%NM^7@~ZY3(mT@;52zNVsqPs7A@7N26c7I3(@0TKxCCeEC90 zqN1S@8PCe2Z_CmXa=dmeotk-njEkP?ZF~_~GS)yqxcj+v=v7m4ImMvu!dFe--#cQ^ z;62qN88AFJIEY<}>s2JF@X%{v`7*uMeh+>9s_^^wZ-s<>tNN_og}`mQRMgZ$Ns$2o z8%s(`qVn?(ka6O^e=qiX>%~RCJy%^wNBYv;t+V#~*GU6|yCltPVpu$5)+Tyd?phsiiiO--|Oo@_sS zsZ&$sVsEL-k108U&Ps3DlRWg~a%c17?LOz6eD>^F^|&4n{fEX_A-%Z6xUMfriMz+~ zP)|3QrzmWkota6-j@;aBQsGh3=qN5Ot{|zSqhr^aOxB*Eb+e$ba0>~a*Hz)UvLY68 zhGXo->$WsCy_eY!9+>(LhqE0`kL>U3OX}(EecRZms*-K=f>-*+yo&WQHoR7QmR^9B z#g!`>+jfcT7H^-Q>b-aFpi$Jq$nRg{6h8Cw^ZdNL=_=OLPQ#6ivGyC~zS+M|k~SWc zuJ7%AVAr0t*JQ59zB^r)_vq1kcAqljaZ6{dtkj$zvIuQPgAqD0#BoQ@leYK81)7~K zEVQJtwzSYlv+^4*3gxaDcL-=tWMlwnyD zD0sJNgAVTboQ;i&va)i_ojZ>XPCqz$-T$JY1Gx_0>HEp1udaP)x+}uirlYNGvoP^R zp+wq!fj_qV+GOI)MyIh)nK20o0x^>vxg2Dis9&lY8U})|&cA0JY;x&yk@Z-L9Q*uK zjg)bvzd~@)4TqS*=drc`o|T$qf;w-i{?jvCe1jf&2?#aRI|}#bS#D_0)Cu~WYp#)N zR!-k$eQ>9JUs=MNvktzK29Xg0qU}$OkBs%cIwEYC=X3hLgx0`_@9I=vxf(Ao@8h(z z<1a3>q(6Sl6EGaXA)9VhA0DgyUhRaoHhn@u0v%~~ZZ5UBSi;ii>{)Kl<>9*&KF-e0 zqf@=FREO`0*(z+H+?qVxbob+z7Z)&NBoigdB&JbHN~+Dk-19i|ev{<;_jhxxp>b4?>z+6fbh)q0D2z#o^xJVu#dGyn^7a#G{y+*Ej${=cng3dc29sO?b?qg4>yx+ zZEZh{cjShQ_TJMNZj2Q;denHb^u>!80p_RM*JjP=W#U8*4Kr>^?|h={bxO0NNsc|9 zs54bna^!NTQ+f%D-B&6ZpIO#!QD;o|Tv@0VnO^=i`ayB@LA3de8&Z-I611capo*4N``~w>qkhr+i}@{%Z|;Y!&mxDY>oc-TsPaYWsAb} z;+7q}3}^Zu31+uY9=$&GiC;7}KAxXqPA_WH09k8$wgF?@-o1Ob_&&kiW$x?GHV`3~ z!y)I_Ub}ho<`$BhyE}8HiNl-PTHD?6rA{N;@yv)j(ARhtM8k0>E{@5_&@f{)-#@>t z%1Qme!cJT42$CFc^dL4L$&ruY^XJcXsFe;w4RnRpG;C7-Xn67`PHft+Y5OPFdpJ6^ zIYuuuE4@5fSXrr@go|IjvTA#LQbAhG^e!4YuvAY^&&=wwlVw7w!=k*rJmAmfty`6_ z_~)DNZ@F;c0-AWr+1gu6UqeN19lmnYgPnD`Z1d4;6LQhKrz|dA+Dc7Ly|l9O?ekM( zg+FK+DX3j}&!77xC9&g^^Kb#=MGi|dmd-!gHtOr^(}vdOy?hy1St&~#L7U#Lm-f^s znJhbZzOJd+03`O=rV@uT6ix2++qYYYUb1Z66vwq^k7BAz20c4_xQyrTs?VR9v@;o) zn3!xj^SIJa-AkX{t#b0@YYo{oitE>}v+UZXXl~AN_~Pd`Z$fs^F*84L+~vk+?6tBW z%&U>K^~8x21TcC1n(WZ|ceD$>!FgV5H;PM2UbnW=G79R-P4|`m9C*DUcD(uMfr;C@ zj?mK5k}h7n2zc=+$7qMs_jg)HTUxNyM4i5?6CJ;n) z-T3swLqlsav{DDq?L(9viRRnMM}JO|a3DtkQ^Z*z(8i7(J9M-4x8Zkm+)p37F~8Z| zd#zBEFEk`%3l67$Z76egZ+)4I>(UHY;C8lK=rwZOQB{CcSbM?GN?E8)o40M-NOGR+ zDr{Z7xrv676gyYdaPHvhY?BB8YMl?+hFd1gr;3YCKW7NmP8F0BUA(61H3r3Jn2QS@nJ#TjAX|O(UR?KQly3lp<~St zj@I@yqgEX7T6G2fqFz~QH!vSW#jG_O3A&0Vb{}gK7Z=BJ^xCD6XDj$F7V%3gSq8cL z-`zcT@nR^z=Gm#964l)3J28ogRGf=T1{s?2wk0oLitpaNn?y1)GTQs{J$4`&>Bi#J z(}NlhWjskFpzGJECn6Z)h+|}0l$j}zC~i+h3chtqP=Bp&DzyEmY@0#$Y4sLx={|BfbB~`J z58Jl+lQe;ZGSP;sdv;J#QVPCmXIARVu5NA)PuF;On$+<5bKF=*u7-AI%*?uz+A09o z{K|A?v=TQL@Ry^L(~i>np-S9n5_cPaMeI6qy7S2`Mxi@FK|$3fmN*xhMwe>>w^fVW zM0XsTn%ciAc+6{ciQJp9x5D#PEANdPH*}wzW^;6QX86u$)Ao2vIS7)v+7Yt4lZ6Oq%%AuE(G(Jkf76@EEI02FAt&bGqy5l!tT2XRocG^{AMeKg0$m;2R5- z(3OWnL+8?-;~T4gW&lN%yn3aBBNpI@)e2~Q8C{uFxwc^ZJTH&a-rnBP#U%|GO6Q)q zeau+gdi>?iJo(N*>{Z5m6ZHNN1_b(rBiNW-7LM)r>7*$tjsm# zrDvRo`RVlZbcGdSAKOp=_>n1^b9@IL887Laz5P>hjXrjrg9%}PPRiU-DnHK9GBGKW zGH}S0=BE<1w58J>9UY05Md5S4CF!Z7EU6VoS2r#qIG8rDXPB*S&%=(!*RSO#U*7cd z^S_uXHJh}BKExl4koR3-J18Xg#hx)OvxkbyWM0 zB9*HO!JFLzCte9!Bx?%7U?X@MQ#;FDxu<6~8{!Bk;YCaSF$~ zx{z<3R{v2?lJo4vUYu)c(}%p$>Flyz`|hi4qG8uf7y+MENCHB%E0a}HvhOMo(9dOm zw77CfGf`~vfL=VGM$+cV-(aMI!gpIl3IPPutPdur7qd6tJED;~$#%>&vaGBuL-{r< zg-_V++W{q}fq~E5FBcw~$cPPe=1wPLu=D<~u%?P+Sh~dmq9XrNYoQY*yhR;Nh?i=yv#7gvT8< z0|SG!BOeM1#JbAei`sa`vdDEleE4t$<$xpv+lGo}MREc72$;OGIL#ba?)uY@Rs52Q+*Fw`b`h1+sdLuWTULw;?eqdf zQk14k1JP{=_!cny<;xcW=nd9b+t_3%ln~5rE-j4hsQ##zaP(W^**t#w^b|f^P}Ux3be3-BBj>Uq`>gD2H3$%T@f5#i25YJ-7Dqqa%z65hu_COe z+}&V_b8)h$UvM&LRq@rcXZu|yI<>f5U4GBoa@k1Oed14h&c($gSUCf3K*eFF=P4gm znS?#_wxh$q((Lx9Pr7NP!_w)v#dDW0pC;LWL8_9aWMuHCr8tO~SMY#BbGh!|;1JpO z%K68gnb#G-z{rNK|2!BV$9RoWjr!Icroeqby(q`{Qz#Rehyfs4F(Qs^wikEmZfA>&D?=~R>VJ% zRr-b)8K;_B6M4{!Y;Nz{V`F35-7oDK%~HYpnF1}E?@Q`rP5`akDvxj*QPV#0u`0%6 z=|?c#9+a%})spdxgo;t<^daLT#07FWT$#ZQ=`!pVoW zPRod*8^_Jf?aw(o@cQQAE{;>=Iwwz_RJb?1{^8VVnF~p!a|#cLLp41uq@UATy!*q7 zr+9E^sL;ZJFHo&gT!b>KK2--$2Xs8j;<$Tls^=sm4`Q1FH>cyzY1-KdnpS!}_`NtS z;-Rni>zlUx(6sDW;B4&B1dLs`${lw-?&#^s zZ|A0$Q33-z!R6Ty$we`=V5c<<@r`~gX1(Ug>az=Kar?B476xmA&RJN~GvQf_g z_@)@dq?3gZ2*~AKcech9N@jI#&+>?3)Wf6NnYzPw?cD}Qq{+T=W7}m^bUgQ5b938- zV>8-N0MhzyD3!eCTMlZ$!LcpKRVvw0N}jwh*)1fDrXwsYOubr_B`CUgopjO8F3Z8R zV*QW1rF#}HP0P`rV}N&hpHm>732vyLH8*?ZGM?_>8LSeY>_mIqJFRJ=hwql)uLl#W4m0yU^kC9x-ZHvGc)VW z>A!uOdj!YNZjN@if_Jz>hg5*BLq1m-QFfh9w|&oEadG{ngZpTh8iy-Cp4DPM{^O^^ zVD-k7?Z-?0G`%%X1;CDpj7%;tk;#9NbNEHKedK^2ZdaMxAxLC_cn~RA9%yw+*x6Bc?)U+NTL4*8 zE!Uns0K~9~)Co1L?e~x^_`XkgczE5f$QGknG%RjO$&!W@hk+`xg3GRQlr)CJzZfB&><7{sA^mC0?y@D@H0otH$~ z!FO7>q&WE+>{*51PzcCK1U+bUg#2re@CT6UzO*|Niob&KkO7*9tE(#~ z#9JI(-2}c$APR12X*M94*U9oX#RT`rQlkFd1i))^1OnX3CZ)@D#S1Fw2}nG#ZNv3% z-o4w3qpN*p3gRtkW3#)s^Qb0BqhQ+NRPQd(%!bBBZApV}e6XnZ9}nnJyPg^sF_}FC z``i?Hv?Wvu;*NhvNJ!UkmTrzA9UvmaQ-V|G7K;u@_{WE=~E*HgYTIKN7fR{Rw zG%9%VB!-@sS!Eu}^d#M8y42ISo)6F?-_+Gn&3R1pl7FSM}x2L%ZYYpAJhh;)CdcS10hYQ1TsUcl~4*|6?OT*NFiUoog) zRjYeJGZo`^@ri)GIFEnc<}dSLifdnV07nz z_C&DMw0XbgmE`2)v(r`%-@~*{{OP+N8yoO8n$PO@{EzBispwGECc%J2ufc{7o8_*q zt#_Dv2dMLY%uqjBgzCx)2`(x!k^~e=@IL5^tsg%sUb?go!n(=Vxs`@~?}*)!epf2j zn9vuipQHC$qAym@Rh#8Hf_c7%rp)zD(#FC<^RX1;;tR6Kquv${5hpb@HKWFicE5f5 zc4l$0zPsqk;KW22epbuA50K*<3Ju8z8fGfwV9`T|6me8vUg;0VKDv8GT5ax-jhCD3 zm64$#i}gQiHXYB{u&(IxbPNpspy6bsxP3-`9l7S&CJ`WmK0V8({TITu>cnc{CpD&= zh+rRtKrZ3Fuy4OvS@6WfB{ADJ3LyJ>P-M*#N9(kMvnOwg;6H)fQ~fY_X7U>Nh;i1& z4pkbqLe|sD%p_N zV2s9dq5g-f_a8jaqhuPcH!wEd1#MoTm1kYi%WUXjp56n{(+#9!($XnuXxxSKQZaMk%nW-LI#OkMZHT+rV3lzy7>`x8z$xoj?9WQRr3f6oZhnK%h zhHL=WV^@?p<;tX_rbZ14NlBN2?U7+zw4YZ!-lzA~_3Q^>B|KeyZ{MD@E{%<4BykD| zoc!iV>f2~NOKC60RjL2>=GbiI@$$e-D;t~OLl>F{-X$CgMvcHl=G%PWe)Hx{G613K z5pnFH4UtC^UP-y}#kK5W*{5oxM|NB?K|K&PV~?=#Hk?FG?yc;{3k zIOSqq8Xw=}S=RXPdh_ny5gH(08|dWk-oGcRVs&pXo6Gp;jgd!Fu3O^p1OdF0H#}%9 z1$aN6)w1FtzSNnPmX-oiSbXD`AYlVR;L6z61th;6{9YNy6%-^0#cu$Fz!l&K?%L0p zK}BL;d9VIFk@=~xK;~xFoyEgou+8f zMn#@0MbMbfSE^r7`5mwOD0zcEDa%F|{nYv|rl3s)#A zDvCwSh8D`?3`9-K*`c=tF(xk6ZSH&Knzr-$AJ3top{j<4&4qUDH>Sm-yyf|Aj2tSC zWn642lYT|jq`2-~6kq1TpHYFTUft_Hv^@9O)YJ&oVvsqn3{9Uz+PQP5LX}dxVxRl* zkLGT}*DZD0qvhC^y$p?wv)}9B90imuGk)$%+9=VkV}#zn zyt|2aE@?b5^5Mvs*9W~P*mt4_4!p*XIlv5BookVW29h!pkoOq1+LF*0BWRX;r-t6W zdq*0o`249*zRxpeq-Dih{xv(hEJxPr)%$jR=;_%Fu}E;Pp?}5d>{&6hT&S6l z6n(|*y9OY|qf=88Y82G?lxve+ov*G>uNYL2WY<>@p-gRda&m$qu9m|5)Yq4csBgII zuW+~Gk6bYeu{YqMw}kdZn3hB*g6sDcdy=@w`;tzy?CksJWunYA+5w`YTVGg1muHc_ z;T$zS2e=}09#05G_=J)Yg`j>;Fv>y3n${T5gCT??k_A3?W`3R_E<}lYE34RMFfieB z2NT)~Up@9{7`NzSuN#w*?f&v*ONX#a?#1PmTdUy}VX-eYM2C@qjB1N+)wl814be$QH|0(n^VzpxcM9juSxq@S3?^)^n`%iOC2kCyAQ<0fZ z11dHww2ReF@blX?4en_<+M2>5<;)L~9RkDUD-2&asm!p#0zq;HL3%T^({5=_;{!c( z?M2qB%syRfLyLqGL8!Fzlikd6Z-2!Js=$@_`Tj`%01mzD%EHqb)~na9QAT!;{{UDL zG=DhDnWihKXSiR)xK&ie&B(296@IAfyZ8ee`5k-5u}=b| z64|WG6^F;L2=EQ{*rYl+_DqjaM+Ji}r#*Q>ook?1(>_dS%UgGeZ&mU~lgwu8GLK(e zzJ&9|yod&D*v)?MTyWck`a6*c2_aBVzrkiFwk+VTKYLRab}pq71?+D~)|u}-#21yQr?mqoov)c!( z>IgKw1obYaZyseb&u#8J&emB}C>yvM)z#HhtDFXrqoXz`lNl4L!W@tE>7`%v-CeEG z19rkaJ#p;(0?_GCLq98ZFFwAitBVy-y#~Lf+sx)f*j3U_Lp3j;IfVhRFx$#J^lNZ` za$nY)9Xe8K--g+>N{4TOUB_m*{Hz_yX16=N7vkeP!9sPn$ZXb}+ql3c<#`X*R;@TN z3-??N_(5bBuywQ`OQPWp0M0Q6P7^)=4n#ewrmj~Cxr|45le+_xvrNVg7M5V}$35cW zO!V+kG-K1ta&yZ#7Iul*!12E|Kks~B%7qH-K~qrZ;j5=WS5&!-%g2tx-FBTHjcQ!} zIAlHSHh1UkIl_a%J(pga6nw|Wg3JdHgB5lV%Q06`G^Oc9*5>r^(hl`L>*baw#Nk{W zRNMjZE%*^nClngtS1=F+N0eM$iyJK`IT}H>L!mnac~*$vuT|7+C;>YH-M;hDJnro5 zWVpBsY>IF&P%#6D*a8_DStrcl&pAe^wDE#(#iJHGg-?HJa|X$Qy;#%Mrajzqk(6gu zza7wm!A-TC%@*5d>goLQkKC6L8EH7UU2M%JC>M7?R zA&>0dmHl{w4#(`F41fy&FFF`JK&;vCq9Y?Y=i>|(OXDj0n@0N2YR00O?s-InD&)3z}Na2UF*^b*6u!u*p5@j?F1Ztl;_E6s+lzMq-7YI$b?c8P~i zqC$N4ceD&v8m`#6`FSN&zg7?@_4^X}Z^)l*gtkiK)3fE~-1h84qN222m!Jcx4#Sk7 z-MxD&NmX7R7LMhg7r;{jv9|#jK~T8pLttYi7kIqAvsdE!)NaC^(e(l$6g_g}1cDVI z6$^8>5W#Wl^@7*rwZ3+p)R;SbCu#MVq~uP5B%V9R1D6H8&uoHI?c}l{uk3VJq0rW? zTji9XTy+3~LFA?+QFXlhik)A*;@vk+qxsyPds^BRib7%eH8h;9q>T$N1x{vWX2u`B z$mG)#(xBJdY5{i1B5{@T+`*j)lkDl?{N~93Vl(4#$}8!?gWG@@R$wK#%z6-o!>mcd zvdr4r+Mg)DV5KNHuP?dUZoRnL`cQsY_RVg{5AlIqxj2^O4Qn5Y;TsBea8+K;Hf}Zy z*;8j&!d2-~$Xhh3H-geg;e$Al?Zu1C%M0CZ;90k^(G^ilbW}7{RFZvsNaS*W9aK$2 zMJz3qG@wd+tLu}kE%zl;3xqRD0-mkvmqP&2h1~udzDj^KNPqf%(^6W}^OrC6-CkXB z87hhq8D**;8{^@VHCw(S%eH&BI!VRY_#WCWUr)0XA72na2XD{!BVuB@DLuGo6(RxC z3X>Uf+09_bD+%Qe-;~?nbzZ)mv>uNjRAhMgb^wRp3*A>@ySRmgZ*ONmrb3v>1@^Ov z((dA9oSC|+#OAk#pxpHQ`SVe=y9@UC|cA&KW8&AT*98eCdN?Vmzb*6bi02pqF#DwkJs$ z0P+lfwOz;ITX_!M_DR~GT(nK{{6RtGJr zNQB6)eGk4=TbbLO|Hbd{tzsHc5-Y1`(Sx;tb?NBR9zWK^wh34``8QMIIIKEYS_Ykf zYYq577f~&8Idp%0L1sqAz4xuMF6*6EA}(<7Fe@WtT7iid`*v0_6>B$_q1+QV!mw)| z5!!nbs*O@oQWAv^vMZ@qP5Et+`m2^9BrSAw)l0nr<;v8h=`k;XIBi$mIb*3GhSV1H794^Y<+zan!2atzTW6 z9Q}}bqWU?ktwd2vzpI!2plB;}j1Z}WU-7{*`g} zklM54gtITGy+*xaqIKt?{I{=Psn$#c?jovl{{>2cfvc?SwF#D;yLRm${T|EC)s4fB zVUp?SfiYNRe?VUwX;wo?ZF=8KBwL;%bhLUfIA=r+P^A zxRMxb=ae!~XE#^Z&GH*a^DrZ_Oh!sw#{Fr&fnE;5RRxOUC5aKv<%Z2zP;THKaygJI z;@fBE=Xv0G4Jz7eRay)W4T0U&_&=|!tApQp8iyOzjFQFJlZRg106%wB(}uo(S3Yd_iMA1!77uqfTXVmpq|p*RFMUV@ zcz!2yh>`+2#JUg^LGn){sljrb@zvvw(u6ve2|-dJ0d3$T zbP=&EMYHK0@n;aw^7W0P*%{y1oeueE=*N$f`}gm!E)k#)vHAKY#4a}h3WPDbG+>i3=+;xR;oiwJM*%iL+eo9s)t{ zFxCw+9aotw@^gv;t~zZ@`%U;SX`pH%2yhxV{P@WezKL;wTD>^7L+8U3UcJ597#J8x z^<&J7e+RBHO_h;wypZv*bZ2L07tBpj;-NUfLx0kD#1_>(w1;VW?n^R{w8xTW${O16 z-w2by8NB17ikn6X^Ll;ixcR2n0dOY!f6(u%c*riLCVCO|6O4FWW6kQj# zGxeZ_j^R~H{9E;|Wx@lSuO`<)WbzU0N#YMlN=in-;fE7RwRyAV(q40&G`4u|#x;x9 zOdY=G1qD28Y;0ME`P?p(U3!+4dXiqN63f4T-$A$-8NkR_ua4UH6h{S9(QqP*5}dA{ zpa9ov5B@?E;ss^4gJp|aldV#z{oQ4LRBNoS{3v=z2qCv4kV>plzsKA*6f_M1Z%4<2 zXa*6)Z{h(F?#0`;A$1WP8t`VdN^aB8#Up+C1Y#E9BV|Iv<6#OVegxTW{=6va&?!<%9%e!sXy9VH(6z&#=zsQcplse^|O?ZMp<9>2M{`EO`S zG32z4H0;Nc!8}KUKmXy|1#RukI%K4L zEhB3(;N<+&s_;b0SmF`Hn-B`zXk=_mJ$)S%3yUqDSw2M{RX1Ym@GMnndhtv0;~~#e z4D$`^GQyS37iTpWcM@g=1YjP7EcT+<4oc_iW$PDM)=)M!H}9YQW5XIRICVjD*+F0; z<+O*#u~SLM67Ju>uYyzoo{ig$8*PZl1em?b0&N?fPG~W-rKh9wJy6$}nVih|E=hV9 z?nztHo6;%~zF0jD6cFh%$L+iHOA(pgxoDIuwslpzJg6t2y%2D%{r>4hCV#s15l{Z#3AGCl7;lrk~Hhs`i zhrEO3xpV$_7I$6r@cFUNp{O9?!iG)!{ngKWn$u5mgsO-y){jS|D*1RQleOvXmPRO; z)Mc^9titkV6=8TFj>A80m)F9u_CSG9hjzr`1+!{hhJNe_rF8 z>8OreCJsO~Gkuhbu3b0E=KeyL7tibKi&no`nN$o*SAZwOk%YvesH#eZ`Ys{y&*)F% z)=lY+`3AXWTu^5r!sO-W`-Y$X!FJhgCR?oNb>sGb7b9+QiJlI~ZD!)jp~sIO51_0Odb3Ndbew z8~?wjlVUe4EG+ypP~$O+gheQ{qF9&1R1Dy%n9d0f4kn+gt*gVR3>QCthGBP!`{I7Y zSO}(>GyCr*Oy5$Ww@WB)K4-Uy`~Bdduzf75&Vy&|Ntw^C85xA;pGCi`K{vb#b?N!D z#uF(}>PA04p+N)lgXE11e*TR9pR3H-{k4ig=+1?ZADzzMWv-cfguN-)$JMK(u}OMY zmqq;b!-hN6n(_j}7+b(CS|N^p0;R~w*5;m=n23yvqhx#V)%<(qpWhcT-SZVB1#X$? zme1k;j#)YW)Raljzo)U|gYB^GB?-1v+8A=5HwBk2N$m;^7drjz`^Ps2v$(8kcd)VD zM&KX?b1&AWvGjZBLpWnq`-Mc?iyf?p@_-Z||LeqmKeY2GW{z0h;^M>Q@9m*`c;CpP zVrfWyYuE0Okt-i%HQ)EWFL46NQhy-D2v8Xau#kAgiNQBIVaujO z;qLBEn4l`%O#ftj4Gmc^dH{sB3DO885^~ah^NKK#((0-z5@8%8N2Gh%9YMOQE-nGO zf&sAqu(Sc#LqfxzuvicuwLE{G5&!{dNHWNkdd8KoWvc4yHz9QZ!#43i<(4fKk)S-IIE2%Cc2WtSsTn;&5wCtI=ImL759xMNXD^?$WfEc?FnFzuZjMv() z4)UhA=MILk93V$%#PG%WZKQT5b?_6+ky#?mVpx2Y3(5!{;0L%v{s94*CfC6~leFow&esTWpl_D=tf+u)MYKKnaVQq2iMuy!o9xM=p>af}9aI=IGj{Z*a z0hOu7@D^fJs@BLs-bRX!A-nz^P{V<%19Hw|pSGZpZz4f<-IS7&f}4Q7zX^yr7^ag& zwf`0ghp+w$%g{}%fZ_?S)o&xUsK?UP-I9(Jps~eHBjJ%EhYsxk>ZwKU&ksk_^3UqB z=;6c1G1U|a`Byg=YxW4R0lDpXWC*?Bq><#X>0`S%xwy!nDzHiu{>4j#PQ$@G*&vkE zBXCZ;YgZ`L5TZO${aXC>s}`L0!*gr3ea0`hBKY?joSisbvAz+oqzA!qu;`QsXNTdu zB2B&ra|D*<#1B3?9@_fM^k|Mr2|X0ddITVdhdBU8JAl0g$%7iCuaQj1 z)^_kkzk7EgVO)_B`~`9*eH$d0(+FjKN>gLzvd5#L@PT`kkr!@R6Bq?CNE#$mrJGU? zKjQprOpj|42{j5Ig!++{)KrOxh~R@D1R+OfltYkKkVVF9 zshJ{?-Cje`lFAK=gaVG27;FW>u%}Oj3vE6q3?E!RP=V-QWL#3xX+|Oa$B=Zl*a9^m z*CL$-BRvh_$sa)YoB*02re3}-u5;hFktGB8F~@84eFuGuiHT9FsHh+!U&!t0H!$zm zp-8g9H7cBOb(J9F#7y5U$lz@dIT|`UmnY|Kc5i zhI0nbIimsTe zKa^p^$w!&o5ZV2WM@)gSNi&tjEOX(<87H<31P*&+`v)8+Z@?OSOB%+lyB2(lhH>tAi645N#boRkp^giA3kW=gg$z9L`sSk5EIs)u2B&}KiTRn z5xVb~)1I-hSC2mRpt|Zn=U1xXvZ}lG`h14zc>$g5yQIXzs zojE>r;1}2YbUz0cs}}a-!1#D5gu@!l*FapVJ!9UNbUN`dGXrL(XDt-Ky|r6a_g|q& z!49A*f28|%S3IN&-L9gxE`w8_g`(fKVpi?lgQEu4W+6)2qu2Hc3lpn6bR(t!)V>NP z-KOok<>0py`vFyjlb?Sxj<9Y+H-@b!Hf^#(@D?U>hkLJQF6ZnuG(!H^Ux#q<;CqrP zEsd`W>qAtU8~jp~m|g7EIosDn`wz)nBkqgQy20dz{&v=uqv3#wx0~X}etca$@>F(| zS{RFrLZ+qi%5y~=6dDbe=kc$JFFWvDl`(yN?AS2^0Z>p-@c$ZGhA)Ewm7>eNeh6t+ zKS#zH!$c`l+qZ9*a2aEOzs-2#6ndAsew zm|_Tdr#}2j;CrQxDV_pmC9@WY;WGkYA*iP3K2lnZSt&-G4-85KDEvmZkBo^SMk|;u zmcvG?K_Se8>k`{^IAIqdu)rj(0j3Gn+eAVdW*e$Zp2xByBp&J?m=h%)HR@Tiqw5nB zIQ_(-50iEWKijW zluJw?t6A|?GWrk-sHtr>z5js5e3K&&c^`dj{MdO_N;xz2#Mv`U;`hVl>lvRo%-NpZ z**^2FN;*AWx#H4PaR0Mo@*d9}BlpOvDQ*cAcj6Z_KY8fjL1NIK4yJU`mHxB9@qpb( zxFl)!VQPO5Itl>Pi8R$%n2X~yX6BqlPnr=Z$v<05g{DC`9%qxlUGX-a7F$@?*;3K{PfU4Hy0J-O5Z3qDb^4W3t&{~l0h#`H9xmFROKR?mw zqxCVofm?U_0gn)&$zmKIV$uyfoA73dNvc2^4zrwJi1N@`T3Rl_V7Q70NX+D8$PU5^ zF{6#k4SeYJJG07#NOwUVse(5x@RT94 z2}zDWHzOiwO0Q4vS5NlF+}Wj`VmdUjynHRT-N;aE4&H3;`HKx>=g#cOO?Z`YAJ7S&{qS8dyRcgVyg23YCiWd>78PepgvN{Brq1w-aV%)%WPv zXcYBK*IFO zll0d$LRJducQ{Ut6+&gAqN2pB7``_@5Zd~7gCk$P47wxwg+CC4WqmlCO<$ShKS8LN zp!4CNEzQl%h6OfsM6N3)hQZLtC*}f25J-?J-r`TFE#Nd2nZASdMuB@s1!DO0E8+y#ZU;Q?QX%Z7TPx5l#8Dx0&>K~5qJYa2O9Yj%uJ+7hRB;>=rTYa0n87=8!A*( zRD73_un0m|IE!M4gLpGEG?aGt`*Fy(#n-3Vktum}%L;R{bp$bf z(F$I3>C5wtzP`RxO}cE~aJIgoh~bSEL$78%zkVtYRQu-X=jr-C^)5mp`ztDqg-<^@ zsNtk#OCQ7^KD#0u9ZMx-B~bB0lX^$D#Nh8}C7!ekuT`=_&*)P|j@jKEui7Zoj}5&E zIyWJSx-U!!jE}{XAma8F`IK77jgtKTs7i)(5KIgoCo_=BPBvBl;tQF{%hgmZ{O#(% z;L8}QYhXOMG4;periXqAkYNU%F}|LdTivkAuaEgG({wa(?5^lc5GvI`z`& zoqQsqqMEi9_ZJU<@d6!GRvklxe(y`^IvF^@95V%e{}jF0XQkr2q8~mA&+gWFrl+o& zug@WZ7um!`VC>Aq#H6v5^`D={;7(Sig;fsWB{L~{9Gz!#QJf6jVeZCsUdtKbl2M_eF1PU~j3G-gla7S?sjZYjz2{V+TA%b9LKi&62Oq{nA@0Uu;X|NC3_UhXd0 zRK2cPKJg(dow4nOlLgQF2cxS_4|D`&|JQfrvN@eL}xyGZgeJV9OFlU@C#t2pf9 zFNd#^>Hl?o2kKNcXTDs&K6LSi37z=)-`0lrB6j_+ZwOLVQB`JIF!vNZMsxLrZ$*62 zWM&OvEs;U)Sz20(aXBnw3y&lqEp0DOjOk}Rf$K@K-r)#VAj3n?9km(lO=Mh90KOD8 zB+VI|?B~y@{@(Y>o5_kLmbW`4Nj{06m9M@ya#KaOVRBi<>R3I4P*0hl-0AbaMlokA zyUq^W<4e=uTc_Fu?_o2NPEZq`LpGZNvY-PqsK4Itxyjhr_!FF9^)Wh#14t+m`UNa9 zswT>C%*^%pp5I1+Is4$cDp!A4CO`M+N||Zqeuc6@2@4Eu*xqQdLo^pfY74DrRi$_mFw?B(bep*|eMpcc~?1Vzbw+YFW zy)P@E0zi392I1WMviXtV56|E67}6pKcn|o#ha2zDx$711i5QoUjEeFF?!9Zo_V;0S zJ`rA;HXQVpHMnTV@{^`NV6ci|WeBAPuM8T5mKZPPA}DC^Y%AEU*p{4cr&C_8z>u5FNGnCp{?iTi5998J#6 zrP0vPSX~}@c)aU`3;r{8l*;1*4ZQRN7e!UBjd2Dfi}H|0=dwqQCQ^K;1r+&)Te0~t z)ys~jLp2^69PFdQQ;i0chhZ-XNlAIZ*s~yJq`vpLE_ekG`m{WBI6|Gn`yFnAaNl0A zM%E3H&Op2*q03&xq?i`y7wClm2$<^0aWn-wIgT+r3kwSp^haXs23b?#Hksifz&D7V zs|bWBB$ln6I&}*31bZGgnbXWJXsD`IV`7RLx&`rKBgl!clo3#63|8SULP&&2`VwHK zyV84I1~Lj1WbMbE+o-7gRCvhgh)gi<8~?*@Vr&k9C3UI4g2=5y9A#{ddiPG<19PKS z9UTdW0B@NgfDWdD1wYAGR1)u0f~IV7;ld_72ZYLy1XSfRhk{5(+H)eDVT_)Ig?}i& zurL614)01yot5~fwdqGft(Kr$BPVgvla=>e;%py2*+PvYu74SEiH#t@b%;+Ka2eB8 zI(|F_ZdFN;(VZ5j%IU03}CbAPhi3BJ9LOh#%D>RD};k$jIXJHtB&rs95cS->9 zKbu4>LutDSL4c0bN<Vps;dhA9rAsLg;<#DoF#P$I*(y7VoAMaHAJw1&5#xw)>yX#{thBHjXY zlaEPb;=Z645?6%_hz%bm*3@;zqVvla%h!Gz<0YMVlVm(^VGj!y1o5uK+%4;!Ey_e_ zk$%Sx2Y|=`XJBMZHjzj(7|bHz0Nm1Agl=cxB)-NGgK~M8nEp#K!Ol)WBvIT~%Kqh+D3Gvk08l}X#bXINV6)*b$>M7zQR?Gh5^1z$I4m;8c)V{1ni$~S)4_wA8%MUm_J zq@4}+F@+CoXau58sz5W40>#n-QS&c=8n=T~>iY9hwG;wQ%Ba2b*mm6ZyZ*jB2DBic z0GXJ-0UKu|%DJRvWRk$4Hd0V*wncU3A$@}W3-_3*4dncy3G#ypiqLhR7F8ULEwj$?gL&2 zrH;2J5b2w-)|8u(R|cwl&z(K{DDOR>a@xXHDk^T$JVJETwh$93d~gUxamA|ELoVho z&fiwJAty&7O9punihVl{+=qew=KGli>J|3?Q&?=z=6ifxS(%=MY)nP9)SWwbunWUs zzZ~x4)5F$hXdcE}qO_4UZ)j_~gLhtWKE@b;7JzlY2qp?*Vq??t_G66l>6t<#@Akl} zLVa5W;cagnxQ^n#GaRmQ|9J(XW`4=Mp zkhiJI6>y7mtgN>%;6Y4$9q)Sve};Hn!?9e`SK{jG>Iix0EiEr^`GJW_6ajsf8+;u4 zM^G1SE>@0z{V_O52Ee>q%2{Q!Pc=e8gB%mt@{Qy)_lB4^)7_>{mwm>e} zpX)rPk(4k_#IT@+?~g`blt_R=osM=gD}*nP7x3g4=Dq^B{4LTkFRD zPC{CoKc6nHfrG@b&<$mfkS4J4e$H4n7}kndH?a_$3f_8h=>WE$}(-0 z86KIOyp4iSE&UnqgxiA4b9Hy8mX5(3Ef~#)fFxV+Xep_ZBNp%9y?cZB%}>Ap(0y)r z#*iKr5Q$m&3aTmLJQ^9ffb%f`!mKkYtasm2?+K(@l;7 z%A|jJ_}m7gG>I?IJ$k?F6nyz0vv~6U#A2Tb07YyYhV>jqTe%1o0a__RKJhj@xN-(a z(QL-PkZ_&hL}=mG;V>E!Jlg-K==z340&9qwEiEn{K%SDIkuVyu8ve^pvqV-m(eMoW zB}{6C{9nwy2UO4f|3Ca`XemjHrXt$I3~4D^v`M99M;mPwB?+NYijsCTG|^s?QK5v= z&_ZcyY2A+(uFLg1=YRhH`@YY)&$&D2d(QPmpYeXb-p|+b`B<-y9;I$0s#ERIve_xF zbbZ8236PmDD-=3{N1tpYxSTV#B!Ez!6NQxF21qyD%l!H$%>x=Soajj_;l3Fm= zOFPFieV*HW3!J(8_Kk2o7`1j9-6YN2bIU2%*suQ`mpF!B-48t~huj6M*0Of}*Q zip6K=OsyGi;jd)}UV%i#qMJhTAsP}arMq}zP;=D~lwZ4ZM~XIOlrHHa1By=#wufPm zZ(`q%w&f2JSvCs`CU#-wRAxp-KkQLIC%#K`OVpL5#eaMBw5keJsP6GP8A^)reJ-%G zx=XKD@OsZbc}8;ubdz-M68Bmf*i6s;YKl3lW%}B^>-$Yh*XjJI#_*xLm_AAI=*O+T z8D~4EFYy-L$TZtkeSjgi`cligdB6A%~}k71$6W&KT_N6)s@1U@@lFW2WZe<*fwc4XhQWHWfs^5w!8*MTUUsM2Z!r z+Q@Z>nYNx8xQOVZa`^lJf{;}QC3ITDJD4WEyi4ar!*(9wamb(SB933IQPlLd6P8V4 zsa>W9nVYBGt*os_+D`w@ZjB90d8B6UacyRc!M*qZdzO=T@4upHgrPRIet+;nk0R6h zn(Qx?)4A1U1y0{03Krxh_0Pe>8os6u59j|LZVof;h+jD>AoPC*n^Obzu|HC}Cz3#J z?2j8C9v;^FQfqt-J4Qg87yvPO#|{c+DQQ~m?ytxsO@VgVAnOpgUuap&{b=YpJZ|%E zaaN1p+D#Nf5Iu_hh`Q0+(~cdL)PY_-{DVpMb5O-9LvBE5l%-W-m(b2h z&^`m0>c23yXOja`#K35Mf!RW(4A(t;#~Rmx_f+)s^knj2D)t@KZIx~=_3=5`lFtH& z_&i!`GQt9$73&rY@|L;?MJu40rR?nSr)JTifWEhUAOOkGW57gs7b0ou{tfv*tbcnm zXQ)Jb`dzl_;q}L<9`8TPS8=fCDW+b|!9A|f$AiL;10lXNxczkehE1Dxq8)tb(5^^2 zeQdBvk|s zBIYvXT0VF>igNOvJ4}-IzdtSg~ibPTHVO+|C3P#A6 z(Ma?7+U?+Cj3hR&0}=|OtPg)oOgx3`g0$K%YZi42oQJYc{FTh26#IVk@755wiuV(< ziBXynBkB)%fYFV?w3d1;dG$ozN}!C3D4o&~&PY_^1x~1B?XF+jcUL0CwH4Q-f$$9{ z{)mIVtv-6Ft1K_MVysQTI8LTlQCYn_z4|!|`WLcK0xu(VbZ>?^W2#YFV++l*@0`o9 z!Gm6lH%QdNFnM}Kqr7X$ZSZ&4@Hq)Tg=I~`>WF&}vJEE7moT7RBETp-r=+A}pLBxt z;5H=p$KPKEkxZrqkk@e-hX`ee`T8;%QD=I;UJDnwv% zke3In!`H0_INN|cAJU%yp9IXdcZJ$p|M>V-CDXmX{HbWgv0W}jq4C9YaS!Btz|Eo8 z@$@^d@&y6lXDEI(>g?#~i;TcTf^Ufg z1j^^K=4NiJ>U|71E>_#XpUOeBJs@dh0Uct9ro*fGV)?xB*mZfiAMmSqnAS`%1`?Aq z+BH?|BdKej&s<}bqv1F&jD2F6$G#rcn;5UmALX=Q10|3JMEc81(F;HPMmk3{h72T} z0J+D*=z6R`p})_vKEJbP2EiW?7(hX@qh}NYm^DNZCI$BVm_ZjogEY`yDu|9x{~2Z# z1t3iH7$=CuQTEJ$38@*#JcPDo=fepp8TV~GJUno^AJB-HjCrnfXc+~P+6!nEg~i30 zaE&X$y33+pB@n!A!CMqO_#jt83vK66@usaA_i7cWO%&kDsgW)M$&*$MTt~&7u}jEK zWko2XFLa5R>H=j-rc~qTKVy^@x~}p30lmm?JTG@C+r;e5U`^W~i=CBd&%Bbl$URm$ z_Vp4@qlM|5jS?z~GFQ{KUE-`Q5)Y`86HyBCUtx}2XtQpcLm>}>M4w`S0L7lcL9rHz zzbjyP_!Y7U0ZLWyf4y`9ljwUSC=rq{Xtis_+39PKfkds8l^vdwm|>fV3>gVT--12I zmIOLi3mE*AGj#RR9w1W%N)MIt!}IwAU4QhWzlMtWC#P8Bcep6&8p}3R`(;MQWq%Ef zEu7l%di4^0zSzFaOZRTzp^bD_9Fxf#cnjIU214aPJXeM*{`|?-0UFHGyd({H!1_11 zJTR8iKf1ni;c7;>wSCIxt3Mv-@u^!$#{L)nj?}mn&WhsV2ApL7zVqSN35AE&8KRG0 zpz%bIThz4+>nGbOlw5V;AEyfZ)RG61E>4Gy@EKn2S&yoL(wSOT1M$oACsF@^!q1<* zb9s1eZsGQ(-S=-*zE*W_vrJk6FjLVT>sVy#LOq<5Om}AAzkdC&D!-;p7wto8dJB3$ z1s1{I7_?;O%abcqWr_=5xXZZmF4vMTGB9s-zT5-p3N<%(M0UoXs(bliNQ2)A-I1*# zlXTX`yG4f6?5Ute)euWs@?SXnZNt0Lg8>WM-sSrS#n^u^PU8D-(0nAFy@LDs`yuQI zQ+#6jrNf4b-)>2I5v%?7e}Qv3mXlj}C1%^RFRYMcUVV7qE$9CcMjxP|^Qy|Sxy!$A z{z=IVGoQH*{c|~%OH|c8+oe%|gNpL)LuA>9k6D*b7%BcYzWxDa-3sen-8(wn zLoRG;Xl|0Odz04D|1NW#mFd#}@10D=UQlze6y9ln3=`}Zbsm+=pvngUQy_(I~m1w_P%Sfeo(bqM1-JgK+EJ;FDn5L$_1@1kYns4CmLz`3x`V1KVT>y z!9Hb;2aEw_0Y!+Y`$+2q*o(v;n3{g{Y@DA1KQ&WKV@M&ca$C! zif>(rZqD+(^H}A>-jv$_#3j}}uPuME@q~Y$3^eiEg^)3SZqdiEKDKAxrj9 zTk%?j`XW!(+Fe(}_Ir1a@u3H%g_u_zFAphgAUsiBu|D%`qv#x1uv|!9tSLPk$-zhI z6MvK6GqxfiAfOz9Cy;pqD~13@ltMyF$bEwuP7f9)$Ru@o^06H;BB}sd5vb(jQ6j|y zB`6(T}91x1;aq=<1?DYeh2f(OWE{mE zelbjQ^78Vc2`>Tz?ErKlG~OQ#IJ!i5eJ+B&y>H%3q9vf(ft;MgouI{h3KF;}!xbbG zfyW^f!|N}8u$F@0<2ZaX3|@W+hbIyV1<#2G4bButFf~L~j247sBw>3*8@MFg9S@rW zVU7UufDa&Q8RCl?UEZ_NbLE;8lpnOXZ#FQ!*;9$R~Qfpv`fa1i=YurF@L-cU)?{K=ucW&cyTpk zH*8RYcy#H_u#1RykvCAj z`w87LX{JHgLKhX6HGqCZ-ldlr?G}yQ_g2f-6N8_Eu}gHp=1;^FBcEK`I{~_b&^y1> z3oJ?%*aQj%3Kr6x;lGoXK9l}7-H1tzU{^Rn#3S>}FT2l7>zuRV-lM%XGh03}kDU5S#q^`y@-mnMHm4|6H@26YvAdaN@xo z8j?v!KTwFASBzyG-(;clp;iW3q52VXfF{CEarrq%kHFJPj>-7=(bU=#z#*-2{4O1x z%X%52c7g%|NE-!Tq52U$7Xt-T(Sf)aKzKqFl90UkIAAN1Ye=&BXQ6E$R-!=Sg?TxB zf6fJ5PEe56$T^NTk}d_a%}$Bh;*#`co%`ICD_6vjeWL-L!YZKuQ0&B9<65yolX48Y zZKi(2tU&Ry&Dqs;pJ^sMX(;4r4O^G7!BBjWq6bch^G!3NHcnjyx1;^_5df2msHmS1 zv#O!JXHOESlN;y;Ea)*f9XfJkCz@`kCNhm!-3$A+=Og=i1IiryMtkMo1-QeS0Z$(i z0rIZ_E#H8(98YZi-Qf-nfc8m9GiXoZG1S_<+t^h4w*VNqix{^KW7vex%VsJ#w9-o=Crv9*dyYd-BR=w}H0>41R zR^gmiGwHeb-u|cRo?Z?keiLa^8;G`gGVP)l1FLW$ypY94RJNqdycUGmt%G$m6 zJiV=_c5To1jp3(+%|&MhW`@E?6&FT!Sx>ht3j+#$~PO^V>NFJss20y+)9;@*!jYwR+g zV=E#x3=^Js!iyo%33M!!XZ(t%Asa(MHSBTJHa>71$h;?!+d(;3iT)@q(Riu{8$#I7oq}dBvKu}NQ{2^Qt#;?#S;)&a}b0@aL5j=FDPRn3m zLBLB!ZpQ5Q7`UZ?Q@~bagT4=qb)%c*nSu9yxEcW#!HB>irUk!Bp1j;#*`sxQ=mSWD z#^a#Ni5`aXn7xBLfplg_DE;QX9J4M7@&d3z@x`BZqSO&CS9RuARFn^b=B4CVwqt0 zff`GT0Onm6s+DrB!b*dS)S z;kc(U_zNOCOQh_=MY9G*m|G8cKt(MDy_Da07#v!HpcrLg#3jP%_;v33F+%>+sV)0f zy81w;$p!-&{EdlQ&W(CZk0w?)-3O~vF^bG2iI<%yek>%#xL=qn7)KSY%#9XbbjJlC ziovs0f?8?E3*o5ZL>ZcbB^t+>tVIwGL@bpsDj{Tvq3u%jFY^uJd5Czk@g^#c1Q|!L zB^`w9SpB2D#i(zXf^y zYi2PEzGQ58;MSIK* zLVO5y7%X@1SSuvdS)4BuS_=q{2*LB*Nu>6$Ii9sn(nC9qc58M%CPdW1gYEY%64SA& zkx?f@{8d<0+&mrE3*o&W91~H5UwAVL9Fc^biAU7Fy0eWNdlXUcqp#Wj;HM=f4N6CM zx6v`8>ph6AjRZH5mYx7m!q-!b2766zk7=G=VrF6z=H~rtBvhmq;ua8IzirYpJUV@a zxp;J#)0O9$&g_vMAYD{<>~I1XueNJf7g`e-OsNsPPQW){W+p;D7&S?7{TXi%>f%0k zmOvzmJLm|&1nm(*AW1+0ny9?zBRG~H20K`{nw-#>LL4zg9^yPpR zcai$k!uRs?1Xx7`SAYdc95Cd0x@v*i z0cuqMd7P#&%CH8+m@m>F|JacGear=@Cj`zHH9E&hPYQfUe(P51XWw_PUKT_u>6LHC zprpsFTV7MME0WLDST}k0AI@Xwl<3fvn-Q~1E{%5W$6*p=D#m+&yL_EeH!WpRWFn-T zw3fBI_r7*p+!f4c!~t^QM(T~p+lG0H6xG5P_3==ca=o9~QwJQBluBTJozR_Pm*AHF zP-Kpr7FckI=0;hPOF}Y~RTEgb1hKZ|^^@k6M0mmTMl`AocD)fWtAiOhcKmo8X3~J( z$&nF1cvmny%OVz@1nh!oHT=SFqGyH!#Hiew-z^fWn_gXNco>iF#^?3FmP_09T83@X z3LG($$4nMo`0UKCIQ=CodP@?}`yiZv(}9I|tR-KH=#cSU2_LMCUFqS8h*3T*iY%lE zzuKC(w1`?fC$Et{2eFb6!mHrtuYeZ(HS&nU3wlZ*T_SRDrTHlu$ajeD2&0!Q!sL-z zn4b`Rw~=F0fzeq&1Ja?}exxHft8?%aU=zYE!CQQP_;B<4_scN}JjKSDg(n&oOb*ZG z?K3y%U7!O>aV#h#+Xr@9vZa4(gRL$hDai)klc^s@&bT)op5xyZJv6xtP*$9oKJ>-5 z%dBn-Taq@|&CC+y90*KebMmaTJz+tvZL{11MHq=6#{Q{)e0!Ql9E2u93JQD>mI3{h6=31&?Qh~=Zm;|ap-2k$F`-z{|YM8>z6OH%JcuufA6 ztPxAl9Bd;|m*A$LN#uSMi_b=|6bKR;U@s$GL?`i?5h@1giJ|+hXGH)n0zqS58KM>? zub)}+lZm7mW1tG~HBdprG>kl+4(#$|lz<*8?G(S0i^~r5gTvSiByugW6M_Gd;9I<9 z$BrH0-_382eFfEmtcHDC5+Z*S$^&jcTol|xw~gzluVVB%@Bkiq!_{Foo!LiEG^VgY zPu=nTd-}DrU~!134+;%`92SJfu3{9YMnTZZ@njz241$ppM=LHtfMq&pu}e6l!wRhc z9NBj&8p6?p3Cqc`5y$MB@)$T*Nr|^@5=VmMzIY3U!|7R%6GvaPocOi_xoQ9RToR2F zZW4BcbN+wtkvRQ-^hk^#k5_16rcTZ{EWiE)At4ZtNwH+ z@T124rKI%iK_jUOfL0{FLp*=0Ksdj~iD@P#CLL%~iBSnx$`Lvu>i)PO!dSttgJ+be z4T1L10~AVR933-*yOWe36ap(KRnfma1*<{89FRiFj~IOs^MMr-3>6mP*$g-QismWiB;N#yAZn z;Npp)_RLN7C?;epr#yDiQdd{E#$e?J003@}t7yeirJ^k%P(U)fDp^3+G*&~-6;a^~ zH4U-!!YK=REgRW6v5}o@D_Q|X*;DMgoYV!Nl}L~tNl-=8Rt~)Y%8c~rG|nQrL=F$= z6a59#OGE+I3|J^^yw7atJ+730KVfcR`tS!zE#)HR`nv_CM%qWY!H=$tCR7cLO|sdd zWK!KFUVXHDsz_Tlk3S86O$SU$V~-1(@LlwNxU11%uh(T8|{ZMH8jTUpeJ!WkmBK*Ct}-ujiMi% z8~e?Eqh(k=c3K?!aQFpRzz|aLz%Y5~Ph>yzbDGEy)@mt)0+-Cj2qM=+E*{#a zU&qpbIbw7-;t$2>9SVBaIK%vwJqGxW5)}NKKu<+QMKuo3VcOL}rCs|cO$W(nIQFH~ zDB}bwZ+BN0iNhg*M2Nq^QKMhV{AdF5wN+3+)F@wonm57@zHZZ|H5f#gkT^sXyxrlJ z{ph(lhz7B?zCHxU&uAj03%i%7yu3-z1hhe#*hN<1av%rH2)q7IyrbBvRqCnlO zI^dTPQm=uB-mq})94P)}6bRpNQ z%B*8#x95%x1?!;oY^|xby0sIn3%rsf#rA zDAjoc$idxQ2B;yDJ|U%hdpgs%(?YQ_J3}9zZgl58L#w-1S9hWTNeL4s#e#$8F*wnW4{D4m zAZ!5}PS+HyZd+^e0gV8sZ>x3vgW1mCX*RIx_;{R~XjdL?e>^ty^Ih6SPRk9xB~jCf zObR7FnY345KAS7GvAa7{ue^v}`GLpk@?>Qqd<>g$SPAaD3h=caAeV6V&vig}Iv*pI zjffKf++)Ibh|aJ*7J@jWV;21{^W=Jr0$lw+S=)rEi-(ar)o+@kHIna{*luvl#xJ` zaT?Kuz7Kt8F2A^sIzn=BF)td+>o)OQ3ai3G;5Q^P&=rAhG{qG@A7{tcMa^ z4P#EF9lw;Dpq7wYn>jJ{Jlp!`r0ug({?hceMf<-BuN31~_R!oreO_GhtiE~AR$16x z1E6wT3AZZl_CAKF#{2tNaSp%*_zPeZAZHbJ;^`ymYLX>_BT^hA~lr*2O1`awO${rPRa{-sK3hW0c# zu!|Y|0NLkCetwf+UM{O_?GCvKM@&mg%Rz)|5w+aM+~en0tUn@3oAmtB6KUts7^6j} z8NH`RMSNMeZAg~9Kvhd~6}L|$dkzbm*aa5#-a_&L2nPf{#75P-NAA;%Jt}yAw8%kc zdg_vzFo2lh;d9RCLJs~DoCM-~eu|B4!lF=2Vd$&<;%KeC4evVay_jku7#;~oDl}!h zr*{XUD-KTK$Epk7wzGLu-FE^Y0(h5UB*{!VS_bM8$&`3rzLoNQi(oz0nw5 zTs$rAaEwvn-LgRyjwop{k1j{a)i zqwVH|yoe^Y*smKpQ z+vJ`p1_UscRaYlIxc^&BinJ}756+u*PL+D!(aPEo#^s>ix3_ery_U71fHom5dCL0n zC9a3l&w5z2#&JpsX>Ow~7s^7K{>Xi74K@jA@q`lKP@$n*<=b-4XpN+#j*zbPG{Rla zV@M5{{g%NO#5;fu(c%1{X}X`ks(cuIFeFfBHmm zRnHn(5y6tlKSl`XrV2maHG-IsXrvN%w<4%8@gVVH3IQy#6X6(b5{!0Bhtx!O1j>BT=>7p&f#~-@6Ee z7c@6GLWLf{-XaL*NwPOK8)AHhZ|D-yx`l0AzrGTmkFaSQPQG6X11BBfy%xjOSOKyo ztT)&c**Jg+I);WqP)Uh|i|l-kz;iAm4+68YR59J(STR56L;iOb$Uy0w;}@3@8Mo~I zN9D;_zrwYYuIG<;Oy!=|PD$$7_n1)!=}^=ZBFF;47w!(lEcwvHz5{??6-g!t!M%uy zfSgkbuQHh#0YA{dMU^;;XMrp>GS;BzBeAC(k>v0T1GLRFn1{ED*dGW!K@Nb0VWkX5 zj$-tFiUwKx)WYAZykVEa_8sOnTPb&M1aU~6*z2G2h&?>Hm};G&@1;Cv8pBpIEpq=k z`S>ow*mh+UiqeB%cd9?`0TWH8-78~gMc}y3HEH=R0mTwC`N}vy<2UTn{4?D8-oBRt zd$xQ~{%N*y72{HEsvx1e_ggyy#3VNV$Eql|`;+H&|FKC+DoWK>(=V?Z=$?G6j=VIZ ze#(BS_?VP3?Zqop+jEv>*iHY};OKoOf0jrRo<9~o`SM3_$ueQ9p`|N3sq{W^r!>=F zF>>#X&XTAW+SoKM@x} z@T@y94HW>u*Lp!cdwR}f-w(Th`_tLg6@Lwxr;y~Dd}}HB8|d}z&zFCopz-SE)1iGa zO!YLac6ZmllOFXrC9uGvP|A2{4b5aF)pi$UThntJr4}B+XvmGm%pV2pAW&Fx4t-&v z=l*>`oaAI_(v6dL0gXHY_jI*DYrlpdfcT=ol~u<+xc*yeP+t|J9diBlbrwp>VLRH* zGfdT{A#LGThGOJhX_+M{K~9=RBRL4{Cjl}DFpM^F9}fi zvQXgKcHi$Ud^6x9mL6o5lcRTYPc#pnKN}l9ul8kL5zk8kg?!&&?OWS_<-rgdbwfSbR49obbrpneeq&3dQTe4EE?C}d0*=2 zXq`S+G3K?HmY&^?m_~q^Bwb>lr9jvEvCf)GSVSnmrcF#uRgph}X9vUCP6*x}P7vQB zL85+)kLx2Mn~)U9_7orM#UVpq-z1V_aKLU6rx=_!tx#wQSpw9;;v(ivuubLoqySL> zhG0Vud$W~W(!>|hW98WBi6-&%>1ve2tN|6= zfJZOS9RtD@j4lWBEX6@K<_~@xi}$)+EuPgg$=rbZhI?(m%CsSgJbD6u8#+Pf-2Pp`^2}x4Ea# z8Rs<+x(N+(W2U7bC^}_yCk@4(u89f~`!XN~CPIkw*i(_zAc(;{65}ZzRau)V28bw$ zBoiQEB__u(4@*#)beOWM!OTp&{xHlXhy$bB9koFP zrJCr!ulazvruZUZYSs4b{1}Q3^6xHSmnP|Z1T(B^(U?CQaS7#fP2@XE4c z@zqwV3x z=idMDFu_uhX&y@NTi4ElviOzfrEeMtsF)mJPEPI;qHI8SNMHed zwCV>9(?GBs3&G^~f*SD`@dqP|X>}g_hsKdQL=!E1s^I+)<*N^#(_m{x7n^1Y^~$Umx21d(5_+^}2fF zP4LJH`Ssz)N>=@vu^;ED1?g&)WDBm?{dsA1`Sv8mgG!b&D|j;pIcrlK4(c_{QQZ-4SNpR?Uw1I0P}+~yPeBbM)VH3@1IDaCY&Fh^y#!JeK) zlrLc7K-Gu60VsAq7dctVAiS9B;Q?bMvsMwd`A9)AsBA#)5uqLC7dCV>K+NMH3gc*Wvl(5 zB{J|Rf&Z8|LWtpp0{fn&NzCoru&FbXW)S6g3MVmgOf04l(jCyXRvdpf&epgJK@oBB zp!GIbzUd2S6cSGY3^(OKVgARh_um5K3~(XF9yGy5$7cRUp&8jW|IZw>wFwBJr*Ydn z#Pe&sJUDP?fN#Z_xiO7DLc!`UQm5^>uSLzQ*xph_wRPgv|jO;LR0^#Ioo(AxZU>kj&v-A{FR{A0o#lN2NpZ*!queb+7$w|JPh=TmIO8c-E^@N~b_!>%2W3-GT3 zBsTHQnwp#Q;oe-p@e=GHGMJ`uo`wGcDfife3FAu+z(exr5+)X zbtrt?0|>C+Ov2SCV_(b9ySma{cPJ$UcXO!E%&2>2cg=-A^JQHpZ~d^a*yK#=m0JQu zcO)-Oe%LZCa#xDVdEK?%Zq`+)ax34-6!972thQJ=ekuMhU&9*s8%B$7staGwvJc8v z6D!QPnNg8{I#+ARsPWxieyV*D8RE+KG~1ADf}x&*KsA%XGg814sloH0hdc<7ooDyW z1Lgn#aQ)xoP_K^)(~~Np-lL5 zk9OSe3|}G^|BH1}hD87n3v!Mk!1D9RKO=?^NKDVci3d$srtI+|HQ{pXe@?(*Nir2M z{pCne0dE$7ZI6WDRg6T*9lX#GEm?Z*oYSJ$3Ez8KSXo$%c^gkO-1Vz)ahI*DsJLe` z{(F>opwQwWO=Meh;!{c!gWGh{MhCj4Tz?J>zC7trM*SxEhfRbq!zBYXWfl|t_+_ET z?l~;eXTS1V?Vz6ILjw-=gOU%eX*d#EmM&&`(~^BUzpn13`@-_b!-uUGudmFWRcx;A zajCoAciU6p+M~^rE=nd9s1Sbn2QHAk( zdUCP}>(*wNZ4_sRcho72=?qNl91wL{ZeI`$mQmKag7(2)iG@j@g&Ob4Pc@!H#Wh>{ zZtuXXx*x3=3|J)p;EKmiG_rEc=V4-jhA{cOqKv4h6)O71v?M5`cK|efw9FR0?E}cj zwha}9ZX0~mb$6P6Z?KA}yT@~}WPgbp1OmcVTvI&R_5yD8ukRnOTB>D(7kr@K-mfv> za~3-)&VBGG?mhWFi|Vfai^WKURa1+~885uFKTqrj_v1a$x17AqZ1VV?gg>)3<~UcY zquxByp+j4EzfR2g*cGqE!lN!+!HYs`jSezes;e(TFYOpV2-%k?0+p(O^;PuUPfNpW z`c$QP;WKyX-J+t~rL#RUd2N{wwd>BdwJ@m8 zKxfD9_7_{o^z^WqGMNSAidz-5Jf53q8JYH?_qJ#$Z3ax*;V2KAlvU$rh{4#{ws38q z`Imx6oMN(A4dWQjUf&GfHtafeeOTqbJ`uT~1QX-8c zwk~8zA{bTzX%uomC?U2U$g1Bm-5q)y3hbexu_c%s6u*q&6=Hnm-Ql*bB`*Sn_ARv3 zXg`@YD-MBg!nW6o&$o?}lQY+$JqE4w*mH?Yq}3MkFjf{0J_%-3obny3+ZMvUN1iIS zz6Oz#7n}Q8l-duilS7wF!VsVYdHgNyJ~47FTC~V(?A7L?(dQCyx23&Pw7O1gct>7F zmP0$Qi{<;c`-V2@QepG4f?djMx+el(HvQ>sbh(ag?tPt(d_`V+$ujJAGdB3zH!hR5 zknLf!z;`=;XYNw5kCW}~;ni08z9IrA83SWioi-r128Jsg6%{J1KH1HmPO*a{Hg~G# zyzTKV6?V6x^2vUzd~7SdNzwM7(KpxhI6HL)a64DBY;HiU;%Sc3Srg+qxP^{Y;3a}C zUd1lV9%Zg1LH4(CR4BUQJpJ~sdl-GSzh zK){o2o-&HDArK}(vR{@MVPV?mBFVLQ-+|MdSj-gYZ8o0Kr-jZ8$3!k_b?ds)>ij}V zo^MIU<6fTmkq{sMXin4gy=@--QfFsc+DE%Xbq$P-tI54X9PKlQvtzvA88}MX@lX+4 zBdm@GaakXkOM@M(gI=%Q@!n#qGG8jMi4R9xQD2IND;C4i6^Pyvvb|^6)iA>zZ!B&8 z`LjpHM-xQGW2n)mn@i`_5!L}A2cmU?X9f#1pXOn`TA>Y97=5-%|BB|Xdtx#+`Ykve2r z?{^%dcymEX=yi?{!7~~>xHN3x_)DV!Xc|=&8|jIn#C%p9Td%5E4*9P!((efn}-B$l|%Xq&YlJ$ zShDaMQJMihZP}j96BCvu>DN?mBsdw#Ve4UU-+J}4 zLFl}#`RWVQew4?LAD^~>M1}MGv8Wg&+{RhlNrjKrd@e9II5yNpLJD_?4URx6~4GF^GQRV6x9NSe`0AB_@jh~tT<+TwvR5$HO-NqB>NG`5NB!I~Lt#1OY*Yg(^~tS- z;1@r~Cwiq@bgbGff5yI33_o)`<3fA)OhKHLWs|Wg-6yKhHy#w)~A*hufE&%6I#sC~#Zz`eFl$ND^5__(`~RYmSs$^Vpki zp9feC%QU<>7diaH)k(s&W7yN1)%f836D>&)>P8PT2WXx0WsRj?7rqZ;pFd%Umo&cU1UG=^f!c+moW?0xMPa z?_UA@EgAB86Um$L{5LLM4@bC-PW-S5|H`!IzBy1rQbg1vb- z+T^=zkFm)N)4-g7!M0;7EW^l0JACwL4#;0EZSDO8$6eUl`o7?;-+XVe`twwI$@qnY?~BRO&v$WKkt(Fl8{kX@6uD$Yg4)|+LSO*XMVpQwmBTF%9q zPJiyziVw<(bjCVCzn9iCIXh45Ey2xw)K(j#BO8?HqZ9qaLeeFyW}+J251t$3RyKF` zUlkr}XruMd=Q*CE^hwRJy|?n{DV=ybTh&U3(XW+FU3Kwg;R4^}MD!D@!hXN@MZO&I zIVuZ_;*DMk4}F$6ozbqezfsCo;wiCVlUZEz>WkW!%&$3XX#M>x{jX$P4fOUdQZD>K zyXScG#XA=xl;{Q|9&hc`(byu?IAwlsPmtKJ=MIgon>4f2KDhbHNc z+8m+@-+>+PpKq!8tDp$H-X6>x5^ueEc*d%Sy{Trq4Ac0{b(^zxg>Q&LIr#nK{8CJG zd<0ME@5N`r8>#oPghy+_|7p;c2K$R?i8e1ia%q0o3sj60?tOM~Zv<)YT1Fpwl6Oo^ zHFaFfaoQ|vwEfyU5hJO?dAm00ZIS%XXJ5w5qflz2J>{lax_9+^-iNbw`=}>qzieNm zAf*~479=Ls!~Xl4TtE2y-q2)L`dvO>=7RIq2iD7NJi=W(s+#iS+$Y8*-0wdN^QJUv z{=I@p3cp%2oguE()h#?{&RZ=_d%i<6HH|6DbEWi_Y!$8IV`8!Y+{^X(-Qf?r1B)r| zrPi$Z@uJ-JxYeguW}8(yk9;{Lo#0BHybX&`4mI}QUpOK)!+5aWO|OsHH2LY~0lGxb z&S3YYDgG>iqOEL+y5+0DY&{?G8!al)>vsf+ZLw`U_|MIcDw4_P8huq~ zLs8fx(Kl)D_u*vP0t++E&Fu!>tuDe&HEYeV|=kp2kL_NzjYZpw`icG-=6Y7Ow6rr%i1;d+qp#5*Zx^d zyG<=LMTWeJ*(vO@$CV`p-q9s(xcTaVfGa-@W2#l&F6n1GWT~l~1-&(Pyl+{Fh@Sa;oPc>m4cBGkrIs%OS&$JG1DF1n({ zTvgp4YN-ERX$iGVv@9`Z4@JBb;w&YW@v#i({FolY)^YTF@}l=$t2jF2=88;hm5g^A z|L4m-U@iIdD&rm#C7b;Y^Y^IN!_mxlmQ)U%%iYt^zwDy+RmZTr8TY)#vuROjN&ha6 z4f}HxKXC|zeE)Q#>e0?00cqj^?~=<>=vT%set+Ql;^C!BjHU8o(~?xNb^ck~YmOhp z0gIcusjLd3R(XY2+(*Jj%L?wKnH_~SL zenaJcH+d&f>R4`Bze<*?RTryWMNb=Re@{*m`+E%-#mA=Idvd^0{n|(*$4gaYDK1aT zvTt>KVp+m1&tA?qXQXkU$zUecR_otK>G(N|xs%&O!p~Bk94Y+Tp5({!qPcW2FL(J` z*0RC{`ZJs5LxlzU>YEBoZA-g~|NDyuWMO9S?(7=zHtTdY;XMvpu1KCYSza?d??;axI~g1p!>$^XDth7= zu0xGWQ`o;%>P><|;g95DH>3KK2FnIm3d^&1rPbW2qjKG3!q0YTu%SrMOQWe*D)Qu} zEWQTT7q-=ha`;F0;ZNM!7gl~7*{J?H)A~=HVM|Tk-p15sR=`WiVO%p_Uw-7c@`2r* zdrc-ykDXk%W5}{9!IY_N{pXVkt4gTv+D;01wG?!=XD^z3c0KY=U1XKx1O4iM%EX2h zV|!1Q3vQsuCms;zUv~f>>{3lRga1=I-M}ALt#mIsh6Qr}m}c+@mZKV>p%WE5z5L$| zLnE|GJ0HgC`}Rohxsvp(NZxBk>!PhIB)@FsxzVCp z!xmR%bU93Vzt;4(jg+ZP4o{=PQ*Bc{^!~kme5Z4kq}=3rw=tAU^NPl|s^kjI=5rtD zM#QZ0)uhX|R*es;k0wSKOuf-weU&%3cz1N<_8F7E4z~T9@;24|jS?Te6|CRv0jTfwCXWqWT z%AYg6ZsnkyXEeXl?XdVsOBUmvj$z;RVxxWnJF{}blr_aDToKw$1}t~~^UE(b^*(fR z5KRreVb#~lo+G#M(Awgjkn)caor&6=Et|`}(wUu%y2^W2sc@(`WHhsxO;oz#v(rIM zd$qh~C!K#E+%2|8q)>=ro7YE&(R$9Ujo>cpt2Y*mm1X#B6R|waZSuhA(N{JvIQQvI zweK^QYGU|z@jcUjn3Hmwwe%B>Z}rK~xOgT~EE>iS&T^$aSM3^TYshv;sWD<+UcNFh zr_X=MIh6q^vozB*w!Z}}{_}K?8N+fqN^+iV;(;gk=R@ZdO{UdI2eVRIESpN;LhZ>l zBblzDj!##ZaBh{ZB_Sr~HAM0qvbE(cQ#yJ->y$wISKKYavuHUfd zrQg*`)hqP}t7Mrqw1mgg?#qf)9o)V9%7>h2_J4anjhDrR_T27?ijhp8?nRJ#b{6P(NeDdt8_&VcOl&|!=vFzO8!zQf)lA4-wo@ZR*6TSZIbju9e2PJ!R zqyi{aEmD{JTbAWKrfS;rLLgNDL(rg(>}IPK9c<5{HwG)8rShYiliUB&q*dBjo1g0M z2S_^N{RC5Aq1l@KF?1VP>gNQ`nk?mswriNV5oxTwh-vYMZDVZbL_{KS^0@a zr1TD)luS2GGyTQ@#OlxLx9JVvXqKA7MBT7$ZIFOeD9;(19G0zjXjXhGzeVfBa@&XY z#+uy~hcF%dST8SKvOKtjKj-M?eg7^Y4ISF^bS^jTmI&&8X};v43ICUkE8=c=pa3(tu34;G4^E|lALhg`zMzn73*d-K3yLyKRt!Kytks}9*Jza6t$ zD$ayQct8NV@1DDPva@HY1Jr4YEFQ$Dy-d4#iQ(@8&`~Hv3$5^E-{Kox{Nu5+-=p<% zSxkQ{%2{#S-tm#J;2|jtiF!tASyjQcM=C#S>E#*ur~Lc%@7bHFtOzb?r}1q%Y4h2a zPqNozul5Q$uPJV~$-R-AskB^n%q7~C6}Ycgx+1XS^+H>frICMi*XuoAYzhB9$!ecm zUTkLHpxU*5TjJYI8{@EyW+@YA1$m-hUUL$bSGdla&Gk4$LA8ffDXXu|Q|#>AX-R4B zUVpkO>A=5}hMKXuvG4T#{0kv#S)P?2KY^=oc>VpFwH6e{<8eJ+`dpF{@h6+UHTZqF z+}G-2EX8S!N#)Nv)E`Id&Jkuz)u6cD&+ubnqt9qYUY3*ov6D{9^}*%0mQSxdVSP0u zdHnVVj+fp`d@Ka&>h6%2_XZ7aw^@IGF|$Ocq3=!P z#_jYB930C;rJr^?TJ3SPc1krBfAhDBq+r(9?HqZ=zP|Q{e7}-=Q|MA=DXN6=)N3DD z9!xOoa?7$(UsV2F>s+-u?ThDGk3M~q;uQJ&5a{dcee3eE{LU<>`>mZPIx%MBx)KAE z@l)eY?P7ANH`g(rPj5C~d}Mu4Uu=j}GUt&9vnB)mzeQZFTKj&1!BdGFzOKL$7XG`0tE4Pq zU*w;2rmWk(`CI{mJ^9%nVK(}F%FSDvawa$kHJNPZCTVHDOfpmVSC%ijS|zG( zYWRCgD<5?}`Sq*G4V$snuKkVdh9Dsifqd|q>I{5fQ5=<)mc|oXaPq?qkUerpiUL9& zTx}C{M&qC2sqPZtQ1O1VKtBJoKtEB~4h* z-2a1p&4Ba4J$QU`)urU}5M+(x*y;d!YRE3HI+Ek`IcV(-F!D%I>FMbK%RYu2hA||D ztw*2%OdSMyC*Ut6w6B?V|0#a=67GqY;qC|gdgQFrkFsgK-sIM`Q}$e3Z(rYbhPs3= zv>qJ?+ouEz2aDJzFS4y>7VsbQXU|$KVsq*BXnEIuqh*HN2Trjebz~iIeP}7JfD6Nk zW=ecZcBMBv0H_0s7(!SHM33Ep!5Q+8UZ_~F9rOREhUq_0ebOO z;=$$I_MZ5F(`m-Q_zr(OkvW3@kz@VWc@B{SUVw$0bp0gvP0TRq8on{~ zy-Ui#lC3~skWBH9W!x5S*75WFe;}w0>ah+|-O>Ej6t~-?1nMu))*p4*Q7h*Al6{Bv z^n*ZY=O7MkEvEeiy2*z|YL=I1N!He)2*jnVPU49U+j_$j?)htA6yy-X%G0Mb-OXON zwO=x$aBwT08s3c7pE$*=j@#0!JiVapIS4!65+wQ#Ej&W@hTCvC{dA?^{8GhI-0c{T zZ(@W~Z1Y+09yqg&5@GJlh_4|HNbevRh0gxPOiWWvnP=5aO?hIyCsrV9nY!bhnL{GZ zUNm?8befz2SZ%L^cB3RWa`~I!;Q*H&A`RYE;?az z;w_JvDbG!{@tc1!D z1m@C2RxMwRcA%`E<$G1%nS*B@-88vP0te3{ z%C8`#Epd1fK;Jo|FxouiYt6M=@SbQd$T-?~rucWIfA&y(DPUapaq)YMbEa(j+LQXs z_Ek(Y#XbDQsim{!UdzN9J)^3d_mga#-;Z482{UsGp0O33R+cIUB;kb_PD zLqEnstk6uncTbptu=o>q`bOMr&6^wg`=bn7#si%}I5&FCLn)mEp-p#BPm_CyNlLxz z^@9=r;>w#GfL%x9|0Yeg%<}6FxxHqc9*|6P8WydCvk=JSmU; z?nLof{L3rz8Gml!F0)(p^*--TF8*f_+9=Bz&>=x(W-cmr_@TPuvsgwBPBzIA8(M0U z&>Okz?fT@Ax01bRy_{v9lbJ5$%dLGD?ji^>vqlP%!epyU-kQXfVU%c`R{~{V>*?kL z8x;%tGYUn;#5xiAkl{2Rbs~ex3Np_ruae!I@w7J=EtdS4d> z(%;s6y_jRzuI1`vWyml$*Ro*`xt4)A+ti)XouSVzYE?Jl&9JsBf4;DsuOMf?JT^yg zSP(EAxqH*SaV-U)VeQylI=%bK;~f{6tsxU4 zn<5Ov5l|dI{J;2m>!>WZ=6{$H1OX2nQc8EJG^i-uQUVH+ib!`03ew%BgpyLyDIF3b zs0gTlfOL2L<`VGvzW=lf zM&UIuYmpJN-atN#ii>N9o6%-opy4F2P-UESs+-!$ww~Ont=mG2@XdE4c~FNrI^1uD zvr&U<(fbFZy6;B6K7?(FpWg_T>p^-Y+RB zFfuS8BX5^#*OeZ)Op$=lT?WkYV3Z2K)2I!WO_}}V@pi+pB6w{$L%;*{89fy z2!0I8Oj`RxK#yVhCjrVIA2v&CHRk&Sd75R-;&ostkQOg*m5E672kA%E>$4-+qN>o|bm zH1tXc6iA{wzkHFO&FUW(*X<@l1Z?5w%RU9rjEzCUjTb0?k|Sm2h!p9VtA{~HlWH@m za+c`ZrpBWMh37==zGCAmcA??d_=BR^-@W?vg6!q|&;{w4-rQL8D~5^j$hv?noiUI^ z)HO9}HmW*}4PiSL;g?fBQnop_>{)jKrIdAWW7P1SS)|!8{m|qRnYRc#4#5}W zk`Wn7f!NAaF&>^Aom}jIdxmqchGL){q5JRr^F9hcf?lO1TK2OlbKrelkuCetE7V5T z#m|4@Rr?r}mt3`vEgxBw-%RU#-8Q32%ECgYq8H36PNOTAo12vwp1shuZKo{0%Ybkj zV1tcI9PeKAduTmu61UoPfqMmZEn(hJ7x<3t{K*J2BnUm(KrgSVY0!17vtRcQW8SiQ z@mUk<$!^*L8YkgPYTsYDD%?>mDv30(;r&p;r*ORnJ}5q3T|sQMg$PMo;V5F23PoRd z0=i?MR2l-s1|~E`fD)A-&bIb*u*3ZOCuHMDpk{ayuWY7MHq^*^W@As0W9)){y^6=^ zOo>WZ4$JrborZw<5v-`zz=!18W0o8=VaRuaj{@F=6HJ%ogIr`~|8RU2f)n{Qb7II_ zy*Qdq;hG33;iB1H=&Y^^K85$XN3o*yysZH9R(HL@JF_ee`sjDsvP%s|t+CyObj!1E zC5O~3$xM|u;_&4TTN!xI!Nf*l8fFHD@VGdw$#vj}mx7X9`^DMA0(L!+^aIX7&-?Uj!C^v8ZW@|NHB1~xT=y1CeavWB?;Yy*>q(|r`=0^s7N zI++NwFB6|X|6_TP^=w51%!Z+mbglrJQ|pkGp$n?F~f18O7Qh;IC2i&;++ z(@06Ck$+B*>znvR5CBzv_g{u@iF07Z1six9@c@Nn3ng`n3L_{4>;7ad-Lx zN`HpQ6`-RGd6g96x@UzaqX!gx(Cm}TlfsFPCS+l~toBDAB&_*^xHBtDld!-ouko;O=s{Xe^lGQEDOk(I-rISO~D#)`; z{Z&!;4ON00*%i_4zMWL1*GC3Wcx4ra+Mrm4n~bKQ{frtR#yJw-?l-E20-#3lT*Uo` zfIA~}|DkvZVdOPNpx|o_PH=c`-|uQTqTCL@jKR*%J8Thu|9e0oYmjU;VpeeNAIqYn z*r*o}!pU_v)2L*XzP70g%g+)##_J?*nqaqN-DbVTlbic(z0WK1c<{4D-L#IoZaHj` zVkRbN>tY9WCRIVIgmCtu8u|^loGDNlfiTBU8b*4WzS0*#!x&Hpypx7;L2>;2r`w}7 z7wTQHeLQzoC5wbU5Y%;LMTSYdG4IwyQBY#uy8cw1(*gITtl?aOS?WWO>h8FUr+V%^ zM)dliIzn6x5YAr$OjQDDR_Gm;)7?0L!kz!^RwECJV4Wo8BH!FFt?N|O>nS&sr^>$h zssFWhWjzN~ef{!-g~jg?f_v`yva-+|#cauJAB1A_ExNZm;D4CSJMbDf~DJ;P*} z7fLBb=}3Wjj?s;epZ%b!L8!H&!1Y3iZKE(v9Yzd3B**;ev3)qVnKB%C}C>hdc5vf$n5B70(fvoq#TAM~Gw32{`<(>@f8H4c$*HU_uojw&kUEd~a5w`B?Ki-E z`0XpQ64BazYqPpF}$6~2a8E3~GXcYSZ9o)Ti%SemK{u167*QJTCH&6@XIJFd8 z#cVM`x$g5MG)~w)Ty=)GAB?DgShE4OIeS8jsOYLK?*gR@BRD!VrDYr{7J@p@;2C;73ing2Pu<0r*E zZv*;~3&FR^48jZo+=;_qhbuNskKaCoQ;}UU$jit?a#!Dd_1y06F2{pf!BPFW=_D)o z8WG9KACagJx_(%&ET{&H&!649_~Q~!%2gkbpuAF6A{5hMf=D_Zt0IxJhK2xQy3(9d zQNs#(%&D&omh{-_ANXhJYQanKHT-fQ%EPPdD(V=fM?Ek`YT9}L_E+Kl|<=bx46 z$AkHc7%GvI7&!aReILY4B6>obpQLFcl9Jr6Hk~p%BGT@yO+L6j6@wZ+UBbWClE0_? zj>X4-WFsD>^XV(`o`FlOD%vuQy1|`>AI@X))<`P$O?@MbYEcKqd8P1NMws{nFw3wS zd2oRNS@|bD=s`l}`;-Ce>M$M{FdTC~Fb}ky^ z4+=%^iKAM5TFVM6qO($C?hhEZF!07iM4&z=RE3FKVEQ7S{BjP4qHx?gdu(9~c9zME z^V_^%$QaiaCCgEJgHVy(gF+%#CRo>cvf5S8=#`??OtPi!6bI>UFD}y!D z9Y*n=Zxo5^_9!vXq6>8H1YyxMrbcBzY|vI9ybJ0Ka}Jub^BJP9O7g2UMz^V*2u#|mOg6uT*zL{Dq z6Jce!uxVwUmfHAu@$S5U4U*6OAd|eML!5rW2FSKbuEM*eW|K6$XhGds@F21V;CdfJQ`#O41L+cTkgC0_L~EXDUU1CQnIk7lQ8GSz zYag|13SM$okRs%xs#jL>3U6Gmis@f3bshDfa@LVzy>(twlaPa}Ej!!WKxIg+=Q-OM zlP}cH4!~XEnHCPd(cWt2sbh?vbrh6I1^s4x-{wDY)DaxGxR7gTUKdvC)4I_dI#hF(q!O<$ z(q~UDcU8Y!kbtL=_1Y$=zJHoHTUj|5ldr40xKgrvQ^CfjvQJDJ)uPzup=gV6g5I0EQ|k(|D6SDjk%xPo(2D8{Z-$zx)9ox@IgnCHtS#;N(@Z&-4fCgIunKrR-Q&G_br`?_LRSC=!Ss4h_Rxy%<->;Qv1Xvlo( z{bwa35aoeV-4Faiq}VY|zlu=sa2oLK+BeK^swEeA5=l$o0$F!+89HVks8Z�KEnX zbfG}k$owZwX^$TxG-`Wf)!h`p?bPnWIWk1g2z~!%prp6p35+&^ENy4e%1a(};j^Hx zL<+PgU-AP?&u(^G?;e*djt6NMyvGbp2fV2iHt&aO#bN~r@%+(zsT?5Z?)UIN77pmr zJr@$E$tWQ#?Jj_{EFu1ygg{ac(ip^|bq@!f0x~_?7CLN(K;Vbc92om_pwpSwYS<&8 zg$p3R>6CXP#3feY&fiwQ`m1;Ly^><~&)Z&%J8~b4G5xVPHuFo&y}R%(t;|gQZ4dau zUewlmr&1Pu%{-{#H;bTo&D!RlPrhK8Qwr8v<(6Yzk{88p10UZ7>P*l+wtT`ltK^2+ zx6e60-Yj2Py@PI7bZ;)=N;U`fVu#i}!tXD*xz@xo@fy;e3Mfpz{ba;)(8nIX5b7ba z3YC|Tuz5fZdKn-kmnql>@x{#ox@Q)>ra+NK~R?Uwe zDHZ)vRw+8Suxel@C0d^T^(O)|Jv(p@{ON*~(dEy08C5x&WG78ikldIsBnVCWD zcT!}@lp2#`mc~8Bp@pcxxVq)a;2tZ1qmi|;>Gdl~>vBC!XTDUCsj0!w3RRtf7VQ6U z{c^M36X};`s%M<`4{d(07re+HQN{Ice{2FwT?b{h(*kN+&AKw3P+u>M#`t)WHrn%m z*9Lvh$VM302f^2>&dxSLKx$R@b^rF6tq!K4scg4_IkQH60?L_Wi~{9AISVPCmkmYX z29H_aeek9G4}c&3f6TlN8$&BuMk~)UbQGf3Z)QkKXIG+f+5XZSJ-9C}wvnCtFk$HP zfSV~O!N7!ZddC?mbi038Cl&K_;deQWMwDNfj;!`Q@1%VmFLtXWUZQ;a-dBmfQC*rZ z?-?3n7O~gLWsP&^0CoO|jZIEM0tA5n@2|l^yXiZlrOrK>^i2ZGf5>)6)(;i1fAJ+d z))vM)qU8rdrNiP+DM!2<=<5}ec-Yu^nHdaLD`62nQV+GL0}%!6q{P$Q*RcH_QuP3d z`R|{^G9YC1gf^+9Ws}M(_0C1AZfD}r-dd`7ZR5_iQTg9T2ZmBg7Py9w`1{a)?XOBh z>S*BNlC8tW%9`F6`Gy%vCFV4D1a(sd6HD=JI3;jdl7UjB6G|GP$zXmKcJ`ouH*$N5 z5Y|;O+L0Q@R6*UbDDjkgVoSUPdIS0M*MJK%BEEJUg4e|Q zu8firep=1$bQ0)XE^6u7oP+Fr8lvQMaXZZbWuAwwAd4pv z@9pk38gF_5f{U29sjh*?PymA1%^L)f;-urJjRPvJX|BR|vBG&zFgc?p=2Cyl`Rj8% zd-m1ORX*!|n4HD4rdwz~3;}$BF5s4iybI6poa!gW_F6~-a4uYsg#k0PUTdAqD0#W zL^SJQ(98#5n3BliRYY+1)#Cg608jw`?=$pxaD9g%)edXlekz+>7kYow=R?1Q{Q9fh zE9$IM*3wtGets!f*jv~CW;;Cf`{)TKlfSQl3&mghQhNd{lx5WvDGm+}BAo|(;Z0Fd zc#>r`>wyaoxiKi}H|D=`|L=8>X+CB9wSWHsO60^-7}im;B>N=w8M9yGwv}%y4)$Y$ zuqVqIVb6YE^h5Ply$6NwZ@vD^DYw+2X}L9 zLI5Kx>jWGWPNo1>WM~F2j7HFr5P@@nG)Ij{PJwWl(u%j0K28lk&Up`Gkl|T`3|j~v zWNt3|MXLQb$3Eo25l`Rnt&56Qxr;6Ym_x^DO~?0~Nx;X0&W2@b@uNJyCD zbVv@;X&nTyu+VH~6aZBBZ-l%2ZJTXay%^Jx ze%ck<_i#P3{~PH=4<>(ToN2#$h0^dfsB`LywpyT&xHxdN3yl(9UzJzQi{G2qKbG_K z6bEIP>F(MTm<1-wbJs-!Nx-&GZ2R_n@!6-;A+$rwQ`+Go5}(lXvZ_O#{Vpd6$vFt( z{&u+CFga4iKe!+&o@#K9m9zpF7wyH_8VVkOO~FV{BxLya;Q{0cNpc{x%_RP%Jwi*= z-!|MW(qb#Azz1CaJzo8h6++1T+1*GvjPpk9;!#iu2|J6kjW|U^4TYV%`=*jREV|{L z0z-4?ghF<(^J*OoAWRWc8rS@@o9Zbvbh`znk#<=E1na1r&^I&eM|e_nS1koFO5E;T zeFOYD)pEl-f9S{?KUMKd-Al;LVsH(}0sTj$`S3KTg9+05$O_s|HwV?%X|B8_&TF;0 zOv3@a&C1zl!1g&e_uEE+BE=${G?lc{Z$q7XMQEXzftKjsuEC_pi)4{@pJ=b}v z^y&DdJZ5Kp?|*^;;i>9QjLmQIQ+<8!az}pDLyUi8a?_G&U2i$ z?g8e8bBNpw@{nu21rO>dydu{P9l5%!=Pb-`Gz%Y4E8RXEg?QrG zFZ7SEM%m7TJMFGh7rg8Pq|xsj4$2!^HymjN-(F%U^Q+!dy`}=(kXIGt^w55^{*QxA ztt|RE!q>1KZ+bMQb9%*jnhv*#dVb_Yr33ZUZk_I+m0>5%N(>O7fn!V>MsU7pvwA*2 zVwWNLevP559PCveRxYI>HA2#|7H~N5O?P;5n!4^LpeLoK{p3&jUFpi;iKg*7kQInu z0k6HJWIi1qNzb;^-h~N7f;CRe`}JuLd- z261IDL_##Mvg=`eo80t*?7_?S$$&ejGN~&)1GOZ&=P~e5djEt=)!2U>XC}7r>6`H( zz&%_~Y&8h52z)hlcueQjt0zaw-6m#h?942#A>a8dJvH7Hz)A$%zX8btP$_m6pocb& zXbpi1F$rZ!h~rsO-4hC8tQUU^OiZM}EM~CUoqU~1e52-axa}=?&ptQN7#UmAZ26Y# z`mh&U!0~ywz??OQH8&s!5pse^mIQq6O>5$e=T(lbp!G^R*_Gm_bhu?sIwc>AdFx1JAqGL0Jh^ z?T3K@$lz?%Zl}?=*F%{ietckufH~hQtooi5)=7cS3x2gJ^wOX=QGrw_h-evtGguv^ zJv!@n3@lJVaDYscLD0fKZDn|&pUWMEbN^;d$I}O<7-%16*1rDNUeRFYPy~kr&{ic$5ca|PZ~U+_7@V!)I@?O;Jy^#x&xil{Ppjk zLALD`nt{Gu)1$MV{lEfn#DHr0Ni0f1d32mP-wiLY!oN+n;Prm|%T^%AF#!y4p45$1 zztnG`;g~wA=Z!h5P(DP>#nGxLP~3jm--Caa0J{3W-+n=tHOoq>-OAUESllBr)%;|K`Aa9nOnObvybI=wGqn$cW-w^!oWz5ZzL zNviu=r*NUi-t60ZU~vHY+|d91`nE{g?Gje*_j0UbQwHi`3)RdiZW6fhIM}oO%eFjU zBRb4cnrG;N{b8-2u?{X04rS5J1jREl>qHP&K!W`Qh+>sslU6kuD#F6WCH{x9@N0L9 z(~nMoe_;p&CL|?B@6Q09$JeJhMyP1QmQobe+{I$M?peI9SBtAholT)9m48(kf6~si z(-VzX7^=%s=VWwv`CUyMSW-YY`T&kTRiL6pw6(Ke3u&S7N3j617$Ar$s;bC%-om;q zNLpq<-vI=l!evT3>p|a z`JLYI0~{dMx^R>JDp2!M%5zp7lyf=!a`}qdeA#X5vbMzcSfpHpemMQOtfDS6opyA) zU)P6IvMwvrGd(dL4MN9w)lBLnG}{fp-v{x}^9G{X-uAMxyZfJuDv)6SqAsU$dJ9tB zLTIn-6+nC4==)$Uq!35r#aO#1|40CXujUyijYaRv7DEhQjBZJfcRUF)seCV>A~fBM zjl_KEYG~fDv*QeOr;n1Jp3r~i2lu=KPp|x2XNCWp+B!$(T@!H>VdF}{49V|s>O3gE z913jY9p2-yQED;dP3SQ``guGqe2B~DsU245@Z_Mz(MVqTZdvr8?nlDAn-9s`KH=Xy zIT@Lq2&7l}gY%^m)pWZn9JDZX+6pmrs&{aH^Rwx+)t_Hu56b*e2QNE>u`UT4Vn3g^t>*O4M80ezg1l~d)6lV|(fG<5h%VyipR~~Q9Ml=9wK<2@ugG+~l%diPaWb?OJnviw+(7NfDD2p1buD=z`p?aB zp%?!VL?2J!$ax^bQj~b_Amy$P7T2j=o0hB$n*CoF4`Cm2VSRr1&dNK*t2!hkLrQgJ zpgVqx@wE>l+;MIHBuJS7ZjU!L*$!`_>{TefaSLi`R@U|exdSwu%RFvRu-vbT98i{E zR8f)nB+rfdIy1U+99k$>u(^J`pOoHXm14$4T2?$OHV+MJ2|)E0nsAVdKb zztTVDMdxh(yI8q=qsOm1F5iay2Fj9u`OWXjBB8?Q)_FVBBHfRASLtR}?$BvsCGTI{ z*|>aO{brmhK@g+1+O2-wU;;GvhKlHs5Fsuu!@h_~`N4lJTU+B)-Fk}V^CABB zvTq&a4#t_|Im5>;^?I?`TNF<2xq^sJPW%YNZQl_iFtC_`$&eiLW+7Rgt9&Ox;P-d# z`W~AazL#E|p^P$B+^1aEEjVX6lkLu4j2MRo<+tQ&Ubd(nY>ij|+pzp_O_;5IpS1Kd z#T=)JlYuJS6|dpC24X=oWP%SSM}c5g32q1Jn=+?;N>kuoF9EuO77iFN=c)P~+b(;g&F2D7+@xh*(a*?AZ5bUH1VF za=#DBQ%OzjlY8$QAD2?OHUX7=hgK!>0uxaSqcq!pk$?(A%S%?CezkU;r>hH28$kxm znUYoUg&Aj_flY#Q)4Sd?GsJmgvE<5@}Uz%VSEpmE^ZdUyZ4rn>i# zoWb$~2q!Q9&t6aX%X!L`zMg{}Dygsco*F{NRs5Y-CK5c-rq9%0WwX4u=$mQ+mmCKx zRz%iANT90wPn3aIQIQ!b(C{4Z@dQZRew}p7G*-a~eUqS22Dp=T|f|HbP!Pg`OCit7IC;aS^ z&%e?CYnNoGd#7#t>5DPf_*SJfiHk9JUfXCY|G7HRVYK2py!v+JJSho*fHew*z)z}o zu_bTbVS|hjibB6ARjfS#_0o5>Pj`Zm9E9@1a+1~6Jv~QaO_Y8<)=+{m{)=L) zCq@L@320xMH_~(89t_)^Cw&ol|5=?&{Mb^zlKb@}sJ153zZkKR<6fWBG6KU4JA|Q& zxa9daKLch9_+pkla6_8Cd}q)~gFVOk}pkqT}SMGlDeH!g!F&Mu?Fk5wj8l zCuayuAA;6Vt5KMk-$U|{qtui8ZNPnjE%#TUl}EGewtGj57q`)L6K!)^R@FkZ#yNbk z-5P<{*7^=HZy7hLS+4LhhEjy=C)+g|UfwO~_EsT@XXp3!#scFW#JqGT-&x8nx>~D3UOqtO_ax)ObL-VZzA_&`c+Q0F!TX2o$Df zV8zP0xCp^RX99_=cJqral5Zo3vM9LNhU_MuKDPoiTU$)}@{O$m3cvI6rwCT`X^;iU1cVUf6wHRlf<9gr=g~Jno(B4WT!SY`KCK zYUT}fjN)8wri@l&&X@@_W}8?*IW3O?$|>F^!cyI_Qc3OV$0C=4%y8y7P!`N8_Ah%Y- zX|$}c`+DCAbp;42QjD4>&+~XM^+8E-FO&mr^6sK_zuzV@i*dg0GcIcV5HKm(PVkGH zgW=aE1m||+Vz^SPKsU*#`pJUuZpbS9W;=lQ!NKFe_v-_<5b)X5DMV6T6oBWqZl;Bh z6MTouXrgtq`e@^`GE7N&-L`r10&NenvvF1i=APmut`FoEs8fBj>iUePw2yZ%(O}($ zXTV3Ts-Pj_STz4*Kl^(%Nk%?y=}4IsbG;!fLbmyJ`CUp#Oqa54lB1(rphl|P7?6LG zd<=A7lt15!CvR?l^Rhg?=5XJtSX*%XInfR@VS%2cUC_SCegc#VExDH!yi8xq#q43_ zC^TH$x`O#$5j%yY8{Lbsa^P~-Awm25?c#P8_gx9%UayG3Cxxy2h8Q58Q_y?@{pHK! zFViloB~9|N59Dcar@Bm208K^#Xn5d#h?w8}24|}oFmGmNWeorycK<@|?-h!sh7%up zrbTI#L!Qi4dmbkP=kn`3qma;UQ87*K#UU_}R0MC#j_z)eg#&G&m3{g33?Kl{QVaYa zJ}}r$T=(*KZ5OeRFw&BDlb@ZhwcCK!j3F@?7Yd0}82ma-?P%zByh68}u1tRCMe8e&@;!IH! zKpfXR@W-V*O}wQ@aJ5Fqrxj=Qu{2c}qZr=y&!yK3!}sDNBa6zP$`{q5MGXfFap9MK zzPGJ-n{u{NL4msKo>|Hr=sqFk2@rtC&9e_H!n`0#+*iBG6FzW%bo;4zJ?&2^>amV3iKb!@aThmDPTOtgWba{G_mnAp{!apzrrsshd1yiYxkpee}* zlO@5cQWhEcHq!t?5yaxA6eJefes}ZUzgH=*g3R6hUAeg0{?HfYIqR=jXsFxvbligwd zO9TvfQc^)hMw5DjG+_Ca5X=I~wo$zQl^bYg5Q8bdp5!_*>)V|dtAf22?!ql}_0fas z?hhJo%U+mV)Jj_rqgYs#rn*~{AB-^U?@=h-I#>IEgXZFNOD27NgUo&C>bOGB!v&0| zs-W+|%EI!?sbT+qNFTTiBG8F5;Hl{}L0{)hz)uw`1)PCrU>)1$H(@DTD2Eh~i-(st z0-F5sPe*bqHb#M^$#2#Rmr;S#&s^Z~*-a;yl>nay%&byYT);95ESFH@g&oYV)~B9y zE3e)JIr;dSnU<>BEEXMaZyP7QNJON8#6qooobRv}K0~Pi-ogQ&{Xhc7A(zNRyrWW@ z($RXKeWwkQhN z1U}sv1)d3t;NiyS1YkS6xVYHQ4j>Z^`oXpsXdK__+ynqlMTHXs)iw?A>yEQ+mVf^? zhA}3cAa?izY+~S9q8Y#lP@wkndL143?%fgF`;(FE`28`=w2B22#0Y4O=~~sbLY0Xq zu>ij@&uGXy`My+_eEdHA{BacPejaTN;>a$+jh_REThCKgi7LAtrvx| zfu+1QH}JJmJkb(pm}^&LHWl3zP@JX-S~bc|@m_ErW1;9OvFj}DyNE&TR#bvUM~eb5 zF_1|=0;AT=4+~`#6&aV$m#68K7$ZK;?`$jqQUI>`Rd-Q`IWV$UL#ltUmtwSPIJ^#S zWgvG#%$qH1SEYbxEv2V7dUHEX)P)adC~&chiGmIyUV#ZQ(R*4k)&|~4_|;$qzLc+B z>@#9n69O%#AmFzPJIuAvhlw^p7vbxlx1T>~$M-$*Y&aart$(vyZ!u?~Ozdqn zTa*Df`_c3xHWY2&0^neAaV>?53@%ncasu}gC^ws@*LilkR(_xcJNiR^J z-ZgI+UXy<2F*YEQhCZ7TCRiVQTtfoOANar{+FU!rm13*ee2W zQ#Lj>YvB|ukJPlZKExCc`k53#XPu|?X@H8}#h3Ao{vO3p5U;|EOBwKs?Obgc=G<(! zcTe1%w{WlWXN-!xTe2UZSfFWMJ=~AtapsBL*_ktEjRTm0@GkFTPfws)mpmQ04&pZb z3>U#hKw}W|NeIG-quM`uClMVH002u)KKzLF}++Kjqa&K z%a=Zs7pc|e&ovU#ejm0vM&O1>?I79KK_Hn|LaK(a>JiQ)_I~-Rj~yL!;F{FTbIPO( z`G*`52w1g}zG*nR_RQeX@cs19O+Au2{C`?3N(fUB#{xN2WBX_Cm5Q9iHWF~BrACct zn|a_H{_V^gehi>N@tK-}&&^;$Fsp{cMJrrt;b<7(R$|zS150`y(--*2k%*8Dlt?XX z#HShgOKeQMC68^9R6&E8I?Ya-QNBzP>d!=vB8dHa20G5kFZ$n-!}?Z7S{*8sE?Bo* zD|wV$#SlsxR;T#n2O8I0aXJ8(QA ztF5ur+R+P(r8%-zB8%h{SG&K}275Em2@Ui8p9|NeGpvku4u6Yp@Bk3bay)Tiz^dxp zcKC;DuxtPB*`UW~lI+Z?Z<7TEa|=U|s<@KvT~7JDJ9Lj4gu%sHF^1P^Z)VswTnt8u z|M;s@JYAdJ>ww@mRSZR2O-)8{QT>%^S@i7rq+EDfg~YT%u)YKc(Qjg6Vx?+r%@$G9 z!Dt*=xSL&|t!_kcisf_=?0d$=6^C))F=r_>wxZkZEEa_tGYfZa35Id7ZZ0N;Z=++N z8-_+iWL`LSoMei~5>FS${6wn`5kMqT?nC!?6 zbM!)A_d>95)&py0BuPXAgq~CL| zjQvB;RFAEDFjTDrz~LRx_Q)tJn}Ipuzs;YiI$iLsiKJBXMU+~A&GjRDOEM|@zAHV{ z&p^1hCVlHTp98G9UdoH7k5D{A0elGU0XOM(eaa3S8+3XtAtTLT2XZ7>+W`dmJWLDt zZ-7xCOkNw|6hLWxz}0gF62gDH?+8qx!Ctopks6|uQc>ed>dfl7;pBtV*#B0jCnV9) zG=}_t*5W(Wp^}Om*tj)Ma z14LAo^-`aZe`URg4W7savzgx4x|-@J&^1kb4oLTRND^UqMsqu&?`2N5MuYkTJ8cAf z6?68#RGn_0tLMrVYn?RXaJfXvIEMq2_(YWmCoQCTnY2MrIbpvX~W6AIQzn-fP+{Z78N^Ti3TBCQ!$Wn%D z@AwGpTW%f5qCfM1R}Xb$X6MV!v@8D2A3*m51Hpk1JMLKX#62}3p%dsBN{}l2dk-C# zdvuFm%ochn(mg^6CpMehrWgLi+bmf_yO)UeCgy?}$zz&z6Lzn(>~W+cmX%yiXWK za3pJn61>@!o^CsUp2-fE-gaBq+iG4D=Xf1w=a9I51#R_hst>1_ z7}y6~g?d&|O`$~cc<*b242aTzxdFm4R?@K3t3xRNyW@|V-329JqJ)z}K5tT;5>KBH zy*G__Uh~e9C)WAc=wFY1rgkkiFds$VwRZgTpKwEij-1G>33&L1{Y%(`e?X-2l zFj&6+BW1>kzGsrbMPH4rdGV#*6=rhpuq8jn^HiY zA8T8av1Jj_2J#g)5t04yp59Z;)^JKiWckp=UciEspLNx7VAOHoK7oQ<%v3Y`-G|No z5mi4gW8hrCc9yq>sVXKvBkNv$4MM=q$JGe)9W)N{Ft7@&Rv8!=S|%b+p76>wc)|y1 z$qYx4w+1ijuhL#9UfIRMrtHb9HT~RqoQ8d@_lpU2gDlV>;ESeOo2Uce zatF3b3v87s>*pmsU_%7W4#@{*dl2kL+rKA#vwrz_DM zI6LEqBR1S#nR|NfY{9N5)gdp=GX-=t?D%lT-S;a{yhPX@((JKijp7p%L;m3va^#_4 zn+xXeRdGD=9KCJ$4uX;;^>Bi?P$x9bI;=%H+qjY0Q$k&c}=RL~7QDV#QtK+ZFw zfdEZh#$0jyj}S)b{;%L$YB(rS*nJ%Yg2E0h4Q3Z(GyCn}>f?f|&*+u{(Ot0dMcj`e z{R@R8AKWPLZ-B=&l>CSUD1_B4w z6GOGs_x7`DNTE)Jj|BojL6+UCFbf6jYa7Azn(!=x&rI0;KdIE<=Mme_D@wm(_N>2g z7`rALGv8rhEpwqb*#9^O;cuKtrVE=fD?Qs{)4WUb9fQikxl-)X&)7HmXY$^~4k)?X zf^MXB5KIB?03Zs*%Vp^3*6NhBfnU8Dm_z?dWfCKhSY>kCtoKiR+YdroZ2Sv3Tv)>a zHrgMhweJ{{7tOIcZy(cB32FPKifiDHkXyZzrWO!b}lF= zFE4A5Njj8(cQ>UPa%;$#Bz1JYUoHe!TT()(Os)78+M7c=dMD>$E^8}gdtzs#k|wj7>UZppOkhRSv@PT_ zR4xr!O`JSld5Ypow{sKpDxw zc|(LtU>)LLX9~?HsP`GqQjHGY1thkb*bC&7^1I2i$;zH*%o-RYkdd3Rdd}SEvvx3N zcUb5%?(~d)yklwX22)K+v^;mK!xt5*#qX#?Lu#gR@*=C9yv$6G2SN2$3cCgp(P144 z7}dPL3nayatB=*rM3NCkgEUs>$i#~CrPR!gw-t_T9x8K~^6zFr{eTa76rX5s$AUuK zH_8%MM5iOTebiq_&YTr|!yCZ$XN26Lj*C3Q7LD~#`dvU^FJ2bNNE?6{3VM7k=&8Ug zrk)H~QEy$#suv)DT!lP%pvK7xYI)MLE&H2%!VI*XcG?I`}H3&Z!4xEpB1pK zpJU+5tEd(aG{@$q4@$6d+2}5x<8q}*4j|aMjz`47ZUo%PaZ2H6gF8ZB zHa6z^;Pz>+gtQeIBHGke?ef39kx<@vxS8qSo48Gfp`_b3kw3ONu31@PYC*TRvu& z*OJQ`@Nf^R>}2;^85xx0QM||VkZ`eX7#shc``R^RKy};RDQ(a;7>Py6(Sl3VUws!c z1#!CEchCR=V1|^LkV<-m-je5h+UUg#%c!{DYKb)^(3$nISQme>XwaYQ?(THkdyEp? zCAUzZ^Frc+Gl*rs2PlV{@G77A?hYbngmAa}&k}=RKDbQ7cj6nvFCm(IS*lBUToic4 zc{6V1aAaskP2@&jK=_yO^YyI}{Ozru*6Q~=JT?e?Or~1W-1&+Mb-I?pXXv49@Zs96 z2$=vmL-RlVuR{{yVnv*7=&3?ZVmg&P49c!_-7>Eq)g%@dC~;#FNE@Rl#a@&1i06rb zQUi(aju^5zVn06|=->DeHZ&s~>b*{jru>D#J*b1nUED&bK<0R}khf-R{61#$klGyb z^1u!TCL^fNyF_(9e@N~xN*nQ1^d#r`9q&+?vW;tC6>zb1`1YF-S<`fM|IPD}YCmgH z3-4pqNA?$YqP;z|<^me@9}cLdpPYbzK&f?XCIkepn$kB=WCSI7hPf!q%UMQw#p z5%-ELiz3Y_7#}tK{xK04Cn0QMFhYTm63htU27vbTfBO_&9Bis~+Qp$vrqblE2$}2R zRx+3D_tvvwydD(55f-(;U^W@0_5sCx#mxD!oP%0d0_sz1UGdqkkJi{KCoI!hB!^2R!IIf@onV9MIxfp7nubaLHebgd zOb;gCA)04acE%7LeaQ4N5YN(;ebl0gC;le-5&q+OO)I-RZxH}DF$DS)WiZVF;Q_%g zOC5>6o}MW{B-17RGivbOp6u|Jw*Q6Z+kZnd_o&jP7!A~|SoKIkEBa35gs9TKzYRz(W(wNg&$mH|#%;T*JZGL{)`h${{Yp@Nc zS)e8O^I%#4=s`g++=|br4HtUV&G7c2Q;COEC4Yby<+imjUeR#)^xuV%@drADZDMPz z`j+fyw?vljJLu?IQhcHa8ouqK;a#gn3m5;aB{Kz$*`%WlcjpTsqw&xFzg`42)?E5X z`cm0<*?jsmn4_48OF%fytfDrq zOvz-pT9W>FB(eXUTkD8qd+W%T3n+#V332D!T*n;M&s}dKDZgD_kSqSRL?YM&cR+-C z_Up?)N$@?^S~(DfWCsLk1NnEdc6WpJ%PTGNK?POeJ?q;>`#jvm@lRZFG6@wzebDu^^jc zN*Atg#I_F6*8qfp4h{R?nNdpVFh2_>7)C}%!~8o$oa_$%W(ev4t%x*WL7)Q7Q2%I+ z%o8l(r0K7obb|v*fRa=-&LgsX4$rrikk+&`z&1BYSYV;guPf_jo;{paAgb^siI<5U(<%T)%^giBICC-h^ z-SBZDQ`fk$|GM9Uqq6kMLC6*0o8#!aN8Ib(sHTlG6- zySs>ZSJ=0tja?&I?@mz=y&Y#896A@cGAZR8^nvb?dY_csd&J zUM3XlmTV5(VyliPb%Rnqo}%1SJ&fZjLcas-@2U8nwBIJ^icS8cDQb64gbi=bG@t;r zJ$G9abRN!R9F0JK08_!=*#VvJ++nUV@iK@t*al@IDz^s5NxA`FC!sFhCqvQhe8mS#}o z%6w!!^M46hLlWg)oz|@;TO?s5d0RF`#$1NatO4c;NU0DCbqSI#QunDcfG>%a~|SyGrg_ih8P;A7mf}@Kd!DagiuQkeB<>`HrqMev9l{S zF)^}A@$Jl|9m#U}ylj2?O^d!e{g`j*=SmoV0v`L2;mOEahfkMBEtwmrnIED#PnlW7&4L>2nKu6QOU?(>tW#C@c^^zduL8{k0adX9SUM+Ox$FdI1v&nI8{8(Osq#pL8V zZEgnLJqd4_H&gYC6WGlPaG0~8g?ljBlzJ^v4TAVNlfMNCWaXp1d(2rrH?iN1@-pjrhSKR70<>V+lC8SIkP2(=f8ydcECVos9?Rg^Z zjc`R;;oHF?5wV>^o#~Ns5ddnz8fEy{s^M=rlF=ToOciU!-I>ChI-((^^2PY z)VDHo8d`=l{FL&VnknnPBqZR1&C}j)q@6!LiwkzmS3{+JgDVejywzTx?0i)dJclKud+Ww(Clr+eXx`qaI0naH^$y5Jx%4C#2 z%`ZxA-|0t9BvTx ziLT?qI7x{=r9j<^0{5ZF8SSr2?-zQjH{{ND}>Vjc+mG-?w+`)*X4s{Dd&09Y&NVYiit?N0`&N#CC~d( zbl!|YjP8164(v>&6nCk4;eQ8`)D2JZFDfL0=1$C4(msiW8N^ zh`gkLCL!P&(0kT{bJ<`o8DOA&BIFW;<&V%oU~m#JCxXxD{c0CKy!xR3bI%Xb@N3Urp+0x~;kGj&zVYbZ zShUC^^53rP#X1VmWqL0Y*bY{%YIAvUzjDS^DuwJV%A0;k+<0sEVfZS$KTPg2hk327 zt?EyIDh(}-k;jNfa@PcZoc#>yTq;M;#`)F;CvMV;OKzH^<(dUx_4mhC@=uet%u zy;86?vF!m>;<7nUpZv;VRjIAp(+k74?)CUDjdI+I=pKmR2$}U{fT%(4DPeQCHM?aq)aYmnHW z^4y^QvsTA%kENp?l4Go`WB;CdsAL@h26ozf{Ic0u{S!C9kId)Jo%s7)VDC*DhBe=( zNfaPc2J~sfWfx(HaC_~_ziT*h{ZgpkN2DzJtb583yx_3zH$a;Vda3`N9mwBi$W1h& zuTN%an@Vgo+52g1*5yW&op)d^3ep`Tq{^&X%&b9z+s^Cx-IdRV9f$b69w^nRUyb)= zcgvr)R|L@zi=v>^a>(Akss0c>G?aRU3>>)GS|a0Lnhj<*_wvIU-npYP)gBq^$HMR-0fMdTXFI(^9FFDB zAG@*461%Mb*gGb+dLd7`Jy3DFJ)misSwo#!MdcvT$a5=zS)YP<8O4hcw+7t;JNy&l zfIhQfZ4hbkWUMYPaK-EQxnQgRy$u5fvGvVgJs<>m!k6)M(8;*1v7N zznF_b;XZ0o+zZDam-bo6ZvvZ;6f4!yz^c0LMTkLo-NilzZ9Uu@KN`~L_R}sJlxX9N zFl&Q7O$p(A{n!>D@bUSo`>MK`!vS+?q~K%=Os@IbCGdqDho66T7wvU`Iv}Z8(LHl6 zp><7lRW%B_u}Od*7CJ?T`i*qWiwq(F>STn{Ub;|$WQyzE zb9OM0?^DD+GBq|3`DnT7SBT?vQ=BOYPuG+Q z%&Y$)2xA&(TVSM2hw}f}`U;>d-zQo^Iz_q>kPxJmMo9tbkWN85 zB&55fOS-$eJ4L#?K|mU$`#zuX``?*+9T{U7zW04%_w4T3vk*C6a0-4i)M#Wvr!xOM z&xh3O1Bi!%Gf;Zu1+Qv>p!&tSYZjw98qFglXl3Jh8gJ{h#zUdbOe*jgvDI+>eKVlD zD4Idzk@EAI{jRg0p55zQdrcnGl0c~_h~}}W{L>hqi+}EPx_}kY37%sIFyW!sI(p7T2kFQ*qqqEtTS=`IBOpRy@Xl=q@U|z zC(@l{zPw!rwnGUdB>)*1oqM@NlPB=Py!9{HEn+e}`2whAEJRAIncF9|wB{)9 z`*$SFdO+{P`$x+6nm@e}c}k_C3~$c^`rlz39Thk5bfn&8ZA8iFSaX^r>py}4>hPV2 z`!%si!PL}nQT3z`7Q$I?fCpklg_UGfH&tVZv0}6R8)ha!YC>S*<2!yjgv$DutmuDP z!o4m|Fy%p!$MN$UwU78#94cnar!rtjPe(1+r>{gW%l6;DinQvZqtdd1_RJ6jy@+w^ z>lj=#D$1f|d5D$$zSLXGi8%`OvS1F`~8phb2nA+7KG_W;Ac z;unMRV66+0AN8IouYIZXrO)P-9`Kitfb$N&Va)pMdmDn82ZPGFt@f5H9o)QM=8JtY zc+zdOOg**Ml;BV;2>I(&%$hVWAwELH#(DFVOj1qFXj{vKK1U7tg)r48&~VoP3M-I4 zLJV$8e0VS%79si7Z=EnBzdt%*V|%nzuW?}4KQEdnx6OnK!2*M!cCJFY6k?Gpkm*r%WveNy( zW;sU{?%CJ^I01Hm-f9vH^LrB_Xd$7&6y5Qy-QE`jNM=VBhrDz`OhS;B?&|^nJ9+vG zXNO~~?>8*Ax95Xaqr7iGX@lVxhji4sN>TD&*_l!;t{yjEiLq~fr6A`EM-zO zcb4Un$8-s-b{-DOGnCUN+Nz{sH>-|9I$F&ZVzKEG@Uuaa(h~;8?wDk$&Ei@heP33= zsz|+IrPDvql@R0Cy&q~l@ic?I?02r`JBuMuSXI~}ZaP_#>g3U_C2g|NFwRb0g9LL3 z1Qs-ONC-3~P>Ap=nT@)=0)@j|Q93`GM{Kaz+cmdQTT>VOKSlz$w9bEH7Ri(h^9I&b zRJ8am9kmltN&IW9zLW&Al6ZDD$*1_#0O9wb;0)$d0)|pD)Q68B$AOx^3jZH*yB~2E zizRv`F7`M)f%G1ohksLl+cbZj%snPjT!jl{EYW4zwT9Ds{5$3!r~w_Xz64hc^HUx! zKft%$EID&FeD1V?XqKB^798RI(65*UJ_dMVtS9;eB>}hZzgI+8ccg1X12l##a~*aCov0W0DTi0wjm|x%Q34QzITr6Fxf_I#MWU zz{`|h$Z|dN;WUHih?@gsFF;Pv1EfKGzc&^ZFW}%d2hc!XIz21P2XNGLZ_i5IK{i+` zy{ycNnMn+h!Ut0jks?&(E*`vk377Ocw(q}qdXYGfIh`)ghM8Bim-i5}8I(CUc&vWa zwrvt$5)qKA=|Ac1LyqL&c=Y0ueJEeyD9J%D*6=`-^^3AP3N8wO{}WTswpdkTF*`fc z&@e;0^^%+q>LF90yMyo`tu6}k@*RC&pF^0-;MZ@3bD(q%=_rBM2`Be5u#Yzl&q`gs2pOVv7yvqB`yoxR0=OSts(g z69>x&Swm~#ZR5@6c>%sJuV1kpf~T`nAcpg0X!YLnK*=pOg*~$WLMtrM7KNUAjq^d; zYr)vt`nf;D^{801%iS)Ogv1w+WxmW|(nW(WuUpit;b0#uieyDCBG}GL$24IMU`0)= zMX|0Ts(QbH!qn8p(nH)Sh~?;aGBwUe`l4k` z+OyudIUZXt`pGxw%1+{o$zj`EYN=tQD{eom>}v~2GEVIK30JC;a2zk_5=lTCML1#I z^g$V^HsTi7)Xd9DSXzh{=C!Q>qte?R14KQDV`L=)bWKL+bNH`+QUcJjkc%ZQA_o~W z0Ga85u>0xvnexG03mDxWaHa{~L^%<)<2gymAyVY{W~IXT;WFt1Lek8Hg;yMhOh^AR zJfq=xB=;smSv`e0wu-+c<k16j*nFBX_Ua41kW&x|#c|OhP6?f7_H~#Xy__r6hm&>W#F~x4E-%?@I zpW^#X3X494#zpBZps$)!i>$7`WNHd;^To@!cA=Z!8O>9DvvC}kuE2Fm2XPQVXkL?f1jJAc`I8qxo55b77jA zHPG~V-Z3gpIecpW7%M?qNguUzT-3>wHe2nSs9L@iBc%~Wg8!#V%&RB-t^yyViGQjS zV@AYk=lnWY4!b*6eBU&x!QI?y;Kdg8$7P+$e**=Sb z2w3}!ck%`3fsZ!;=_H@dJe0EGLP|F_9PISg3)EYUkdS;S-)ZWc~q>52^L z;DITB)I@BvJGcvG(`lWvMJ^vCL6X~{&M*C7d*Bpz=)Ku2& zrI{*NF^UPd)0en{r~=+|s|TJ4q$~5w_b;oft8drvT37ZehYi=dzCa|!D~I_LwZVT* zOiUaVTe=28CVTt4v_Lq}kQitnfdd6JuV&El-e0csS!io(4}&JTfuEz9)||wU;PlY# zqLKaOOZ&#=lHG;R%dSA#3@5f<1PDz4vErP%y1E`42$dYf;_D|@fY%2ZH{W#pjmY1r z@Q+anpcH0LKVI00blX4K5!>j(Qj6UfJf^Z6v>tOsXLVilk%uZk(@S>CC9r#{aU~O1 zW;Ej~@0auV?k8o{w`XgB0NK|+l9+hpe(l2Aa{cxE_jdWgM{-Q&@e7;-D}LpOC4s~O zgK8qisiXwyx9o=KJAy{=A>eK3oiK#DwMT1O}f0x1=57anz zA(VN@2%h)auD{190cqT9vxnH$TVKGg_X-Y1g`_#5AUyXW0HJ`cKNo~P4P9^f$JKQm zRku(@j{R(6BFqryr?dQE;wE-9hA%lb5=phVMb&>uKDg*Af-^X2)mNLY_XIl`{o^9K z^aI60`^sBoW%=vm!?6kr635nioWN=SZrz^92+uLP*KCpERSAT$N@fHEd8^DOYHEEi z1p_Fjb>jSFQK}+84l4 zblEr6rsU&G27)2ssi_!kjI|3L00#_%2~J(0CU@M+Y4m3iM6153JP78$lsf)!vvtOZ zl(0q^?}nfXP2={l=jbIWEY_fojnF>&+pd9ve92Uw;4f+D!vs5hZcYe;hpm=wZd=3O z?*DcioH_O1Ft@nhPI)g^p)qJPQMI&xwvwOI|#h? zMMJVKK4cjjnzh_O&sc8J4d|E+p5EPv5F-^=y$56(SgRhA&sF~NPyk=80Xe*NP}MfL zI9S~82PquKot>S_x#^ZEQZUR4O!ER~?<=qbkhXU32ARJb9Fdv9Z!de}xM!G#Z7e;K zGy;KI;x^fS^2g6E=c|t~RV+w~eC6!yZGBxqWA8Wuqjj%^5 zDd96R&akqUZT!&b@Ba#X@nfSAFFHsA4tOK?w-C)w&{+TxU^{Z)rQ78t#h^*H#R5c6 zV5V|iuK*OEo8Gqs(u7cu<4iXX%b-%0CNq!ncL!qZoOkNZk1;OqmtUPaYp0I)T{19n zNLlM8NXHC|TO>Rw%~TjsY$hS)LV0;dV$bNm(aQ2XyD1)%p4k2_YYrGD@#q*qR&l_I z5IwLw_~w34zb|xhx@lMyh|D0JOYnAMOCQbO;AaXZ<-}3&Fv2WIH$XR>`k1f*9qAYb z)ei7)(8rjTpZ^801R&k$6Vl_+_Dd}%*3LX)pZ|MKzv@QQ`iqcahIx6%kscLxuEcLD zSe>%pe{p>mw^^;SyVm>>^?vnmT8MqD;WhMU{M)OQBU|Tz^oa`NC!=8I$xv#TvbtnN z4Q|kMQ0+L`J=z~|(FLk15OC2+8}lZ)QWBA+E3{9C9*bqU36^hQ1(P*S7TU=X7!G~ zS@1zbUv0>aB^4;ByH0A$2kb+@GYf-=33>urDmP!r7e~F}Mlfig@NcgF1B}Zcc~p z-WEP{l;b01z(tvuLf=Ote(H!8W+pbj4#XA8)e!ha2`6+UBp7}FKt{?%e`-VWv-^?> zZTbCz4uW+r4OT{PhE^a;iED9V$zsRYOB(vYPkx!d*u%CoIc;8gM878|0Y^xpi1g5O zxxHom?0i4UScE;v8qeuqK$pjs>3rzZ+c>n9O!vFzEsX~qW7qw~>AOFrr0~B`6$p5A z@4i8MB-=oc|I`OR9)vCsiT8T&b<@nrshdn#yHp{vp{&16=7c+&Nw@K*m z!z`fsK3zj@HOsOse;KaPnc4Cz4c`y?%cC!W5&+=_i1c1v4VQ2HNXw)5(@hiG8O>w_ z0n)^4C{riyz`?1Tc3(`&p)`1fht0ak4eM1sSRU$k33%e{?$89!PK)VTF@G7wsh%m?&hHDCs9V(dN*Wvoa99cn}5e;Mjut#3^lXBUqGn;NW+l! z>xPY9xh-ML);n*4_Z@vV1H;4hD4~tm44tNRiKJl)$NdkmA{3Wb{o+#-uUWq&M2u<~ z>vlZ>R;z4g50p;J{e<#d@x{TJtPkzmM#C0a1tj3y^707aFenWcNj@*U(buPhz$k^7 zWvfm!sN2gPfoG&)wM-8{wz?mZ=KygK=sdR2cH7R5#X@`p8F~lCFd1uWfv6XJk26GE z18B{|ub=-7QDDK$;H5?7?re;b567K{**@~1XB8!P>(`d))=^01w==Hj9fpXBJd?mD zw)4HGJiX_KcoNj>vI?jQOuq5cZCk1S#;HVZ>57sljV(uTCaF$=a{zUG?|*nxt|H6E zDqEzC+WFelh+I`6Wb&{lCfS!yKt_;`qU(FY*F)gnbb)*(l-W#Kx+`E;_d`JjR@H;X zV(81Ii3{`vD460dzN@f=Lh9DVw13V%S$MBXDO6DH4%7{&;LS5Z4aG zwE6)>&H<%xfX^>Eu@lj5^5lf46I=bMFAhY^&3E;vur|Z_mAZkTDj{-uPIm7H zRY61@A2@~nSf(u?{`FS11`z1MsF~4wJ9XxbV+Tmybyj$rzUN`oyOfQK$O^UIq+ zd!cjVlO5MR7-{z%aPA9h*^L`YO`p<92)1B3 zECIJGAtJPU^x`lXQ}2fYYA3VPa?u0u2S5bnrfiRe1FLGVZ+g7Q)nDDNdiL_zC_ZGx zi1Pn+U90-|@RpXvYfIAtx`PJ|(RMEq`|1@c>x)u773Zx{9#nh`T6GF!Lb>MmXm9`w z5V`;;MMdKU@I%E72_Nh@REkw#pcx7R(~XTdsy%dHa$@+(3Z-Y3NJwaPQqGT_U!p6c zvl!thG?ZXWrT=wEdar@9QiTkC{l!_aOI_{@gXDP&c=_0&9e48VZTOIKDyecRUj#b? zD_S^htu$LA+XC$maJwwJT^a$iH_+dl4E&5fq@PjzO>&~L(;wR_CPjo3w_ zsMtwrHL1rBf20NaC_euoB@jS@@Lq43-(J%djVg-0<(#`_j~FIaBe>h*Rbkf74tc(= zl#Cb}yj-U)cE23gRt#oF7*!nB-;>UP+PM-E_WW`&QC+oFn|!0ZFiC*XU-BbP_~wB# z5ag4a}@^qSV~)_;G@ZGNdawR9tHXX=3NSn0r-ynH9{nLGw20ss5 ztG1^id-^E8B$iCVq#OZ14XJeuH#H0djE-`eFP>J*p_kkyBZ) z0GSZ{<403gPP-$`%z@lZPI^cbeoRO?e+?UDj4poDwH(()OUu`=^L?_ns`WZfFJcD< zC|7pW$tRUbCb4Do%2kf$l*XFTU8ITlb8&Lbe>`FZg8GifWDe52(3JV=?Qwqo}eeLCr%QpNxf&cVy80ry3g8UC$HJGo%7c-zp=zbx*PY0uUD3Dy(U7W%9+$VEf-o`jrqhRpbXRHU`8tJ83+4g*IQR$-cpGxzt<_Bh0x2;{AFQ1DO%Jhx+32F_o|O&rqS*zt;sH$ z%R?A9G-2~m!iLw^2plhyni1s7eMA*BC=*7k<`f0Bct{%wL>-frCsGPBtEgOVv5M;O z&Wnix<_WPt#+C!0%HeEsg7=o$Q}V?(F#PH#2%Ds&RA4=)wB)yG&+iA-XninmqW&Er z>#i|d@7Py+>n*QpDylPB3^dye);qTAHFzm8ESWr`E%4cU1`Ic8SFBZ#8ul-fHm&$# z@1UA9u^5k56^;&EWgFS-k|*KR_5uioa%-!?^YUI7=C`}9A#u`3b+FVR7U#e#KCc;g zHzA7%Xd5B~T!V|gm_lOzdvCFEwEJc^riiMcnBL&jBgl&i{Zd)4rK?%o%+=K9KC&Y-Z;_pb}3I>y7y+itQ+CuRYERIU; zoUg$Br7JO$P{p&IzZ?kWeZhU-_GcXicW67IO2;J<=KY3^JD1h%6hDUG9Bn3hjFu0? zqOg^C5WM#kz{7@{IPh;IC~mYy37)3wz6!R-NeLI*Rb-&a`^ z78!(nvYomUIKA5o5|E)?{?1b%oy#-UNyA4sd(jtOMZF(e!LG^Fe}pjNHw zYUjz=O_&kCJc9Z%CJ*vrrnJGza-Q8=riau9*7bgSga8C4K2rFb+|hlWzkR#j*?fCH zSiDzcA+3t{Yo^{oloU^Tc$X#Iq%hPZka`sDh?~10(?!~}ROk+B=8nIp4k5{n^F}&A zl$?O$Ms%;El8)QGK?hI-$fWNo==>udVGa+}@(UVgsft=IJ6Xav$2e=+4OqKAIJ zwIt`7A3CPHTWCmL?VMfzPEPk1^-4cX;doIBVJlRTnDlTq$U>Wtzy59ZLO)1xq&$)7 zvKbS{4+)`tNn|1lUvp~9kV5&~=!qb4{C`#*l?X{{>(Er9&VAe7emVTuO>3#VSbL%ia!0qk+j5Z%M_e(UOpTHQvMl~Tgi%l z0|mYAC~l@Ay4CyadD?t)an;y^>XX^{tE6Qi@~{;Ut6fKxPSz+0VW@?D5=E{c~ zi(<-YhQ55_&5I$`^3 zDw(jPRAV? zwr%0C;;fXrk6CAnr|gz=<{thL%IdGJ0kAo3CjPV2?Xv=KjiykC*GEA(hiRuo+B7xI zaY?E7*y?Zi{dD6xG11Bl4AN`Rmv?RAzzy_%``-;r`|^xEMXNedprtzb*R@JtLy|{S zzagI0(2u?I@Lp!r3?!R6Ey3|)&Y}-BN@84t6wa>bineGkALdcB&P%pcv>IV+THx35 zJWSa9f92OeH>{^C-5m)5d<&a)3lfq^4>#1fjU*hHuhFt}2jBLrS)m>)AyWt={CD?HZ5lV|o`(j;Ps)+{zyi*`q|I^tEF42z(j@8I_~y*Wt|)Gc zMhU;+S??y)d(YC_M+uy$v6k8i&w?9c^>KNAeJkg`4_E@#_26D};JQC0Gt z8hL#w`AW~4(+V3W9&M;om!X8xIUobz`pEx$^Y?^x@f0V^R~RkL$pzmW`WkL^#K;Y* zRYq*RuVWrb9a9x_d-aq=t^)W2K%Vsc8)cSB?9ScV+!rT?oyV#Es?(kVpM~M%Lx8l- z2|^)~+vMO?xn(9LrqFFdik5{1)7{uwedK+#gi__#rErJB@>yKTEwYoF)P zOSlVuxRi8S(jn?=Q6cT~suBelMg{)*G>qG=g((HdI(j zQZYHDNE{q0)7_Dso$r?ipTQ-E6r_RMQQ4U00RU8Nss1mv=PL+%=HSvTO{&@A{Z`V1 zTq*8g^A}epr?pvJjE}UV5NfOx_0RHkyExtGX-CXSIjCbpl^Z-2N3w~1Y0Zx|oj-=# zl8DHi3uZr>_>ArB&favAc|j8_-=RC1GlBrNVBVXN2G7(9fASERIc&iI?`Qw*cFTy z!Y5pK|W$qi>}ZL>wc@OxWaX3 zn*jtacy`TrX;MHT7B)yx9Em~{^WU}R44mZXar}`?!h3%pO?ZKl_07Pk-pc8P!U-%0 zi<_jMlRB&0MdK7Yg*PXo!rGVmY+STWQHOI7yl{j$^{JPxZf`1morh^QB-Glg`)!1s zp3EWbZny8|u&36F;yAgn%z;lONjmcDSCD&aXa0V@bRjE5@fX@shW&I?(#>gA-PgTfxhF_vzfaW!TgUE zc`i?0lhn5sursI%y3C)Sm#)v(D7Z~CeYxGYoaqUpxuFqwQ&=i?<{&!`qr%2o#wn)t z8aOlZe+P|H5KsqNs08&HJg0Tn>u#Z>)u-WqfKmvlTW*{{YhzSST}*|AU`$w0=&=5` zJoYAO_1q;ne_0Z^@Ze6rEuHSeG!w}k?n(5X6o8}1+3kgg%MdgRv*#uN^&TI76hk0u zbJ+da-JIzFeZFko{5^ouEb{+cN!Q~d2CR+Ijc(rB2A{W*8Lo~X(i&5}(r}k}tm;NH z8Nx#9lp9C${vgA#^|gs!O<2&!A8fhI0rM976Fd@pRq6WW@R!T$Wxt?}W)d5Q4!qZ2 z_;S5^{mtjGK&MvX9e4}^c&PJoeIyMB=$cy>28#>b7H$e3&f9-@Vhyp{*wUwvR9!?# zzp1|gTm%2**>)6D2`+O%_=iPVQOTzq;R(^}>lLb*ahmvD>pu~*%>3m^CTEv0RwvaJ zVjWpQBv@JMUv=xrHT<{(9D&2u)}AvHPT_ml3;p3?F}X(NMkI=^w^aXDTr92PZB>tK zL1jV?r@OX7LWS9uwnQbn(C%yP277n;Og|2 zj12Mo)SSg;%hVClxJ?fIfmuex?Py9QSHn<)VP5Gf70MvUVT8XK;$RLp25-9DWccM0Qwc@Mw!bZJ>Iu33#J&?=N+WUt zy?iXs_)1;B&Orz!sq@=ud#X`&wY-PJ=Uo?j&+c*(fkzr4N*qyF18g&_vxh6D7BD8( zsy(-SV^jX(cKO;BJwE|UkeKn`>D7x1c*ef8EY~?UhyzN5?L#}8q_s*=nPffEHnVXF zc${@8Ck8?cW81sLocG9au6&mj=m&9Pw0wjH^8_C=Ca1i*y*@7Sg4PH>T^;KND^i|_ zjE8#!fCKaZ2j_E^DyntX9iz18;u?I?QT)Wk(J1Fk z1_26g>~jOMhYq+)ypqFnhsol>kjG;AU4t)ti^s&l#d{ETBSH590+^X8@m>dArH2!j^cedW;8%h_Q)Yx@&E96`iU zrhhBN(#F2H(#coNuxREdmv}yPQD@#}pf9Z(?&zyK#QQR+DGf#4R7r7Ow)BNt3W8cv zZySQgChG_qY^K6Eod|SLMj9bG&U3F!3U*vk zhZFNo=GWtpDvhR?r3s8uH;V;vd9nKF?(P@M*N{L=Bj4BKg#Mg+BmTO&TL8;r9X zvJNQj+M>W$?Fj;%eK!;pGbY!uvF*rFGLu}t6Aj^Ek=tdGk2&xZfZ1n%gwT4Bc0b=I0kb8_GqUlQ}n zfiL=&Rt|{?`TYO8tIl6Sr65gfR>hTOU8K zC9%R`A)TraVMAmm^o?oOsVdi>gTj%zKIP_(1#_A>rGpiItR>Xhq64Q+!Gy)?DQV zTVPyhHaE!0`yu_^s4h90gb@{0jll|Fg%{;}4-$YC9z^1de~)dbB^NMGV~BOmDeK=r zFEQ=3yh$r}b2_dlaL)07Pgd9mm}Do(?csU;jG? zl=8g!l};bgEG7$_SQp(GBvpm(V@MgOez^8WmWSDQz|#NE3WF#GszD2M0G( zQ7H=xja_(K0GO;|Fpjmjzq4DEqN-GV6^w?M)h?zuQpX4WO{&;&kwiKcasl8}#B*_* zgAcbQ>6>al&o;!bzHuGVqf4aqv+z?A$@R?e>V_xgyNf{8%st(-90#ByjaDVq1#5C$ zREA0p+bWn$2Ey2J-pT(RWR_qG2DtsRH^k+U2}5&yb`0kGIA_N#*QB**4dw%Ms4TV2 z1FBK1y=Zl11u>lt)i+nm>qplQJS!u-iwD#Zd{6v&NSahvP(4IAp67&Oi9Yqj{o-Nh z7kpOQBskZd3Z@68!^7lxLWO2zP_#&ZT}6=mi%6`6^G$+}r@#%c$Y0r)p$9~U#zkr9 zXPcLC@9{UIkGBU%o_yPv#dt6?W3-d~1s%*Yf2%e*aj;pZ`@Rlky*M4GAy)4X#v}OyS*D^sL zop1cMhhjJA-WFxxW?>Z0gNBXr2qj&IG?)wN8zYpZXis@TVWWV*$>Dp*pa0VSEPE3mKyx<5>3{kD zXo}{m;;U}w1$&L0*m@BPX@b~xU<@5draILjbGPd2Gi4@*aKeA)j9w4nSnXYrxMuJY z9945*5|2NmI|?w6giDg4gPSIc4bD1}%koqtyj zm8+Ba(ny=Hzecgb)&|bhemdiVc{i!Qgw@nG+_9MLCys$bIhs;wo)lBZa{&)OS3XO9 zyEjYkjtmoZquJ0VGEaJr9K@0A+;50N%X?OMcHc#AY~VL|J`P*uIeE9UZtaq)=5si} zR)&N9+w>E|g$_MXB7^rA44|0Clq%Cx#{sv!VlN80Z&#nsf`@z0P)j!;DpY)H;+Wu( zMh$px97`|F=em0YI^}Vd+akN=D$1OG;(Yh^%?W&djkKm;#eJm`ExLn^+h zp!J{Z2$nM*U_`dr1eSC262np{x}<#xBmGiprUmu93Bw8Ig*&>SKF_tVffP(3EN{GF z=dV$kj!WN_VV(lRq1}+BF)zY*oog4FVo$q*g1#Kih3GG^s0dLoxIL^|N-+6g>Z$zm z4yXw+b>I5d3eK4evTCw>%umKNz$aUe&DcsOhQIl)FVLhbB4sU>nnz+ORF(K_muJbd zYSi%Ohw^x`S1YJ1$MYlD1PLZ_z65uVt@nt#Tm@a>ZA5VU z!G+EQugzt$Zk&=_ugXQb;I$6ptuOR+0XW*iRTZu7MVWG8DD!lS@ zswm;<;?5C{B4cAnP00-F2YB0`|LV!3mf3cUkkDal@M}|JbKqv()V_Uzh^|CQYQ2cA zbZU6}apsuS9Qt@znjiEAQMyAsm$2n@W4M`T5Se$9v4`$?ByuQAVVWWl_5hoA;c*bT zBCTsi!{hmuwy2oC$;F}u4dk(w6J4+AAYD3Jg?!ToN zE%eE$Gek)lCq6mz2(61~VnK&Ej3Fu_9^ZRK2U5GZ0U_^eVC!I}sCBrM)pxm{A*vp~ja zNmP!)896Hi=3^Ui!pCZ|7E@jr>b`zZ$%anjco|Vt)MoQbe7{|>sS3N>5R_KG?3|%) zWP_W{Dfq9T0tFJ*cGPwZrsf(g=~aoU+iF+I%A44Jaq9Sn?LQcPd0R|$%Hd>Fr=Plq zb087(gNi~Cbd>|xbSu;!!@evu?ciSZlir#VL+@zaBG?V{3IPxQjrSpGmW)3BQD+;A zA!@p5$v2x{y$+j$LUNOT@PruiU(^yRJL>aoF4BompRoQ6xn~2xYVH~jYvN_MqjMn_ z?_n|UMds3{b^CxAdxG4V#Lb`l0I6KJr0&+IfXU9Qi;kaB#)nTf zr2St*cM(yzOPi+V`G#|eUkj7?d~dxM>2I)unt0gNel#RNR`gP~Ehc`Jdp-rMy!Fq& zh=a2uS|o$Fp($=lITq7+8!-~*cdy9SwL%~y+*y1-jmAIsei#zvF3Pw2JMWBy91O3Q z7Y-0U(X-GEfQzc6eRn~cQjqseV&+)c6u%mEN3y{};kITrugNAguX|YVx-Qe+Atl`y z;BHi?_U<4#h8V4jQNqN4!x>N09c)M`%z)a7J@Bw<+{o=h6JTdQ=$I<@qSZ1*h;_iX z`Wq|}zj?VIVl}B>m}e6FSryM%{OE{>$HCg1PDX>rzqcM5Y9|vZ#WY`f1h0nO-~8B!1#+Rm*W>!MxRb>6NOfC-lqb&1)K;s$DnjXs8Nu{Rm`r#*jZaEV-ddW&ZL+m}@O* zdT?_=!8_OBFsL4DGQknGVN{Lr^TEHEZ+*x7IDVzlr9xk-N#{3 zNNq_I4gL~0N$RXKY$eH_7`GHd-QopqtSKUh%|U+XZ)Gol^7ZfD4GSm6Gf&NG=4SFX zh~1GM_#mOBb5xk2wle@#;1Eynb`tyZqG69Zv&9>e)Vv^4_8Z~eqbC0OR<#1DZp*{; zU-t+Bmh*o4Yr?_1}B$7vYn0BJK#gEBUmDDPP}J+=tNsjqf~Q zXK=5Be7ofhA~`B7{VVV1k6IXPRGSDQk!Y7^IHEu0Cc0*b4tgVPy>>`C{Xz&eNapsv zW&S}sY!|1IqI)e8DtF0(9*fD#gSQTXBf`?!D-BJB4ZHGZ;>L415Xy5ywS3#&+Y zHtsQ0FfBX9&0czzsY{l0t;u;{v*{$n23xqWu6yVo{%RaG^IJ09mCfs~WPBJgkHelY z(Y4DdPLwmmXB~9ofjoipf{djR`4(O)`F;ipMklC;`e5GC2$1|{0}#Xe>B`E_cQj1V z=bcg8nK2aSvc&A!X#>nMgNmX4oKW1zwKPfMPgd+QDD+i($>%}8jc&>Gwak+FJqMy3 zx2anVxi`>5^(Y(cGI~O8e!)46wopXZ%+DWb*6o?$O&k~Z z_h?=Li^i|14_o@AZXd-`OIOEc;W*6^SQ0_R%Uu8{1b@a4L6qJUfAW-F6b0sil`ZTY*L-${h)QZ=j;=-1cMQnE$o27Y`yT~s!U!_Q zIa|mBhK={<5Pohs-pU4@Xy%5F`e%S_&#oGRi^-wqt0Fr1NoY^rlU?D7qUeIlJX6{Y z)HtDE$eSHL!1lvsA#a5%37j&2ZM|cJYJXlr?r^2Y%zR_py{Xfk6M($g53xk zG%UHed_T2iqqyWQ(Bg}dZ2}_et}@C$k18y{{w>nX2rYH08ohU3VSFst72v>k1LrLK z^F08V#BtHT{XX6^`>TpG5Vr??)2hB@1SH!NOm%wHi#tzZ-YmX|yP=M|6R~MdkAK)4 zW}V)8_sTo3{9{{ORJ+OHdano(Wki~FH?rDPkw}p6_bc~WP%<-$Km-Y26V8PebKCJ} zBX(@6RVr5=1v(F-~vhvVO0s)!!1tqwo>inrM!)^ob%b zwOX1Enhi9nvn7Q*_S%py#S|0KO;D96LfDTM^Za>y=2Bxk8byVELE&gFF}zKM!hP<* zI^ehAn!kcYM;P;h`q!F(kyfMM9@!n@aFZigIlmHb=uqqR0E-!@1W~gSLyhrV*t8k5 z88%k@n`={7^k70Qckix}McSJC!MeDa>O_#W?`gGq<8CPfE;2s*ANRGE2Uw2XN+Ct_ z`hK#zl9Vo0wJ9p{hj!XwLI$7~g*3~b^CN*mQFkU}& z_hjfDtETf+@gU*1(O=h-m|I-qNyDBj)3UeESmvgLRHv5WM2mSwQxo5w!kwF9YS`&+2~l7_Bp2ex_s6WVWJFJksZQt9 zO!!4|nL*MWwYOSP)FcsuK4VKW2=78p4%+XH-n0c)b~flK+YH6U>8Z2+ zaCfQ9OOxaQGKbsiRKR?kM*N8lIrS!AcB{UZ@9~R#dxPqTgI2C}0n*ppd%B=H_hpvf zFc_KDRgsYYB+%0x~;aPeJFB{(N3d&JM@Ay}w-bXiVfhO6hL^oj5>JL$N3#a#o>a?rEKn)~yVpzsDnB1Yfcuox_E- zFntGXOl{jUfuDNd0ghD-5jkvZ!`W-m3}_QvdJ@xJo&NBuCq?!3c_4SP=moL+Ykx{(0lFXsI0FtT8n{3Wk=%iW3CvEZ1_@Iz<%Kij?G_c2c0y;K9_yl&?6cMgYcBT@bv zkzz2E5l}&J*$Vofo9n+c*Se5F0Uj%e*J+)^N?EqLzJMschszP`fXfD=!gG8j9|{R) z0z9ez*iJu1!SPx7DIMiE)$nqu(XYZMoMqCzEQp>|xW4<&DwLl3T`o&grspZkVnc>GMhUY?s#nh^5Ud3rpP(huZ&gaC)+SffJzHyK{ zCHEI>RROZ{RIIR3F!d3)8Kl}KZ+jdLFPPS#u=ax7{SKL|^( zPT>mkffU8DlKe9_XOsU0f2fY1Za6L78%v9CT)KictdxY=2Axg`?~7s5jt5^MCAkAx zvI2>@j{+!up{ASffceU7&=9gl0gvuu^ml1#gM=Mz9a)BpUr>yVvD~QtUp4Fd`?EOiFG*F^7bH&H=eea&OmyAV-H+OA-5}x- z%sGfQ_Zg*0NPUhgTVyr%YT(=ERVBG8a-Pt{2iHdd7FRxt)7t*Zs=m(WW}n=SQG$hu zQOV=mTi-FB20pOL1Ca6&8qKTa-71(HVX7B9T zY42=~N0L=>gT3$hY*+5?9K%^2kA<;3|36&41yGf5)b_g(0VNHjO9g3=ZV-`_ZlpV< zyOmNvKth^LcXuNo-O{n??rzxU-u~bBJKs6X;EXyj`+4qK_gd@vU71PV6}-I$U#=GT zPIr{{U`*CK4dZ*86=dgx`;~#Hp9F6s);TOCfR+X%1wf*7V(MEO)kwZ2#B910sVS{F z11)+W&~ifB{@|Wtzs{iy01$q?L}x?jB}SF;xgNrGcpMo0o7tv*a?AYJ2a~F=WXv#JH?WssMDyQO}*D2^Cj;XnOVKBT?sr#1*7n2#PokZ zxp*x81(8%0NcvV>rd@oqhg@zG^J)*4ay+GASqsx6_vlH<~aGK3;je;QHF#fz6%r zD_OW9GlLaQ_9srO-#o^w!Q(SJger%_ep9S$3RbL?AkFE|{rA6^g(M4HK)}WO)DpU! zpG`Jtd_{@~R#-r=0T5RL)>G4#2#G2vz z>R5A7z9D=5ih024@_lY~vs`k}ii+pMp4>Ver{7a-tp*%H%Y>-@!Mx`eCRNxSiF%x` zhoHvTeM)QM8DA9b4w~4TkKTaNBjHWwj=B|RrC5b|W5?PDJMOmzEC#bR8C)EO=2#_w znF2!RM?megeBj%jqIEut&htUFC)lE3+Fp% zjOy?2w+=r6mZRWxT$G!SzS(u=LkOsa?zZHB^?T>L1prz3mhpihPeE}u@|j?jS93#l z{Q4!*gG8jaFK4;9)1vj?-RU^c%r~DcO8M1sCG@TQSTA&YK^}hxtv4)-7gp|kC;@%c z%3=aBc5w2SuNy_ZuMhv%blEf6@@3@jSNNPm&t>N%3ws6cfCQyoD6#k`-+|td>o5_q z!Y(In!3&Njzn709u6~r@oH2r_@x*|DJKvAi_Z`J=u9m9-roOIF!l>h_;vFs9Rulx2 z!R=(-?x+L5!f8hpA=8_8UBLplNdZkwH=^^h2r|r?2-U zbb0CuC+B2yJ$$214tR~_vK7mFnm7Y0q_h5A@$E-%-%axVLYkV35S+(MV+{{z)BuYE zdm&9X(F~uNHp_UAD4m$==hdC2JT8g=o`{VZM6}h6C2^6fL_1{OKy`%{oyM7Xo^9y2 z!{6_<*)r|RN8P3Thz;FwCl@Chphgj*JrCGE(Ddm6n=l5Q#yH^LntvBmfB;_1IgZNA zx$hM5UoQIX&DsFu(5y{MGGIYeSd7y>Kyo-w0b<9#{zR^DX78gADWD+Qbuuil8%x|i%!~r#A$#*#n*H_1(o3YGdUns z?w}J=M1P4oF-Dzzyj2EWH>HMKprkeljekYKs84bL00qQrLNF8%C)l5- z3~5`t1tH}Hoh0@(K9^%p1ZwlCLmp^=9L8nzx`FhRj6!{ESv=X_rm&a9cro@9*>r54 z7Q}&S)$>-|^KQ0!N?APpvInP46682Bn*C(*gwF?_(3ujg4x#!Kvr`?s78%X;T>Nrg zLVwUgzR(r`Ql2+Tgw`=lh>mx32Tyb&hRWv^&t*W0HbC)rvg#=N?4o5i33V zJyaj-NjCh-acZ4=ssvKH^3^pw)@G1(K02hwcX0^}W^c<_<7LxL$UeD-gDwei(`u}+eCV=9NyN#Xk?=#?(v|ZL$;VJ zwn7{_aHdq!_1U4OWR`Hcr=6H=y*(I*)`qFe*y-oi;Ho;j%rpCt&Hv&CUa$$x5R9Xv zr(XfgQb`R?Fmfs0*XEsJq!_<}Rq2RRGN7(Fd4yK4Y1UYO#>oaSj9}p34(#}R5#R`d z`B#kJqf7GyJoC*BY_xIrDEdN@ zY*3)!X+r*-)nJP0*6Hi6+k$J%Rz)dLzx2YRW*Fpr zKT)>T*8EBn&)6Io010bSA&IVPtJdMO^XVhylGw%Np{MoRQd>AcL()D6D>!&l+~J`G z-wYJb3!h}8<^SIKVMt^P13BRKb`|QPq4QV14S9!bG5S`?z3JDCI-`&J1sbasp|mWt zf0oMTV(*PRH=Io~Mc-@f8tngQT{S-UCl@}0qlr+HN+0$RYNj`R$sH+y0*#e>R64LL zx`p^N(&8jfb(j(GG;k2vxPs#q z;Rrchs*eo(=Oic}M6gnpu(d4U3^!}fZU_vFO(HNFX`^Gk$)ADgDDa!A#u*~8@T$YN zTaD~;S>;P0U&*wVA9<3I$s{JJ9sO4#^zVev$L;hnaUAPFzK{o7WPH=!d_3NK8iIT% z=CWczNuk(vk3FmeZp)#mWzAta^~`4<<-%>cr%E2=F*S=fIhEWjp>-KJqK&!HR`9R{t#?4jzR*#qPm zF|o*{7gdv`crW6K>CuDxzV9Sh%8Q4$34SK}af8!mNz!ea&P^bGa&t2Bj%)!rq{E98i%=@sqL8W}8B!$W(5AC8zhZa$ zX@;V-C_j>xWRey?1!4S*c3H739^0wF1GotzR+b6XXzN?d$ksCTgkDXNTsyO+@01l` zA~-*godApA=b*$Jg+f(>kFjZcbJNC6s)|x4Av0kDv`an2%l{!-*|Cfc0#^u zHTV5%TbX&6{eZt^9MLvGepR&v7>g$I`+ z$D-ycmdFNaxq~Flqul41sN@0ua@ehERpsxlc~Zo**U|@5=W*FU`i#jv)c0B;Cx58r z7R%f(TTjUy{S}KmImw0i?}N0Ob}h(eZsn!kj+3e=#FM2AQvZ~jnr)1gY?X}WjEUUo z**o{hrK9ss^m%u%WkxB)~E= z8I5cDX6f<*XVoW`oCOT@dRF+3&(>JN{{KmdKkb7+;#?@$yT2W#4-9ob%>=>5+0V3Q zyeZHlmK-7yO$Kda6W5KDyG0c0gA7@ljIK!AZJfu=9&Vq%gYBQ=*mS584K>etdODF- z5-9KTegk6mM{DXpULmgVG-fxkTy)XgAz~`@XLolR^6=UE zfd+GIqqJfA&yR1kz780yFq6Xr;5F;;qah4<4-K3Ka$mMbers`8!p#V0XA(G9WC#w3 z416{}cVg082fe6)ucP3#EeK4$2p>C5Nd#Xvy5~Ik~VqYz3tRMk7oTWyvgf80l#GBHvNRQ zmv-A3Zp+<&6yH^mZGgpij(6;J1mh)pB@ncCe);?iQ zE@2t1ddeTaJXC^7$;qJsK+|L2^z>H@cpt+B(q?|S>e3TfbKl>dCa12C`RX>?L_2Vrb zH8DRpi6*xFxiGVnM*|a42t*5SpSTeY>|KZd9&}NBPm}JX&FKJPyqWN2#16Ffp3*{d zkrn;&?*65A^=AJ4xSIK=BPMXE?ztYpv6D{v_ESQy^4!nj%9hl9hrHV2{*-#TqeFm< zzRVP7M27sQb88OJlY-DWXonNzQFX~?FA0ogjOQF?1d@2Jr6Gb`OAeklMK&jP16{PU z)`c_O&&dT^YfQXDEOQ*$yPB=qjI~bxiq17%Tc_|G#7BG4Y0_P71i-Cuyw_Lzmnz`b z3_i2hT-GV(VWArJ9<-ugUhVU-Ik}z=RPc%oMsvkDTx>Vgf!{L6doXdn29E0WV7N{z z2nkF^;zV=N;)JUU^Fo=M&Z(To4pM)sz4Ize$>9?E(YL)E-Ka!>Pa^K-!1K|WVJTF> z^prM^?lZ$$s}s)gp2otPRm`&Fy{3JKw>mTD&mj0@4oj&daQVsryXPX$acW_IhCfUw z7wUV!&d>GGDaH>j)v6d|f8>{JC zDlhaC`*nRo0(g?^!=HNj#0M>KL5N{zDAMmWdlu6{62$EiVD(~H*`NRcaXw=S934#t z62_Z7?{BXPf$uv4l1E=6{kgo_s3-FLGib4tDE&WqxXE5bdoMz5@_+Jhdrd@@eD0Pd zIiLk8nJCo7Uk;YsQ1Z9wJo?aFc}hmVKFBU!2LL>kod{F}L|^9#E6oN8;R~J>gazk( z!0^14Q@?P8Vq4}hqjI{j)eU%^t@-Laru{s7!tWE^>P{|iCR&upOS)_h2vnU~ zS@1KJU&~<1{INVK&+(SKVTHoPML|a${e->e`VmZLW2ZqEB+s~gcCWMty`=7G`45+n ztqNv4{fGU@4xGe|(4yL+Tn~5AdXD2%#1r#5k4lz(%B+T1f9-2XrErBKhHpsj+B3Xs zYj@)Sp49eEA6ig=Ee0Yvu8YQ=%)`uRla+Q+!@1pn%_>{t+f zVN7aWy3mrkbs}#7l0B(_-i2Z?)uwt)&4bXQ8`zl)ly#Byf#ZvRss?oB1_6N-g8EsQ z&oT&f0LyD7(7%Rj${)21CXW$3-!n@{?Z#j$PhJ(yw_-y| zjxkvg2+$b<0uc?T<&9y?73{!ZKmxQIGAX>I$$R^(kV$tS2Hq4Y?eeCg!qnl@ zy?GMHHA7($Tkrn2UC1H^{OW!9MsXNf?bTUfp+-`K$AiZLt-lf&bAx;RhabMlkz2oU zXoi;WrCbZ(USeRoHiDyOkdNx%a6{CmGaYlDeCoborR7g00e)zc-teFu4G=n3`{q;2 zJa%A@p8Q3m_W}iwrr;n_0rSjC&N=puI`EDCB&`clX@fUcL9gmj|GKGb(J99+9-oUb z%&Ggro}JIaFV8f^EfC?Rh@gOhi#8Fke*^Jj8@REMKl}br=&}_D*r0^1o+4n~)27fh z;7?!z98vnALQr6pb_6Jt$!%@#_Zv?JqgTSog;EeH90XR$d^8)8<{BLy{>?TNx0VIq zm4}C*xx@bn%+=A&4q6ZcI)3IZI#rqY1(`WKEs8wzt-g_-I{6ELZ6uFk)x;_=7ARhh z{{||ID}B_ycp>N1u#{QS9{mx}Dd>81LF40xKkr(Jxpfp%V+c##jD`hKDuwdaI3pCe)Ev}@J~AR z;S~R7(ZPnxPzx~>Ho3uyN=A5oBMfD+3C$fk`1&l6aZD1X*U(~qHnM0}0%>6UxkYtCD zL4nK_%wJ{ZhDcM24Qsx$pNpq4f3ui(WinjMG$uvNdWc*Mid~7b^a=%^bfpaC!0*j< z7`+Rf-4DvVu0O#gM z$Yw%PtoZ`8f8fcO0hMK@?#=Dm9Sre(Cm1WE(v4NyEy)%%yOv$ouOmmO)0yq@xHiaq zTxpBQ^=6WENof}{`u@Nh$3ELhng0tNbV z>EXUiX72syin7r0qfbs8(=PMh;e~7?iHReICW{t*gVs3N>EvGifE;%2HD1Z?iGim# zFxVz&!EW)12)H|ifMz)2u=MwhP0ET~2sx|f!cC9_kk}d2_MNPiv6WZ!9WP5v3c6?9 zW~i6Bpagtrv!?%q9U$Ul`>EADN!6AJGGy)ji*KPN%Et!I6)V#iI^tas^Rb>jnNS1m z!eb6?jr?$`v#ZwrZ%luGf}~et%34;RoACyHW9eDI`Ccb#r~S?MnhW>F<8}Mq#6Gu; z4ww4E26#L0a%jJCTTY_JzIJc*B$-evHU)F%Aqm=KzJ4hw)Qw+QJc7%mMUoP4>zvm` z%=uVI-HZVRV(9^>P|Jcq3b1bfUjRE4 z^yqB}4Q&W5)UdRH(sDFrJH_85JzX!1n^y|GuM|`d3T4b4y>q`>Yu@ZS*CFAfyu=SE zVw+pbsVn;wR9oKuwtsy|K!tyNFVLott3B-l24zB;JD-(XolRE>Ge>*Q;0qaQuXcW& z-0s^_k9uW(ETS`t^-;b!udpZIn>_LwFI@fvhm?gAiSC!&(y4i}xX8P7mcQ zCzpAP4K3_3fqcT|Uere!39ke#q|ZE1&YCM2CVPK?eLH!84AKd`Xk?og-ijPl{se;{ zy8j{c`S!WZu6st)P&8oAmgVjAXKdgl_koPfx{IZE*O2VJMl7LNfoA^WhHD|VPoFD) zrLx79TNzAQWoqLWAnSqmJlC+UvkyP?zIUA*B5cOmZRhaYcOgU5og z_4yprz@BABl~p0^+rm2a=A-k{n$3_5PCz{Smhruqc#OS^Hsr&=aGs!)M&@f@x+3%R zMiyCimbX7jyG-49IJ@L@5<*z?aYX-Rk|iBx1D*5T+kK5Nt!~bKyWgmlZjl$gkUK*Y z>?~f35Nyua?itN9P%ZQ2Dp2ytA$hb4Rv}8r?|oUUl)p+JdvOZLe__Mwf+X5^N=5fD`~p2=hAGW@Wq zrlo7+jdz%*uAoh9qm0wXNpG1^)@-}Gu^l&~_Sd`Lb{G3cOe5DNGC)udkAa8Rl5OP4 z35*|D)*e+6AMxGx_a{mIq7DH$+94Y^vY_q7T^5s;-y}BJ3(ymDLdqDbcMq zCDgR&U$R~sSYxM7bhCx|K~~)zctB`LJ8N0r*-`^flvaB~y`D0);xt#WPIg~+2W_-J zqNf#T`gaz5<4PqYf@oUSR6}h{CZtV_hqljM(N>PX$5y;9_cN{G?=!u>A{?aW4F@4f z){m^l+C5RaMEa}2!UJ0nxuQAU>Utg0zbbo_gP$~5yU4wppA0U__bRu{76u%i9@{lI z+^;h>lM(GScK)M|e)*m1mfd4dIji8B3-8@1#k6*OHY-Q$7&#-GY^Ee1(39()vQ%{W zUwbu+m=FvrdMpO#G!m9xZcPN|;IZZA&!Ddndk8ffd|4-w9}B~!h`@=xCj{-H5DWOC zB#A1(On@rl%W5ya!S7wm3A|Ze$v-9J&koyxDgZ6@==Q{x!~xTcXp=YCqILNwm5T#3 zumAbhA)iG0GJC_{jilmKrf}))6t+$Gwr(v1-gX8nF0d|IMi08trKJFKf_5Ojb7g~n zkXh)GYRKo9Uo|8q%0X^DENXx8bVK%_{Eo}S(@%0qlE3N7oD6nI^H@L_?8mfQN|&co za^zziTD(h&I@3q!Lso8qfv(8HrV-|ZqqR%5#ENubOv$#ihPQ^Uo}(+bYZ)!WkybBv z(b-;df3X?3oe=iH!H%7DZt4`OfwLJ;9|*DEI1ThBj8nyZs#8BG`N`G zqh_boX@}-@d(eN4ZL{m`$MM3_6`U08)CrHdA4p3=+XkX+Ok_xa|IL zFFmXfo(0@q0E&`67WoH89+ouY1wB&x$T<=izz<;!i-rAec%iolF8~wj=UCzg{=_Uv zhJNOKUNF*<=4n^BDU|=)oY=pGqzXkYAB!Y8A$*&N-<57`|EsIW&<*lK^hFHjOs1Pd zT6Mb;LT0cL+>e{88mz?{QcWp!OWwSI&G&>dNL!ECK-Kyvr+Qh{Vm@ z#v^SfN!JsalFx5iHdn)WA?MM*xRunP zw{1E=Cf74Th)mq)fs=(>#o@a-w$YQ`W*4$l+sFHrnuUR&hf@#y=$cs!ReY^;{6iC|L&tAG&zp7|#)niI_#Pk{hz=b9pek#> z!{^CDGwH?SlSjH<9KgJ1elB!=PP)*QudHYmVy07t9#XKvq5}EX685NC0R?c}fD>QR z-QY^Aj(T@o4wUTB@;{xYp9n$)X$JUjP6OX3g0oXRcsV+5c&dxgglmd2a+E9&qe*Qe zW7l5^t)^3KzF8z%*4V`k#0v^);AGfKKQ`=+n>wmu?jup%7_irY3y+@7*&XS~JB_w7CGYcn_Z)9LzB_ z=dQt2H$o1udj@oW{_}J6oYM&(hfyt)OXye?H>qiur(g1JEhh2! zuN+uz`14F&J0$jDrg@i67N$QFN!Rarwme&1<3qYooebDE^+gG+k&@4-s6d2fCma8j zk*&foSzrEmyd;f_<4XtVV|DlxG>a4%#(`_Kz*%M*PAKJs9)G*R(AHx1^c>_bWj78z zlK*|~Z%Byd>SbTh^ixH3#@@+pzw;uj4+nG5$})%|Z%di86MJejFa4#|d%jbq<$TeX zFiV>fTZ_b2CawWd?}U)sh(xd|m~P&>d2nEU6YG!2W3;lEnml6J5ToJmm{XnXUvl(m z!VDEDMDK&bW>#psl2DkHA1xVDe8!w23G&o50ERq_|)! z09~>L`p1b;h~&cm)tDD)rb}i!UF)bSI0;q!@XP0rUUOLz(Avy$GfCrQ`hW zcj8o$+}eoVcZ%3@h#vbb9r`xWRxY-N*!ffm;a&srU-7pba3h>=kCFD z2TH^a6Pp!^6GLtmK-tMhA!3tk+y8%_EXbhThaMOQ5DrP$z*MG)T56e zag*S<2DW5yl&gT|vT(DtSLc{IxsY~^i0`)iKPvrVyqtZoZQ%6#i>^=5_Pe1hc5$K- zonwpM`oWTx!O#i(NPU9!6VH|uX?OOSk=Y&tB<&5fA#>@O(K4APjOJrw`ruLFv<6@5 za5xrZz;Z(VWHY1D_G{Zo#zp*1O_TR0UBf%7gSG7Y=GdA$n#qNi;Y|Xu*Vj7UZ_$wh z_&w$;X#05KTEnwD2l6Bn7sq#}=UVQI{b+G90tH^$pp+8eqxt9*egxCC8OlO#y2p-b zjCGlUH#MFN4GT2Il85f6eB5ZX#U8!^zo5sXE^n_6aM(6}@P^OIHZdA|>H07o`~W{f z5PWf5J$YCncs-7^&^N2&Gkj-gpGikdcu27`EaL33k>B#3GM96QG+sHypC_V_@57az}W`S)+H`IKP-DI-Re zT`Ml$SeTu4vtfJ6d~`3WI>kWW*56vo&j%ZvzIU3r8fhykCV^3b>80?HS0-qu1S|r9 z0Iaa=9Cf~R@>s{7Wgdr|qJAGoJKXEZYSSEh>cJY?@KxOytnTLKdM8}Hd}n3)YUArY z@YiiNs9=W3#9cP%;cdCE*DxBK0dRs%z0eUY(w=$)!h^)xV|xBkB{ustw<(k&iC2G1-k8m-r5B>Z_=)C^zwWQ za=i91q+9(h?E8awD%ytHu1$54yMvb*mDx(4V`PPEVuzR$Byvp4 zAXv{4ztMJJ(;TR-=)$e2VQbSx0mZKzL(7N-2NrX5!B+8-7%0N zSG0R`P>?&Ng2sd_zef{XRrRy>#j#22+R(E5K~UU6<2<}~-mpWo@v$b?f)1Px=ymTIeY zyEzBc1^~EWjqxxwU6B<>l&%MwlY%k{fj`aejK?fQ%d)+NetbyJFv#EI8HTNpasmC% z)~SAu-@A5o_GUV}IGZeSxH#6@-k(E0Q@=RZ z+AjWEx)(brQWUrSKO0Hp#WD8huKm3;m)v>}rm7a``SCU1szaX#&+w5eQQ*j(Wh#zP z>_q@#4GtcxB@e-1hW3?FX`bS_Ki@{nVKPg~nA&J#K(Ar;M23FQW%@X_$*KYM*nJMu zvc7Y^9*l&$B=oOLw@kCc1Ks1AD~EyQVc`XEEDfdpceI275|8=aofih9JovRv<+~<> zA6P=aMyf3KycKGUUmWs#Cz$zzl{f@P@A(m#Zo7+7+{GAqKB>1B&pS2F|jPx|z@IwZ!uItmDa#l#^?VcM0ncPitPrik-t{n=bS|Nkx($GTE z1F$x8uTG)}TG_3D5y5%7C`bW=2&-cWQV`Q=Y*ThIyRbi6wSITb1sfcu4*Yd16#=}{(^%0jjy7GPPK1IE zcn!P009zFGH8l~yi^5^Nj%iAbv>djqj@(R6nEW>LNC$JUnw33GwBoQ7)G03$;26qk z5Y)n9CUx*AiJuVAqUTN@;7Xg!06AyWl!x)8MRDe+j-uqJGpFO&W*6#qrqHJ?c)8>Zi~1K!hx2K){S<2gQ?#Btf8 zR;9+u=yeqnkui0e51y0z(T18Lr1}cw`T0fTN__AfyfmmMYju*{xv>v;&a%df!B#Nw`m)IX~ID+S#DU4QZc)a{KoUrx=-%kpNqN<(fzn8OGO;6XF z4oYjZT>Lsvf*oW?m(#DLDR8#6fE%L zjT)y4N2i1tGC8Lk3Ej(@o zzZMDYpkE1Lsd<}@6<=Hulj#=&c2djcNUU8bsNjAqIhmY<7nv?t*|;y}mWsbb86X9f z7)ro)E%o>Daqj2EXh74Y*E%8OYY0C1Q>1{O(ux4oA3;Ok^{uPta$Ajef)Mowy=v)& zCKvLeX6-PE?Bw_5Z8HXox%%NvC@G$J{X<+kC4~rExv%q7Z5{z5CO<+C93L;lQd%Y}gQZOFloIdh7x6 z;2Pd+X1zyzM?9a^8T|Mh`_pNrJUnjCivxlOV9I<|_p@&3JIqWdYNdu^_3IPE`OFD) z{XTqZ0;Dd4()^Q!w#OUsBHOoD$TQrxW}~11a?NG#u`dTM5ode-_WcnK(2G(j=!i`) zE_5jCHB5G4P#PqzU9I1CEdS?dY2maos!by_nf;)4IkE*gNxNI3s*D^1?7 z-~R9Z673v50g{B39;PKlIgzdO?_Y2dlT<%iawv%p<2uehs^P>L0*3-RclS}lqgqrU zt8E+=mPOctOTRKc5r*k)w^rUxY6k+UxYf~2T53XV5WOhuQr}HFNN%RC7r4#A#ki( z^_afvY=i7U{e+cwXqMd4<&AT5&1p}^H!Ov!8R;=Dl6XKNkh>EvZR)cLdCo-mX=3w- z&R83wQV(3j<>$%!IFO+U>d0<-iU<`|erV?;hgVX%UM7<^)>cO&z<8^3A_RC7k{`8Z zOK4pu(;PoTxw>w6yo+#l^8@5}hU>ygdt|39RXL5XlU6f%dd~TPU4sZYXGh|u;+wx{ z!72V~K4WDRmiZc457&<{+mq6Ysa2_6S^3^mC-xP(md^}0%}duBZ~UC*_IxV_7x*)##cV)H($T@*bi)K;B<6ay zP!}>dq`{H(-8JiKP6>A=f%l5}lASUV;IGO8zam7JSD!g?MfFdiP;UQ&(0^CIC(C+3 zT8Iidi)IXI#x+z7lK`Z&qfb{TtE)H9?7bza;G{W?_D*Ttod*|3=k@fT1{|yW_iwXI zqXsacbIz=HvaDX%fQ@dC%mhS#`Mja-2hlcP_dcfOJxj>{1pj>3Scm>>0NEiWTdD$I zS(74_21xL&B=&nCx||Prxv+6u7Pz%A`2;xn_O&>&L|^@J*{4)c4fE16^+lE4Y8QETE`6x~daN+!^K}Zk4Gqd=D;r81HFxaL*#m@+A#l@8qVE)u#=Y1@) zxA=zbIb~E;86ZV5d5}e)A0=Gy zs{k{(rL415)r&ZpKmTLwBaZht`%fxJkfvyA|Gf!JZfvRwuBlNa{p-evPWq$u!t$9l z=OdShJ(r))Z6lW)Ftw*}ZT8=2S=_mwJ_h8fyji&UQ)SP&8bdW})Q>r_wd3O_)p*3U zBE#y!EkP$5g{7Ot_j|7clgL8SRMZTQfL1oI-IU(@&hitd_4|TigE7PKG;qj%$WXq& zzX!b)N3$NK;lxnZ41@GWey6|>zY`Ke7@2W^PY?*eM1>VpDq^ExgK--vqF6x<{aTY^ zZPfn;<)zra2PwZR_f>JnykxS!*4MrmCok>eJXfMZvA}2aM zYJ0gHXqVkw>+Bfr38; z$=2r|rj$;P_9SB6TOF=_r8mOPX?WdUSo}05O(c2+?Hxzwa$V*MURNnHC(WRbp{_V+p)-O91y zEO%P|L6U?G^aTCLhO_o4qkNO*UiuI`ra_JnzieEuiK#>5$%~6)q_1Ff%vT z&dX7jHwO+yAh|p$cj4aTG$bT9(sKX(;B*w{{MY}^MnW(fDfntN#)|o}5RnqHCu5N# z>GntW{h_h*5%tGBqk2c9q+SV?+=V`1@#{0Sex%SEET|<$yjJT(iO9O( zlHDSV>UhxQdo!miHYUSw{UOKfheEA50P`g$Y5Sc zPBgrWW%+UFhonHRc?tys?5NKgEo2-KCY;GZ(Mt}Ww3Dxr4*mf2Jn*DM^;3V~U8>%X z1H>U)z>0Wjr?@>KCV-JIv6E|ovdqJ!E9P-mb?zJ3LOi5<*709i*QGTPm;QM1=4`NsUyosd+^wF;y$CpnOV zp6x7dKqlgC0hVD0pRH^HfR4*xRH7+?@ylOWc*=zGz;#IQ243T`VW*&SNB&#L#hJj_ z=xMZQ$BWqBvz$PGq)(aQ>yc}Cc^c*zCe;K9surKJ$D`z^n2Zw>=)V$e8`foJ8;6_i zH+*C0ebt7!ae+42{kC1OoyKdv1i$2P&SS{);BJs~ah7M}{n9>aKnb%EHg<$E-Ae3D zss15>;9RDVv6xtokAOhh=|zhJ*&6@%68WHP5;mjB*d-60`|AX}m@1aBVSatMQ6D5*b4$OTC1VDo>pA0n5nzypTB?_uC{1rU z{Ni>~YYuPfR>4ceLbheVBkJSH797@X3DeAx6^8yq%4!i}FOr_Tt-p2Vx2B{@yX?ez z$e{f|W@q9_K)@4GQSfK?^-@UW9Zj1W#>Ax?PzyR9% zoo>;mZ}|>o{iHDX;hH^;ydnXVH~hAH>f^lZngx;%Zux#ODDN+No1!OA6c{PknlA8l zy!jF0#*r<|20^y43v?;tVoKn6+(d@oe2oqn-i$ZhpZ@uF`P@UCE=JDFnhcmUKfAW1 zlFK=XKb}eUI&+_skZOy$G>1SYyU@~G*@g&ikRaNsc)>Pv5Xc;_GbTFD=vIzB7uTls z{KMc7g`{Nda_1;%>~2Uv4SuD7$GiP?Y!&$@ls`}-6;UL!{H3$r`yH(;o&Ejz=IB?p z#oUy|WH$deDWmak1*QMSQH4QzM7wKu=&-FZ63neEJRhbe)u_(1b0Qfss=dIogvu1( z3ADq{nygZi#Y}{-*HK8XC&xST#xQt#db;mQjfg(Uu5`M<2easR{5u{7s&!aNsO9H&t8~kbbmv9c-2k*-=-MgglZY=7Qhea!@Dyl~o z_(ss^EvH2)$Vka#!4w9fNm&fTf4;2`#WX6>!Mdx$r1A7x;&G3b^o;))yEH=$g+Tbb zgfME5kSQEpQXXnJUa>lwn?roRw%kAXRQc9A3L;^_@ux_WN)Fo4q>4-VSe_>J@Wf9f zM>e9NPIxrCs)tu5QVOL+O*dd@ERPCa{kB|^l*B4AL{VHT{ngU5vz(@OidWz(-8~e{ zyPIy6-}K_Uq$sVPUTF!%ILh&v0gnD{4sR>6q(S618=D5s})Y{xfG;G&7t$!Xz zSSF_)o_|f(7wKjvlCbBYtRvZ8)Z-2cP}$M?{@?4{R*!r2Dx_CaxBG8yUV2~b@; zUSg=qt*y2{zSoV)*|GdJ84^2pSwMz(W;dg3=rl>8%hgO;);ryDqsu{2A-E*G58eDw zJ6ApmZOKAoP!yCMP3v>sBN?xB4*Lhb-(N>niOTj>G}kJSRHR$FIaltU_rYPs#}9ek zG*(u|(Jc0Fnto8`V~~FO6!H$5h;dhCIc#h!3@ao&cHXbNWUp>}PHks6`i2|W&f9tv zW#0G?+Gfr?I)`69S zQ-JvHsJyG|`=uS^tZ=!WF`j@25q)Xae;-=kU27akJ^8mIC}`{J2j`MKa~Ghjo69U( zoMKs}vES?0G~f zsbit6WRl(UYAI40)i^^;p2HHgE|DHrP%t98H*(8fii$D@9cv-^aeM|Ix*Sf(4u`pp zSCuoq|G&Zbwl4)<0!##5QAxq^3jI_g-WG_9R%vZ&w#v(GO%9)IR+Oh6o{1jj$54%~ z1qV~0heTY=3Tib7aLD=Gp^h_5w8!YsYz8+)E}W@GC?Lm_u^9co-4Rv*=EZDeL4U#rmI$ZD zRY~{#0SvwKw`>?)W7O%TPwIEA-jxKEk41Apx+NuB&(3zzhI#{UhSO}nhkd1CLOpeV zl#Cs_5GLNb)DD3JgyTft-@9Bh(!C$KI)Om$svIlYBBi}H>bEN@eq z;CbiI=l7i_CvAN9Er?F1rw`rNUPPIjqlkwsp89?PQ-w);Q3#$9`Cw%4r_2=^&MQ-; zXOMTWMxHUo1mwUrTu8j1Um3HuJ1^LT-jpzoS}1@ONudTo59qNhNZ=Ip+kpDbwIk=J zQqT}{`e(`$Oj;a!Um3f$+_G9zzFA(Tw~I+JmHFSp^MldvHr(;Nayj>kQ-3ZG`oSUb zr2vd!QxoC+eN>qcyf?FK&9~1cxvcyJ*`7BaNCzo}lRe>0wwsw7cRNXKut$L`lDxo* zuO4%3*x}mu^f{wD-P(HN)mQat<>GpmAvO0=MBm30cC_D1_cs&{&Vk?zJVwQ$a2O+k z3R2Umn3Wr|t;)($!I{`_Wk^Xd_mFXvUH8`n!1?ZCc>mv2+njGje8i+xqbTa8Y5TDs zPabr*T}7CWz1g9DoQ%Ci?t`gq@`gh4Fp=kEuzJ~G_5K)xNfsn67fDF5pgkqOaFc^W zkRUa($(=!+k=h*by3faii;x}>fcXylgrcn zm<5~A%swm)6VtzK^tqIe4^2~}l$aol>DIO7R54(KA*i}-FNQV|+G{(8``=_7N5xQ5 zHti+;p%kUS6@br_zF5*eY-%h#n*V}?*YWMN_fZV@suc_!5Q2Z(sJ$b7W}x-3v9|OH z;p2~xWb8;WPBjbLNQZ@H+eSN5 zbQ~O-+q;(+wRf3yvvTQ{&89WeX4!v9W~z&hol3K}!+xj3&Y4NgifqI1~ z=kSl@oP@OhPDVJJEr}4;bs+;{s=$LzDEw(OEc0?7@o3g8ZWm5NAT?9kovLA`lJSKnU|yr z&QAMktALNnnKCP{|9$TC>!1MI{O2)0Iwj$*NF=<@{*DcYnSFy|)>7z@cP<=QHL_5q zcsdskUhOUMt1Y$)NYi|#P+bKZWfhUVNNc|IQEuzWkg#H&RCfQ>|I^)fhjaP9f8SIp zinqN(q>O~@kxe3H?=4$KLUywEC@UkRvO;EN_6i|n@06Lc=kL5feZJ4{`2GDH&+#13 z;~xj_ao_iKo#*))*L9!wp2YX*`zp~f|CCcDHlNZ)?J3=VF|Wmrl9CcZ?%nIHy6_io zf|mqg%#_s1U6i2`w9DV|Q|(Pdkfba|`Z%nXmM<2DH&{eE%v!PISUE90Ui;!}v)GTj z;-h)+nZtY!pl|h|c{ew8x7jOv{PJUYqQCK>Tvo&>!>2h|7WP5I6VI+UEt=&;w4hLC z-`jNYwKe>J1ae>|tt^kl<&8T&YTlrhQ~HZJ(&f#KBqMt2>6;qPYc#kg$78$x{12J} zU(iVW`-UCvTe4xcxmEYv6;3u)&ov3D{QH7hVKE$1eXvgjC*Z2m<*R}tfq&11kg#vU zFM#Dk(*KB*OdlMe)8k3?*03ne620L9p?X&E&ySu zCXoYV{QDE54_Er{+ofF~ga7lw$Qiu9d=FkUBS)e1|Md^xdYl#KeYEtIWqxrn3+`m! z#y^9+8u*DIxM!wRwg~^|o^Z8oQUB4c=c@nDFS=pd`>m+e*e>FYcP)<&+vbtKg)=>T& zL;fV&(Hw>B!_~>!v7VUd>zwsl=rR8Lh;aPBr(r`&!}Y&2F>inW_l!9-{~wyxh0zLo zy8Z1{Jbr$D%Q7A3)lpKoJ4EV5l-QHsOmFVK3~amPaRCgWXr}WL5)$QjiMiQ)d!&*^ z18?q9hP#crPIP4T4^P9<*~!bs>gURC;>VJvTy+yQqllFKWwI} zSLeF_6kRzxh-pN7~w&HJ(F1)V^*H zCnO}q4IZezG;P~XWHVoxPj29bHC$*|AyKNPKlItY-uviFhSb>-m*c3*2gMrR1hhi_ z?d`YC9z2+V2iZ5k4I`1`<9e^(yfHc2--3CoiRkrbFP8N-bpLt9%);VjnPIid#dPm| zmhp*+z8ux7)YA7}z{~eig#tMAtLPaSPbn)a2Zw|-L{Re=nRio#Q*dFku&~^|ef!gR8OPyAxl#!9i zgnN=C3pB!>oh6_eD=8_Ff&1bD?qm~L^rYmJrpH~q7dlbx8g_Vywj{GI(}AfOy=I-u zYBOI^RCFI67q4n&Zk}A$*4%s!G3iMWEI*gdYFu!z^S%2e*%-tID-s+@E3Ct2mdk9j zP+9tdO%oTHg_$Nqq}9}(zsP!%`$5B-n26{!5)&P5RBe##<O%hG z6Kfu-uvI7!4z8|k&>te7lv_Qn%*i%7EiDq-_L%PensJ+tk+a2csRxz)QE(-e`ziv{oenxl&N)_ENymsWrL zm?0N+3ArQNraVz18+i{dp#FFtFIDJqqiK|=`G@Nq9QbKzX+NwspkaVFy?X zH=YQA04%YCxIwg>D-%Y9>S1~jAC_`*-9^`(wTbr5PDx8k%ZP(lqs10aB~?F&wx zhZ}7gAyL+CqkE!mzvy8)gP;x7+`)RioUCl1`t0ju2=4lmV~^Xh{!dSxo&8p9s#WR0 z@cnbrv&YLr8AU~3a=BCoM}AxUdaqv{otzvBQGgaE-pCm19-TejL;M$siI)9osi_f$ zK1ZZ=b#;~#RgbS{WhWQOMqcb1K7ID=OS`xZvvcRpiF&1_~$!{23`1Jd%&`4U@pXw)n%%Pzc9zBn+Jm^vt2M8OncJT`fGEScRu`aLJ&^=xIkiEe#EZjxK80z@xzB zsY-+vV|_qCfPm*NXH88_k;9B+b6ZX}dUemXa*e za-R6Au|a^sbcSv%56{89rLrn2L{OwH^h`VZ`Uu62Hi<1#zJ>5QjlYsOg?xg#)m=_X zPChsO5sVoaSZau3kxsIxXFak+QNl>oa+uJ`u53*;1 zkG|7izkZRNL+=wrsN?>oRdaLmuk#;WjPj<6ogO{fXd)2%1ebYFWJ9{ZZZ-XgfxT}azpKoP<#4b%M-7RRS;0~KjazVgVdgD zf1p4huzS_iI5arlWU9`Kg4^tzp1%GD)KVy1W5P@y{2-Sk&pk~^p;?)zCcmyzDh^>? zY|?`35hPo-`T1tAc#c|f#Jh}r26pxfH)?)iOG-*kPfsI0K0X)dZqamfbO4wo^4U>2 z3#HaJy{6&}kL+LGa$Ebjh#@CGYj1DAdvGw9>N?r*O5&Nr_KG&#AKu>+6INeRCvT`a0%%A5M8`Dbxy>C>C;wdf&Y zna)F|Q9i$O<&wt=*995@43t_~H8qmi+1c2C(&IzAbkhQWBQEAys0 z{rc}}pA&CsSy@TuI0ivMO1sh0oX06@H_NMhQd3gCe0*e)uU&$3UAMd$5;7(&%^*)J z#~eT*TOq*&h{7)1v=|}P^K;E;IlTcsD-m@`G&(6MUQfJ$GM9D#tKV1{y*Y$EpqTxP{YUjT`@>h-!3 z>lHQLp-xA*5EPAXdF|6*xpD@Nj9pw-7RP)4=T-lJ04$`Ywzh!WO%!B!64&RPG9I68 zxbXr8vP@t-e(fT}^6}BZ#;+y5T)GrlQzLTRXL@z0 zeWT&oSuErk&P6{a`PLgvlBw|aXr_+&5FR!DrCE!j42q7c*n^@r^O|{sjZ;%};vK3r zo!q4pHBU-h8dSF1g^jesQGNgH8jX@adC@aHa6F%=z<97a1f3#WzfQAt@m5d?;k0j2l?*x zW74$ykSbKgz4ns4y!?;;?658DY^t2FgB{36WKHC0%u+lQfdXBAsQSrX7b&?6a;xraviSa%$YNTJc_vty5seM zeJXm@1IPI>hE1wB3-znb9jlFbgc6#I2y{;_-169F_ug9|FMT+XOz`fS0_KIXni^9w zkJaU`)@1^2zg`uK3JdG<*iY)cQ)5Ll+VlXT2B=SNdr!~%BEvGB>)5AG*%9-oCzbsZ zg)`s3GpD4aG~cK{RFUgbeE`4>1zvkajDw3w;qv4>&ok0jm zNgq6G&tc8qqtuUAdP^L~e3uB3c5|yNPcOla!9PKtoYC|DmW~+?2|+=hMpS%@*Wx*@s=YWvJv`n z+Yys4)eI*swl+aQ!Az*yudl;pZUUYl7@OhuWeZm0ap0zfsi>%+UkAL>1e!68`krlPH{kpxKYUc(*NJYgOc#)c!`Cb}(n$Pi*IRe#e6&8&`1N|Ti z=?B`3ZSh5%k6nv~M91Ij3vWPw#}ELZyWx~KWzk7NFcDO|L|5;n<8LW_q&q~ z25P*r7%f!IiCm`s6%`c-J#@CH6%wLm?MhDXWz&35f;9VLom!%I%Eh&r=jY;5zD9zJ z1%*gD7{Cja04U2sf&s&aW+;`lA*FsVbjbjEw6w=Cwk#LcKa8D%`;foRjEoCwC1yIU zPE>0`4aWm0nwg%a7jj)=uRB;1E;;I~X1M4xt+uINHLLhIk8iKw&OAiUF3CR#Z#<5D#-0?x8mDBf6mV;dG%Xs`S|Yb+ZbA57m~=xNXfjxP|y>4;_7Q9RDBG%Re!Yol7h!F+VWdb#{Fv~jU65G&P!kA0om^VaIO*(z{|aD_sH|d z7lsX(D=?&~abZ%Zcr58Jd2K4k#KywuIjqO{V;XU9V*l%s1PA8^gryVzs|o+_tA4e0 zXrXx;G#7M}+{SO=02+QuK9$U2Z==5{~fVps5;h~3IT0|ry{;yuqdju)l-+jq< z5k7JNw@#OT{;XJ8Rb}#fzGE9NJUDnw=O2q=kU^?JJeFy9e_!(LfcE$eUf#h(lxEz=@7^r^+1Zf;DK_bI;#FO93<`+TbxrRAR7@-^EatJHeC=mDQ$BQK74|`e zg@yXF)x-Jv)nqyaWVIeU<*wIwxBC@tWB{bxwCE{yy(TT~oG@cw4VmEb%jaar$Kmng zmf4oD?((0`LY_eKe@s;KKfjYro0;X23KFm@{(4`Lu@oE^@+0Qu%P>Ml*+e^P9!m^9 zjesv`I@^&CE%<#Cy3Rthd=JNi|4LPlAXQV~q58ArpXz&RDJ)ym+Ii) z@CnpevE!VqUak9}*%$<1bbNf^+jQ@$@wotuR*f51mFqe_WE-mNP1_>L@bnWn^d&t# zMXFp@ae2-e?%tj;@IHYjzs&-M?I3U~PmT{}e*CaZSyR{3 z3kOUgN)>hu3<^36D(oLTY;I1}vr!l*fingq2%n;^lZ~JXm}`rYVvftOpQTKD0ypA>-@e+4GY zebyI%^Yq8?Uq;}LATor3M@M+c$;cS5PuBjJtQF48%3_G?_=n@_R^T@U5y%-C85Q(+ z48PY$EhUtD6`Qurf=J;NDg;viQXCJsk1Xgc3>jGh;0 z1))Q;*qbf~5O_CIAC81h4`>KAchM49=z~**OqnEBRSv+ot(Fb{YB)=0^yQVrF7TEN zl^o=`-nw>)c>{3GQ#coV^YYkG9t_Q-NTHQ$nU;|@m{zfIaT~kv010quzCmkeA|j%q zM~g~_Fqgm{6w|Q z=-lTGnt*x^4hw7QNfj=3-?9c(j!!Ez4JKN18=X%$bRK=_`JG~@Jg;X_?m^Y#WgKL8 zZ!w*anp!v86oQRT9MDOA-3a%uLvy;~y9p%4au9QGqfr;0#>Gji>*Y1(dwG zk$!S0hhU%*kbyRC-3AXhBr%Z!OMGRje(GCVd^|H`*Jl&$D(4)8-n=Kp%J$$-uh=#D_4?g zzC+1`&otuQNG#P{7=PyHCk;-e;wWq999T;Jk{OuE$w@TrTDQTvMD-VzIJ<87=Whm+ z9xbng(&;3%>UI#=%$KVy1oE#th7%rp(;?KIe2u1~h(9xyqw3yoL?d_p2fP@PKjWGxIA{XV%7PR3) z5V(W>L_BYhYV3-9BS>R<`PhzwKGYu#4h`kD^7zxu4>~0feDNf`u}<*3=umWYq%u^` z*DA&qO2v0{bYy0n0;DtVO}mJg51e6BiR+CD-QC@t2K{C}U=RYOI!tvB272X3&fBy` zo5I>ah}|rxw_wVC^4giO;7e_z7ZXD@VqkiD`bN~}euxqT0W*)+%COd*7lPFvN*OJ5 zC1hTn)2njs*!2SK`2}Qv=fP@4rNa#N{?2-oL9IJ@`1l$x;|Ft8-|D0f0|9aXod!@p zIvY3PzP%EX?sNR{&IVX9(_m7zL3~HSe<)$Kyuj5v1@W}DZ;p?TZ@xomv{ySbKaW-}BRDsK$4ewo_B5=*#XI&jkX-dd?wS+! zyUQy>1=KuL7vWuM3&*~>KHh2pKofv=eJClJSz9wGFbxTQ=WPOZQE)pr1vMCA;%Do1 z2l(I;Gy-^mf!FA?)7#nUhrtF&ZQlGj&^2-6#*NXXrh23x&LCHzf`~8d*a(RfWHaArh}7{Bp4@tety$H zg!JaCHcfu3hn&ExDBA*<;C%F55iECNst6pg>s#S;0eGIVrdM16V%P%wY6Af7vXBrJ zcmlg{S{NFd@X1;a%^ZRzkGxl}UTGAowe}yFeSN2_&Z@WoE(c&Qx`qSlyTyz9=FN|&_zG@|Y#u&{x8lK_WMpEdwsPx(5YR{u#TjzfC1p})`K z$t!xJ>3BhuG=G!$OYUVJ9ug3|(#pz&C?$vDV%*?&26YP$4-dc8dbpMxpTfpwuQUJ2 z;thrUnQT@As9p>ak;#V>*C{W}ha&Os3Gb^2r0`e~G2c!0hYU7`rfz5!1EX|U=p=)Y$o1ku;#xvdl6IaJtLn=fBPiNhrIs;{ z($j!#UjU!tyx?v0@YYZTOVd8})@`>4%BT}O%88>Lj8f?@t~YVb77@s{xzKna+S zzvbqIu~sxUXHH2?&Ctlp^J0g=zNcr3#&-{Z_4N9BR2cbOE}fd^k~+6u~iC?^Ex%>31cT33ZO$j`tReA9Rr+pMVmCk6m=sq4H98 z5z>8->tVpb!2xWw70?I0@{(0>GxLX*slnJmS0a&^GP20<1iQ+4#L^E8+vxZOiox7X?5u25ti$6p+A5?Hc_mX9@5Ufkqk>cwevO z2aCxO)O%%RWuhYCiw?7e?~M+H+R3e&%9k##*8}~2X*X8wdIP*V#W%Q`Up4J+iLrNd zbu|Uxk$38gMFa;QC!w3{E`PZBQ7@9$W@sAFvJsX!41XRIUsaNZwTISlN>fMHK0p@7 z@1H4=D~wVgNq)gw7{$bTbf!?`4eXR?5}f;F#*Wpnkx0^JH!wv}vy89M*{yJU`{d*B-1@!z{D{yHRwSpP4b%W3S6c1rj8>+w{}3pUc&s zx1nS-gEFSHXoT4I3VeK}RZ!Aj`&f3}<7_%mZgGYz&K11HU8eT?;(<)l5k7w0x zp?fLB;Y|Br-dR_L0Uw~AorN51DXc#Zg(}w6+e-k6ApkoCP{bl-)1Zd}%bj$g7urvt z>r+5-8JU@J_#cjc0_4E}9n)|SaQ=b8xj^9yw3wQltln+oGPne%k`x1B5qqTapVit@ zD^=7!0m|P7U>)@JtJ>?=uU~aUErIb%F$XtL>rPhO;uj}RLyhymXxZHqKMJy)(j$DL- z%_&lpFqBav8cWYhMJE0F(4JB;Vl%LM!w%_ zJt|uM{G6(KF=RGX zRK$rTK3?-gRWVIOo9E%U7SMB49ih0Ytfy!D!N<|D7^Pt{FZt~JWwQuj`Smg%p9X3{ zebcmxfB#a#*f=9m)T11fJPGPMv%a_kls6dS{yq@n#`&_pN-230T|+<(ZES38P$@-~ zQ5bmixg+@SsMq^;K;8Xf*J!nCA!xwfG`sTdU$@@<=!);GTU2Pv9<&B~DGJRe3)`s= z55geQ2!T^UA;X1jAbqeR0uM8~;J7v}IOm-`^#~Sf31O|92&iJ?zRd0<_ zyxU4jvhMCe;441HE-ir)j{4%Qpd`3Vn$LLzfh*-2vwxzCje5)KNf>OL z_21f^K1Cz!lHd52Aoe4KB-52E8Ia>!Z?2s|pNb;~_N(F>LfCYN^$#@&4n}Bhc~Zbq z)h7tdBG}Io4p0kl``*z9liz6%;gTv73}n8_y4Idy`+ zd)E9o9Ov4L@;nGa4uK1Z+ z{l_kE=lBO<|JmJmxf<>WK;MA=tiWEpzyAgN#Up$3_rE}E{=fLZF&#HwZM^4P6d?jX N@-j-&? @@ -414,9 +433,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + // R26 split-launch extras + args.pv_threshold_per_head_ptr, + args.head_remap_ptr, + args.nhead_in_launch); - dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + // R26 split-launch: when head_remap is active, gridDim.y shrinks to bucket size. + const ck_tile::index_t grid_nhead = (args.head_remap_ptr != nullptr && args.nhead_in_launch > 0) + ? args.nhead_in_launch + : args.nhead_q; + dim3 grids = FmhaKernel::GridSize(args.batch, grid_nhead, args.max_seqlen_q, args.hdim_v); return ck_tile::make_tuple(kargs, grids); } @@ -459,14 +486,15 @@ using fmha_sparge_fwd_traits = fmha_jenga_fwd_traits; float fmha_sparge_fwd(fmha_sparge_fwd_traits, fmha_sparge_fwd_args, const ck_tile::stream_config&); -// R25 V0: kEnablePVSkip is now a template non-type param so the codegen can -// emit both true / false instantiations from the same source tree. The host -// dispatch (fmha_sparge_fwd_api.cpp) selects the right specialization based -// on fmha_sparge_fwd_args::pv_skip_compile at runtime. -template +// R25 V0 / R30: PV-skip mode is a template non-type param so codegen can emit +// all 3 instantiations from the same source tree. The host dispatch +// (fmha_sparge_fwd_api.cpp) selects the right specialization based on +// fmha_sparge_fwd_args::pv_mode_compile at runtime. +// 0 = kNone, 1 = kPerWave, 2 = kPerBlock (matches ck_tile::PVSkipMode). +template float fmha_sparge_fwd_(const ck_tile::stream_config&, fmha_sparge_fwd_args); -template +template void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config&, fmha_sparge_fwd_args); void fmha_sparge_fwd_oneshot(fmha_sparge_fwd_traits, diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp index 06be6215bc..10a58ae05f 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -4,11 +4,14 @@ #include "sparge_blockmap_trek.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/host/device_memory.hpp" #include +#include #include #include #include +#include // ============================================================================ // Type configuration for block map kernel (reuses FmhaSparseFwdTypeConfig) @@ -265,6 +268,21 @@ float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, [=](const ck_tile::stream_config& s_) { fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); }); } +// R26 split-launch: partition heads into two buckets by per-head pv_threshold +// (sentinel >= 1e29f vs finite), materialise device-side remap LUTs, then issue +// one fmha launch per non-empty bucket. Bucket selection happens entirely on +// the host; the kernel just reads head_remap_ptr[blockIdx.y] to recover the +// original head index. +// +// R30: the "finite" bucket binary is selected by attn_a.pv_mode_compile (1 = +// per-wave (R25 A1 default); 2 = per-block (R30)). The "sentinel" bucket is +// always kNone (mode 0) — sentinel heads requested PV-skip OFF so the per-head +// per-mode choice is degenerate. Per-head per-mode bucket (3+ buckets) is a +// future R31 extension; this commit keeps the 2-bucket scheme and routes the +// active mode through attn_true.pv_mode_compile. +// +// Backward compat: if pv_threshold_per_head_ptr is null, fall back to the +// original single-launch path using attn_a.pv_threshold scalar. float sparge_sparge_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, fmha_sparge_fwd_traits attn_t, @@ -277,11 +295,115 @@ float sparge_sparge_fwd_combined(sparge_blockmap_traits bmap_t, << ", fmha_sparge_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush; - return ck_tile::launch_kernel( - s, - [=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); }, - [=](const ck_tile::stream_config& s_) { - sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); - }, - [=](const ck_tile::stream_config& s_) { fmha_sparge_fwd_oneshot(attn_t, attn_a, s_); }); + // Decide bucket plan. Pull per-head thresholds from device buffer when set, + // else broadcast the scalar across all heads to a single bucket. + const int nhead_q = attn_a.nhead_q; + std::vector false_heads; // pv_threshold >= 1e29f -> kNone binary (mode 0) + std::vector true_heads; // finite -> kPerWave or kPerBlock binary + false_heads.reserve(nhead_q); + true_heads.reserve(nhead_q); + + if(attn_a.pv_threshold_per_head_ptr != nullptr) + { + std::vector pv_host(nhead_q); + auto err = hipMemcpy(pv_host.data(), + attn_a.pv_threshold_per_head_ptr, + static_cast(nhead_q) * sizeof(float), + hipMemcpyDeviceToHost); + if(err != hipSuccess) + { + std::cerr << "sparge_sparge_fwd_combined: hipMemcpy pv_threshold_per_head failed: " + << hipGetErrorString(err) << std::endl; + return -1.f; + } + for(int h = 0; h < nhead_q; ++h) + { + if(pv_host[h] >= 1e29f) + false_heads.push_back(h); + else + true_heads.push_back(h); + } + } + else + { + // Scalar mode: identity remap, single binary picked by pv_mode_compile + // (R30) or the legacy pv_skip_compile bool (R25 A1). When the scalar + // pv_threshold is the sentinel, force the kNone binary regardless of + // mode_compile — the mode is then irrelevant because no skip happens. + if(attn_a.pv_threshold >= 1e29f) + for(int h = 0; h < nhead_q; ++h) + false_heads.push_back(h); + else + for(int h = 0; h < nhead_q; ++h) + true_heads.push_back(h); + } + + // R26-R3 gate: skip empty buckets so we never schedule a zero-grid launch. + const bool need_false = !false_heads.empty(); + const bool need_true = !true_heads.empty(); + + // Materialise per-bucket head-remap device buffers (one int32 each, freed at + // end of this function -- before that we keep them alive across the launch). + ck_tile::DeviceMem false_remap_dev(std::max(1, false_heads.size() * sizeof(int32_t))); + ck_tile::DeviceMem true_remap_dev(std::max(1, true_heads.size() * sizeof(int32_t))); + if(need_false) + false_remap_dev.ToDevice(false_heads.data()); + if(need_true) + true_remap_dev.ToDevice(true_heads.data()); + + // Build per-bucket attn args. Scalar pv_threshold field is left as-is so the + // device fallback (when pv_threshold_per_head is null and remap is null) + // remains correct; per-head buffer takes priority when remap is active. + fmha_sparge_fwd_args attn_false = attn_a; + fmha_sparge_fwd_args attn_true = attn_a; + // R30: derive the effective per-bucket mode. The "true" (finite-threshold) + // bucket inherits attn_a.pv_mode_compile so the CLI --pv_mode picks per-wave + // (1) or per-block (2). The "false" (sentinel) bucket is always mode 0 + // (kNone). If a caller still sets only the legacy pv_skip_compile bool + // (R25-A1-era) and leaves pv_mode_compile at its default 1, the behaviour + // is unchanged. + if(need_false) + { + attn_false.head_remap_ptr = static_cast(false_remap_dev.GetDeviceBuffer()); + attn_false.nhead_in_launch = static_cast(false_heads.size()); + attn_false.pv_skip_compile = false; // legacy bool — kept consistent + attn_false.pv_mode_compile = 0; // route to kNone binary (R30) + } + if(need_true) + { + attn_true.head_remap_ptr = static_cast(true_remap_dev.GetDeviceBuffer()); + attn_true.nhead_in_launch = static_cast(true_heads.size()); + attn_true.pv_skip_compile = true; // legacy bool — kept consistent + // R30: pv_mode_compile carries through unchanged from attn_a (CLI choice). + // attn_true is a copy of attn_a, so attn_true.pv_mode_compile already + // holds the user's selection (0 = kNone, 1 = per-wave, 2 = per-block). + // We deliberately do NOT override mode 0 here: if the user passes + // --pv_mode=none together with a finite pv_threshold, that is an + // explicit "build the bucket but don't skip" request (useful as a + // control measurement). Routing it to kNone keeps the CLI honest. + } + + // Chain callables: kstats -> blockmap -> [fmha_false?] -> [fmha_true?]. + // Empty buckets are skipped by emitting an empty lambda; the wrapped path + // never issues a kernel launch in that branch. + auto cb_kstats = [=](const ck_tile::stream_config& s_) { + sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); + }; + auto cb_bmap = [=](const ck_tile::stream_config& s_) { + sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); + }; + auto cb_fmha_false = [=](const ck_tile::stream_config& s_) { + if(need_false) + fmha_sparge_fwd_oneshot(attn_t, attn_false, s_); + }; + auto cb_fmha_true = [=](const ck_tile::stream_config& s_) { + if(need_true) + fmha_sparge_fwd_oneshot(attn_t, attn_true, s_); + }; + + // launch_kernel returns elapsed ms for the whole chain when timing is on. + // We always pass 4 callables and gate execution inside the lambda; this + // keeps the timing contract stable, while a no-op lambda has negligible + // (~ns) cost compared to the saved 5-15us host launch. + return ck_tile::launch_kernel(s, cb_kstats, cb_bmap, cb_fmha_false, cb_fmha_true); } diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp index ae0952cc41..8ba1b97e84 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -110,11 +111,22 @@ auto create_args(int argc, char* argv[]) .insert("pv_threshold", "1e30", "SpargeAttn PV-skip per-Q-tile threshold; default +1e30 disables skip") + .insert("pv_threshold_per_head", + "", + "R26 split-launch: comma-separated per-head pv_threshold list " + "(length must == h). Empty = scalar mode using -pv_threshold.") .insert("pv_skip_compile", "1", "R25 V0: 1=use kEnablePVSkip=true template instance (existing path); 0=use " "kEnablePVSkip=false instance (PV-skip AST removed at compile time, equivalent to " - "VSA baseline)"); + "VSA baseline). Deprecated by -pv_mode; kept for back-compat scripts.") + .insert("pv_mode", + "warp", + "R30: PV-skip mode select. one of {none, warp, block}. " + "none = no skip (kNone binary; matches VSA baseline). " + "warp = per-wavefront butterfly vote (R25 A1; default). " + "block = per-block AND vote via 1 LDS slot + block_sync_lds (R30). " + "Overrides -pv_skip_compile when set explicitly."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -147,6 +159,31 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::string dump_o_path = arg_parser.get_str("dump_o"); float pv_threshold = arg_parser.get_float("pv_threshold"); int pv_skip_compile = arg_parser.get_int("pv_skip_compile"); + std::string pv_per_head_s = arg_parser.get_str("pv_threshold_per_head"); + std::string pv_mode_str = arg_parser.get_str("pv_mode"); + + // R30: --pv_mode maps to the int dispatched at host. + // none -> 0 (kNone), warp -> 1 (kPerWave), block -> 2 (kPerBlock). + // Back-compat: if the user explicitly passed -pv_skip_compile=0 but left + // -pv_mode at default ("warp"), honour the legacy intent (mode=0). The CLI + // doesn't expose "was this passed explicitly", so we mirror the rule used + // pre-R30: bool 0 => kNone, bool 1 => kPerWave. + int pv_mode_compile; + if(pv_mode_str == "none") + pv_mode_compile = 0; + else if(pv_mode_str == "warp") + pv_mode_compile = 1; + else if(pv_mode_str == "block") + pv_mode_compile = 2; + else + { + std::cerr << "Unknown -pv_mode value: '" << pv_mode_str + << "' (expected one of: none, warp, block)" << std::endl; + return false; + } + // Legacy bool wins iff user explicitly disabled and pv_mode stayed warp. + if(pv_skip_compile == 0 && pv_mode_str == "warp") + pv_mode_compile = 0; if(nhead_k < 0) nhead_k = nhead; @@ -271,6 +308,31 @@ bool run_test(const ck_tile::ArgParser& arg_parser) static_cast(cdf_per_head_dev.GetDeviceBuffer()); } + // R26 split-launch: optional per-head pv_threshold buffer. Parse the CLI + // comma list (length must match nhead); empty list -> scalar broadcast + // (legacy path, single launch via host). + ck_tile::DeviceMem pv_per_head_dev(static_cast(nhead) * sizeof(float)); + std::vector pv_per_head_host; + bool use_pv_per_head = false; + if(!pv_per_head_s.empty()) + { + std::stringstream ss(pv_per_head_s); + std::string item; + while(std::getline(ss, item, ',')) + { + if(!item.empty()) + pv_per_head_host.push_back(std::stof(item)); + } + if(static_cast(pv_per_head_host.size()) != nhead) + { + std::cerr << "\n[pv_threshold_per_head] length " << pv_per_head_host.size() + << " != h=" << nhead << std::endl; + return false; + } + pv_per_head_dev.ToDevice(pv_per_head_host.data()); + use_pv_per_head = true; + } + // ---- build attention args ---- ck_tile::stream_config stream_cfg; stream_cfg.stream_id_ = nullptr; @@ -354,21 +416,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_args.scale_s = scale_s; attn_args.pv_threshold = pv_threshold; attn_args.pv_skip_compile = (pv_skip_compile != 0); - attn_args.stride_q = q_strides[i_perm ? 2 : 1]; - attn_args.stride_k = k_strides[i_perm ? 2 : 1]; - attn_args.stride_v = v_strides[i_perm ? 2 : 1]; - attn_args.stride_o = o_strides[o_perm ? 2 : 1]; - attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; - attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; - attn_args.batch_stride_q = q_strides[0]; - attn_args.batch_stride_k = k_strides[0]; - attn_args.batch_stride_v = v_strides[0]; - attn_args.batch_stride_o = o_strides[0]; - attn_args.window_size_left = -1; - attn_args.window_size_right = -1; - attn_args.mask_type = 0; + attn_args.pv_mode_compile = pv_mode_compile; // R30: 0=none,1=warp,2=block + // R26 split-launch: when CLI provided per-head list, hand the device + // buffer to the combined wrapper; host code there will partition heads + // into 2 buckets and issue per-bucket launches. + attn_args.pv_threshold_per_head_ptr = + use_pv_per_head ? static_cast(pv_per_head_dev.GetDeviceBuffer()) + : nullptr; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; avg_ms = sparge_sparge_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp index d600ff7075..cbca128ca6 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp @@ -7,6 +7,9 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +// PVSkipMode enum lives in the sparge pipeline header; pull it in so the +// kernel template arg can name it (R30: promote bool kEnablePVSkip_ to 3-way enum). +#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp" #include #include @@ -21,7 +24,9 @@ namespace ck_tile { -template +template struct FmhaFwdSpargeKernel { using FmhaPipeline = ck_tile::remove_cvref_t; @@ -30,7 +35,9 @@ struct FmhaFwdSpargeKernel static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; - static constexpr bool kEnablePVSkip = kEnablePVSkip_; + static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_; + // Legacy alias preserved: any non-kNone mode is "PV-skip enabled". + static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone); using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; @@ -99,6 +106,15 @@ struct FmhaFwdSpargeKernel ck_tile::index_t nhead_ratio_qk; float scale_s; float pv_threshold; + // R26 split-launch: when non-null, indexed by remapped i_nhead (post head_remap), + // overrides scalar pv_threshold. Buffer length = num_head_q. + const float* pv_threshold_per_head; + // R26 split-launch: when non-null, i_nhead = head_remap_ptr[blockIdx.y]. + // Buffer length = nhead_in_launch. Null = identity (blockIdx.y directly). + const int* head_remap_ptr; + // R26 split-launch: gridDim.y when head_remap_ptr is active (== bucket size). + // Kept for future host-side asserts / debug; kernel reads via blockIdx.y. + ck_tile::index_t nhead_in_launch; ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -165,7 +181,12 @@ struct FmhaFwdSpargeKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + // R26 split-launch (default-null preserves + // backward compat = scalar mode). + const float* pv_threshold_per_head = nullptr, + const int* head_remap_ptr = nullptr, + ck_tile::index_t nhead_in_launch = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -185,6 +206,9 @@ struct FmhaFwdSpargeKernel scale_s, #endif pv_threshold, + pv_threshold_per_head, + head_remap_ptr, + nhead_in_launch, stride_q, stride_k, stride_v, @@ -224,7 +248,18 @@ struct FmhaFwdSpargeKernel const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; + // R26 split-launch: if head_remap_ptr is set, translate the launch-local + // head index to the original num_head_q-space index. Null pointer -> + // identity (single-launch backward compat). The remap LUT load is uniform + // across the wavefront (same blockIdx.y for all lanes), but the compiler + // can't infer scalar-uniformity through a global ptr indirection, so we + // broadcast via readfirstlane. Without this, dependent offset/buffer- + // descriptor computations spill to VGPRs and buffer_load_dwordx4 inline + // asm rejects the VGPR operand. + const index_t i_nhead = + (kargs.head_remap_ptr != nullptr) + ? __builtin_amdgcn_readfirstlane(kargs.head_remap_ptr[blockIdx.y]) + : static_cast(blockIdx.y); const index_t i_batch = blockIdx.z; const auto f = [](index_t dividend, index_t divisor) { @@ -402,6 +437,23 @@ struct FmhaFwdSpargeKernel BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + // R26 split-launch: per-head pv_threshold override (null = scalar mode). + // i_nhead is already scalar-broadcast in GetTileIndex; the load is uniform + // and the resulting float lands in SGPRs naturally. We additionally route + // via readfirstlane on the int representation as a defensive hint to keep + // it scalar even when the compiler is conservative about float traffic. + float pv_threshold_resolved; + if(kargs.pv_threshold_per_head != nullptr) + { + const int raw = __builtin_amdgcn_readfirstlane( + __builtin_bit_cast(int, kargs.pv_threshold_per_head[i_nhead])); + pv_threshold_resolved = __builtin_bit_cast(float, raw); + } + else + { + pv_threshold_resolved = kargs.pv_threshold; + } + auto o_acc_tile = FmhaPipeline{}(q_dram_window, k_dram_window, v_dram_window, @@ -409,7 +461,7 @@ struct FmhaFwdSpargeKernel valid_block_num_value, mask, kargs.scale_s, - kargs.pv_threshold, + pv_threshold_resolved, variant, variant_params, block_indices, diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp index 4bf3c9d296..0a8baa4e62 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp @@ -11,18 +11,40 @@ namespace ck_tile { +// R30: PV-skip mode enum. R25 A1 shipped a per-wavefront vote; R30 adds a +// per-block consensus vote (matches upstream SpargeAttn kPerBlock semantics; +// see R29 researcher report per_block_vload_guard.md). kNone disables the +// skip path entirely (AST removed). The legacy bool `kEnablePVSkip_=true` +// maps to kPerWave; `false` maps to kNone — preserved via codegen. +enum class PVSkipMode : int +{ + kNone = 0, + kPerWave = 1, + kPerBlock = 2, +}; + // Sparge variant of qr/ks/vs/async pipeline. Cloned from BlockFmhaPipelineQRKSVSAsyncVSA; // adds PV-skip per Q-tile (SpargeAttn paper 4.4). Kept as a separate file so the original // _vsa.hpp can remain frozen as an A/B baseline. // +// R30: kPVSkipMode_ promoted from bool to 3-value enum {kNone, kPerWave, kPerBlock}. +// kPerWave is the R25 A1 shipped path; kPerBlock adds a block-wide consensus AND vote +// (1 LDS slot + 1 block_sync_lds) so all waves in a block agree before skipping the +// PV mma. Per R29 audit, the V load / V->LDS store / cp_async pipeline stay +// unconditional in BOTH per-wave and per-block modes (only the gemm_1 is gated). +// // QUANT-HOOK: future int8/sage variant will add QScaleEnum template arg + per-tile descale Kargs; // _sparge_sage.hpp will live alongside this file and reuse the PV-skip path verbatim. template + typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + PVSkipMode kPVSkipMode_ = PVSkipMode::kPerWave> struct BlockFmhaPipelineQRKSVSAsyncSparge { - static constexpr bool kEnablePVSkip = kEnablePVSkip_; + static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_; + // Legacy alias: true iff any PV-skip mode (per-wave or per-block) is active. + // Kept so existing `if constexpr (kEnablePVSkip)` reads still compile. + static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone); + static constexpr bool kPerBlockPVSkip = (kPVSkipMode_ == PVSkipMode::kPerBlock); using Problem = remove_cvref_t; using Policy = remove_cvref_t; @@ -140,7 +162,22 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge static constexpr const char* name = "qr_async"; + // R30: per-block PV-skip needs one int32 LDS slot to broadcast the AND-vote + // result across waves. Reserved at the TAIL of the pipeline's LDS budget + // (after the existing K + V allocations), 4 bytes, aligned. When mode is + // kNone or kPerWave the byte is unused; the sentinel cost is negligible + // (4 bytes vs the multi-kB K/V tiles) so we always reserve it to keep the + // smem layout uniform across modes — simpler than per-mode policy plumbing. + static constexpr ck_tile::index_t kPerBlockVoteSlotBytes = 4; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize() + kPerBlockVoteSlotBytes; + } + + // R30: byte offset of the per-block vote flag from `smem_ptr`. Lives just + // past the policy's K+V smem footprint. + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPerBlockVoteSlotOffset() { return Policy::template GetSmemSize(); } @@ -513,6 +550,69 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge }; const bool warp_skip = compute_warp_skip(); + // ================================================================ + // R30: per-block PV-skip — block-wide AND vote over warp_skip. + // Hand-rolled (no `block_and` primitive in CK-tile, no + // `__syncthreads_and` analog — see R30 idiom catalog §7.5). + // + // Protocol: + // 1. Lane 0 of each wave atomicAnd's its warp_skip int into a + // shared LDS sentinel (initialised to 1 by lane 0 of wave 0 + // before the vote). + // 2. block_sync_lds() — all stores visible, all waves rendezvous + // (uses the same s_waitcnt+s_barrier discipline as the K/V + // LDS chain; lgkmcnt accounting stays consistent — idiom + // §3.1 / §4.2). + // 3. All lanes read the sentinel back into a register. The + // result is wave-uniform (and effectively SGPR after + // readfirstlane) — used to gate gemm_1 at :607 / :665 below. + // + // Cost: 1 LDS init + 1 atomicAnd + 1 block_sync_lds + 1 LDS load. + // The vote slot lives at `smem_ptr + GetPerBlockVoteSlotOffset()`, + // 4 bytes past the policy K+V budget (see GetSmemSize override). + // No interaction with LdsSeq rotation slots. + // + // V load / V->LDS store / cp_async pipeline stay UNCONDITIONAL in + // both per-wave and per-block modes — matches upstream SpargeAttn + // (R29 audit) and CK-tile LDS-rotation discipline. + // ================================================================ + bool block_skip = false; + if constexpr(kPerBlockPVSkip) + { + // Carve a 4-byte uint32 slot at the LDS tail. The cast is safe: + // GetSmemSize() bumped the smem_ptr allocation by 4 bytes (see + // pipeline override above), so the slot is dedicated to this + // pipeline instance and never reused by K/V tiles. + auto* vote_slot = reinterpret_cast(static_cast(smem_ptr) + + GetPerBlockVoteSlotOffset()); + + const int lane_id = threadIdx.x % warpSize; + const int warp_id = threadIdx.x / warpSize; + + // Initialise the sentinel to 1 (skip-everything) before any + // wave votes. Only one thread does the init; the subsequent + // block_sync_lds() makes it visible to all waves. + if(warp_id == 0 && lane_id == 0) + { + *vote_slot = 1u; + } + block_sync_lds(); + + // Each wave contributes its warp_skip (already wave-uniform + // after the butterfly in compute_warp_skip). Lane 0 of each + // wave issues the atomicAnd; other lanes are idle. The atomic + // is on LDS (s_or_b32 / ds_and_b32), much cheaper than global. + if(lane_id == 0) + { + atomicAnd(vote_slot, warp_skip ? 1u : 0u); + } + block_sync_lds(); + + // Broadcast the consensus back to every lane. + const uint32_t consensus = *vote_slot; + block_skip = (consensus != 0u); + } + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { if constexpr(FmhaMask::IsMasking) { @@ -530,6 +630,10 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge // R25 redesign D: when kEnablePVSkip + warp_skip, we zero this // warp's owned rows of p_compute so the unconditional gemm_1 // contributes zero to o_acc, and skip the rowsum. + // R30: per-block mode uses block_skip (uniform across waves) and + // additionally skips gemm_1 itself (see guard at the gemm_1 site + // below). The p_compute zeroing remains so rowsum_p -> 0 and + // `l += rowsum_p` is a no-op for skipped iters. constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); @@ -538,7 +642,15 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - if constexpr(kEnablePVSkip) + if constexpr(kPerBlockPVSkip) + { + if(block_skip) + { + p_compute(i_j_idx) = SMPLComputeDataType{0}; + return; + } + } + else if constexpr(kEnablePVSkip) { if(warp_skip) { @@ -603,15 +715,39 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge number<-1>{}, bool_constant{}); // load next v_buf } + // block_sync_lds() stays UNCONDITIONAL — it is the + // workgroup barrier the V->LDS rotation chain requires + // (idiom catalog §3.1 / §4.1). Only the gemm_1 MFMA is + // gated on block_skip when in per-block mode. block_sync_lds(); - gemm_1( - o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + if constexpr(kPerBlockPVSkip) + { + if(!block_skip) + { + gemm_1( + o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1>{}, + sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } + } + else + { + gemm_1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1>{}, + sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } if constexpr(std::is_same_v) @@ -659,16 +795,37 @@ struct BlockFmhaPipelineQRKSVSAsyncSparge k_pre_np); move_tile_window(k_dram_window, {0, kK0}); } - // tail — gemm_1 runs unconditionally under redesign D. + // tail — gemm_1 runs unconditionally under redesign D (per-wave). + // R30: per-block mode gates the MFMA on block_skip; block_sync_lds + // still runs unconditionally (workgroup barrier for LDS rotation). { block_sync_lds(); - gemm_1( - o_acc, - get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + if constexpr(kPerBlockPVSkip) + { + if(!block_skip) + { + gemm_1( + o_acc, + get_slice_tile( + p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } + } + else + { + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } } } while(i_total_loops < num_total_loop);