From d8046e1bb41c64b330eb29ad99ff33d62861d8ec Mon Sep 17 00:00:00 2001 From: ErvinXie Date: Wed, 24 Dec 2025 15:39:44 +0800 Subject: [PATCH] Kt minimax (#1742) [feat]: fp8 kernel and kt-cli support --- README.md | 2 + doc/assets/MiniMax-M2_comparison.png | Bin 0 -> 75156 bytes doc/en/MiniMax-M2.1-Tutorial.md | 198 +++ kt-kernel/README.md | 59 +- kt-kernel/bench/bench_fp8_moe.py | 286 ++++ kt-kernel/bench/bench_fp8_write_buffer.py | 294 ++++ kt-kernel/bench/bench_k2_write_buffer.py | 158 +- kt-kernel/examples/test_fp8_moe.py | 457 ++++++ kt-kernel/examples/test_fp8_write_buffer.py | 389 +++++ kt-kernel/examples/test_k2_write_buffer.py | 194 ++- kt-kernel/ext_bindings.cpp | 60 +- kt-kernel/operators/amx/awq-moe.hpp | 776 ++-------- kt-kernel/operators/amx/fp8-moe.hpp | 782 ++++++++++ kt-kernel/operators/amx/k2-moe.hpp | 1302 ++++------------- kt-kernel/operators/amx/la/amx.hpp | 3 + kt-kernel/operators/amx/la/amx_kernels.hpp | 10 + .../operators/amx/la/amx_raw_buffers.hpp | 488 ++++++ .../operators/amx/la/amx_raw_kernels.hpp | 464 ++++++ kt-kernel/operators/amx/moe.hpp | 663 ++------- kt-kernel/operators/amx/moe_base.hpp | 763 ++++++++++ kt-kernel/pyproject.toml | 26 +- kt-kernel/python/__init__.py | 17 +- kt-kernel/python/_cpu_detect.py | 91 +- kt-kernel/python/cli/__init__.py | 8 + kt-kernel/python/cli/commands/__init__.py | 3 + kt-kernel/python/cli/commands/bench.py | 274 ++++ kt-kernel/python/cli/commands/chat.py | 437 ++++++ kt-kernel/python/cli/commands/config.py | 167 +++ kt-kernel/python/cli/commands/doctor.py | 394 +++++ kt-kernel/python/cli/commands/model.py | 409 ++++++ kt-kernel/python/cli/commands/quant.py | 239 +++ kt-kernel/python/cli/commands/run.py | 831 +++++++++++ kt-kernel/python/cli/commands/sft.py | 52 + kt-kernel/python/cli/commands/version.py | 118 ++ kt-kernel/python/cli/completions/__init__.py | 1 + kt-kernel/python/cli/completions/_kt | 153 ++ .../python/cli/completions/kt-completion.bash | 73 + kt-kernel/python/cli/completions/kt.fish | 74 + kt-kernel/python/cli/config/__init__.py | 7 + kt-kernel/python/cli/config/settings.py | 311 ++++ kt-kernel/python/cli/i18n.py | 655 +++++++++ kt-kernel/python/cli/main.py | 436 ++++++ .../python/cli/requirements/inference.txt | 6 + kt-kernel/python/cli/requirements/sft.txt | 7 + kt-kernel/python/cli/utils/__init__.py | 3 + kt-kernel/python/cli/utils/console.py | 249 ++++ kt-kernel/python/cli/utils/environment.py | 1108 ++++++++++++++ kt-kernel/python/cli/utils/model_registry.py | 374 +++++ kt-kernel/python/cli/utils/sglang_checker.py | 407 ++++++ kt-kernel/python/experts.py | 8 +- kt-kernel/python/utils/__init__.py | 4 +- kt-kernel/python/utils/amx.py | 77 +- kt-kernel/python/utils/llamafile.py | 2 +- kt-kernel/python/utils/loader.py | 111 ++ kt-kernel/setup.py | 23 +- kt-kernel/test/per_commit/test_basic_cpu.py | 3 +- .../per_commit/test_moe_amx_accuracy_int4.py | 22 +- .../test_moe_amx_accuracy_int4_1.py | 22 +- .../test_moe_amx_accuracy_int4_1k.py | 22 +- .../per_commit/test_moe_amx_accuracy_int8.py | 22 +- .../per_commit/test_moe_amx_bench_int4.py | 3 + .../per_commit/test_moe_amx_bench_int4_1.py | 1 + .../per_commit/test_moe_amx_bench_int4_1k.py | 9 +- .../per_commit/test_moe_amx_bench_int8.py | 4 +- version.py | 2 +- 65 files changed, 12111 insertions(+), 2502 deletions(-) create mode 100644 doc/assets/MiniMax-M2_comparison.png create mode 100644 doc/en/MiniMax-M2.1-Tutorial.md create mode 100644 kt-kernel/bench/bench_fp8_moe.py create mode 100644 kt-kernel/bench/bench_fp8_write_buffer.py create mode 100644 kt-kernel/examples/test_fp8_moe.py create mode 100644 kt-kernel/examples/test_fp8_write_buffer.py create mode 100644 kt-kernel/operators/amx/fp8-moe.hpp create mode 100644 kt-kernel/operators/amx/la/amx_raw_buffers.hpp create mode 100644 kt-kernel/operators/amx/la/amx_raw_kernels.hpp create mode 100644 kt-kernel/operators/amx/moe_base.hpp create mode 100644 kt-kernel/python/cli/__init__.py create mode 100644 kt-kernel/python/cli/commands/__init__.py create mode 100644 kt-kernel/python/cli/commands/bench.py create mode 100644 kt-kernel/python/cli/commands/chat.py create mode 100644 kt-kernel/python/cli/commands/config.py create mode 100644 kt-kernel/python/cli/commands/doctor.py create mode 100644 kt-kernel/python/cli/commands/model.py create mode 100644 kt-kernel/python/cli/commands/quant.py create mode 100644 kt-kernel/python/cli/commands/run.py create mode 100644 kt-kernel/python/cli/commands/sft.py create mode 100644 kt-kernel/python/cli/commands/version.py create mode 100644 kt-kernel/python/cli/completions/__init__.py create mode 100644 kt-kernel/python/cli/completions/_kt create mode 100644 kt-kernel/python/cli/completions/kt-completion.bash create mode 100644 kt-kernel/python/cli/completions/kt.fish create mode 100644 kt-kernel/python/cli/config/__init__.py create mode 100644 kt-kernel/python/cli/config/settings.py create mode 100644 kt-kernel/python/cli/i18n.py create mode 100644 kt-kernel/python/cli/main.py create mode 100644 kt-kernel/python/cli/requirements/inference.txt create mode 100644 kt-kernel/python/cli/requirements/sft.txt create mode 100644 kt-kernel/python/cli/utils/__init__.py create mode 100644 kt-kernel/python/cli/utils/console.py create mode 100644 kt-kernel/python/cli/utils/environment.py create mode 100644 kt-kernel/python/cli/utils/model_registry.py create mode 100644 kt-kernel/python/cli/utils/sglang_checker.py diff --git a/README.md b/README.md index cf13969..9c5047d 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ KTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](./kt-kernel/) and [kt-sft](./kt-sft/). ## 🔥 Updates + +* **Dec 24, 2025**: Support Native MiniMax-M2.1 inference. ([Tutorial](./doc/en/MiniMax-M2.1-Tutorial.md)) * **Dec 22, 2025**: Support RL-DPO fine-tuning with LLaMA-Factory. ([Tutorial](./doc/en/SFT/DPO_tutorial.md)) * **Dec 5, 2025**: Support Native Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking-Native.md)) * **Nov 6, 2025**: Support Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking.md)) and fine-tune ([Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.md)) diff --git a/doc/assets/MiniMax-M2_comparison.png b/doc/assets/MiniMax-M2_comparison.png new file mode 100644 index 0000000000000000000000000000000000000000..408457af22986eb4dc31d596110e8ce416a4a7fb GIT binary patch literal 75156 zcmeFaXH=Eh)+LN(m6n!DnGjIGZ~y@@5ETg~ghLR?SyYl_L2|UpKq*NN83a_4WC;fp z1XKhhiG%|J0xF_p1jz_|b7S4N-yi+6N009DwWF%MDsrA@@3q&OYtFg$y><4C-1@cb zYnhms)>GwAsWCDAzMF|@<-Olm;dkz)|ILM;BphXR9M$d299=Frm@+9}aJ+17=V)zl zai_DX!xalVTjBkO4(>m^Z|5aP$IDkF4ji!gudmo|=U{%|g}&1pe3vzs<#n$xF>TvK z{#{`n*Qmy{f{BTG>bQpMy`fGQ*DZ|s>9P8x;0+I2&z?B&%ccYuSqc8zj(_gk61rRG z?CpyyHGbvuK7K;gJvh}dgs0JCw2}HtLb%uBhpI&#XYn@5_Dvra3NsB%RH8k2gtzBh zsXwwRzTwgqeDZ&OTwQtC)cHTZAU}VoI==gt|MnY~SvPI`Z@;`+cITh}_DiN6H&6Zg z-+uk!AO%bHKYsY`+K3nWuOElk-TM8%e!;}_{DSOrywsoug zvTQpGb#1smS_M=9cdreRU4 zR)!wG&g1)@o}QmSecJ8g6S& zk9pb6^%C}MaUs_q!HE15m*_Wg?{zGHKikxbpaBkEX&rrw%x zr_nAcZVF|uq@;m)eJo?NbSd-2wbf~7ADm!6c>e0{J$st+l<|0yY(InGYVzBwD^`lw zzCWDiR2MDY`uA^t$e(ak`Y~3?CE?h2*35T>q$-bcRLI5XVvlLduWw?4c1BhDHC)OY zoKcRGj)~N7?8oB_cN7d(1Z+ zucd`)t-*7iW3qyJdGfZlwn2ya6Byk-OE&app0nd~Gd)3?_m8bOa-`E|!>*k>Kd9ck z$*YrVRyWaV{xlm_jl5H=+j4X>#MuL)B;6dMDl3<&sw>^+r-t3X#<++e0F>IK@n}NZejbr;Ti7IB zA6Eo&SkXsrY&-Xs#;y2yWW<(#upwS9$Leg>xneyd@7Y9Ix2|Vy^z$!%iC&PDlvLX& zN;4j$GVbn?OSS394?4^)Dd_O^O~(7DSInJR*(^W3x+S)G4F4=5$@*(EWyb_)-a%k$Yx!QHjdYV!Fe^ z9#huq*@RAEqg~|J;lW)!|IG|4x0q#qoQ;IqM|Z8>lcOZ-szPz68764EPeB%r3Dhf^78WkAV7q7rB+rZP=l?z zcElaF|FkFYOI27yWjlkRDdjPxv9!3bb;}kL)$VvX=JgyB8u6zCa@u{WXxz);E_vhL z&#jhR^gkURxZ}vf6!qkRPM2=R_j;ME9Mc+pxxgLc-@e^ROH0c!uh0Bc?w4j-6E^qb zd%e@#ghf$>80$I~qrd+8D`}2fG3*iki2v;UtN43K_wVmSXdA9(=FDRyvBoZH8Ir1z zLeAgp%*-Q$rw;6gFW+Cpn`|+D>N%V9r9LhzNP14598BSQ$#_Y(@z+gF>Tlk>8K0cg zd~o7M^Zny1*Y7w?(nsz*Ym(fF6MuUzEw~iVj})!lutDpuUw-@C-_Iv5F5delMqkvj z`D~oDSKj51CuC2Z(wSghvnJd9o-Z@k{{BUYQiR&srX*#}$LEi`xw`K4S^Qzyk{(Hp zGFw1}NzZm^&nJwrlP6A4sCVz)jXHSY?*T?yG9taItn4JRpjMIdaGpc2g71#_?c%nt z>+0%uukx@B>A3CF@KBb;Cd;TalgsCaD&9mZMJ@63=;#4$ZEf;F+L;FKUn0#Xss*z8 zu)eMBbgPU-Nug@qIM0X35EYh&*quBa067W8)G8_U>7$xIhfp-s94j{p)V~h&99%IJ~us@r5L`yjUF%&E5IBZj(2@r7_e$R zyUvhr2CvrU&6_8&=4p1_C8=tOx5s;HcwJmvup@g99z0twv-C`Q>Bp6r!}c$CRbF10 zn;d9N^iv4gyX(M#1E>zgNUKS6q(X4);al9yd#0~=x=WqBV}F0jGq>-l7s_r%9d*gv zv}x16DH|+C()2-m_pN)+9vg0_H{o@M+v8R*%zlk?M@rGm&=c*gk2B)rFVq44Nd)C z#$Yg5IXGfQymBXn_o*fP({M?I-FE|rrSQ$0>$kJ-L1we=$baV^M3t96xPkTH^S;lY zO;lS8ovhTCYF)c;1q6sr4Yt->T2qbAojX?{vpBP-ZH&V{^Ic)4+Fbh+`*>MN0j>0d ztA6{fjnQ3IRW&n^mc@?48JlFxx^=7R(7Wt-#@uw5M}2cC-b<_}4hOxVtgOt;ACV{8 z_4G<5o+m)uu1iGBqLI?(HPBg9a0oxz77nooRA38PYr+rMwc2Cj56*L#C_9vf*WWv0 z>^aqDp3!nZH|Lc|#eW*8ggnkgH{O1jln z-`;k>*4{hau&CwXsN#WQ6Q^4D$pO>Zp78diJ?-vG79IH+vbxhFoq}^d<5j%QTr&EV z#^ax>Qlv*?VqyXo=BLF^%gX9g>&UwV))k81T2*xllM|Z@XJLkW< zCZE+7&43cInj@RlmUI{ciR_0F-r@XlFT9sEF1EgvpgYIRh+^DRmNW~jeNejl| z#B~*WSheMtw&$4oMTyy1Ug2<}7P^eaw&&YeeWy0$4aGYyrHmjv2LP=s2J94ZX7J~k zfgDE%P+Xm-73;k3Nqgz7XBT;kZxw7aJ>Bh-cKq6^wd>c@J*=tiPp>4$`YgJ|9J#{N z=8P<+`Ta|s%^4ZH16C@|>;b3iVve?&#Q{hR0G_e*avXIIiH?r$pF|iku(29(HTL%Q z7SaVnqSo&YdX1HD(;6h18^u>ZKi`^-{iNkH?DyPA(P&6SM1*w_qR65tDWD>VE6&;) z5Yv6OH{5CQhg;gD?yv!yn016+zO}{l+{atDZbixXl<;e3rmtMRA;D6A&6+iq-6h4| z3zMyR*ZHY{Dmimg!)bkem!fT2(skqa8y4>6=U2etJ!9W;jid};E^qn#I+t#*WT3OY zw+~J_m%9zN2u_Z4>KD6DDooZ7H@HPz$Vi~BBJZoDUrb9vG@N^fr!AaK{r^vc&aqVr3~4Giuj zor$q*Ns_<2+Y%5cH9s_Ly6pP(UkRjh?@tWhXLjap!o4F`qRdiV&nhYTyHE68aP~i7 zRAMdWr$`%W%eDGa8|mOs{VvFWt2g^hfN!{q!l$oa!!LB8pE=edyWEKgF|e7FZRT-_ zsr8`(+IglmEJf{8!yVd4`sxM-2DfcSyNc6*wvzZ+Jg)otrl}{N*}(A~C__}x=8mda zeQeApkqRn7A6Ou9%^sCl$+pb?8g2`0KY<$RqKjow;&z#cA^nj0Na0E|E`6uYA{TKl zI}vtuuMIN%$TW}L&a|ee%L!;_h9oC{-4n>~wwXui^*5V@NiTC`UT1nSBK5RC>jl+V z#+y)TFptIO*LN)nAGpkpo09fMJL??1s2;_I6krXz!Q=Sw+_X;QWzFO>96-$VnNwxH z%-SX<35-JBTr-Og&;Qc0Ge|NSXiiD-9@Q&$ryJJ%kc)Nv##jL@$I)sm3y1hn^SrfDhHOV6kOvT|kUQ5V}Lc1A-(L!tX* z>f-EIuDXarCy``?2alRoZ$W-jDk&*RX*cU{OcZtfb{?ll#bHUm zZMn&t04SPNWGqk*S%$yXMJHDU^XePQwAR!pA{MmJ@MxrJ@b%V3Ya>ZJxVUIrTc@L@ zC!X}(V1d|TQQ;VA%az+MW_=dlL7g7dV*T#%i!ZDD_V3?Mj@UWZkG~O+fX75*bsdK* zIEnt2I|b)V8jUkQLTM%luG;lt+G-$7r#J&%5^}ULJI-)-!m8cHpGbL7W+b zty%hc^XM|0i(KsRHV2gFKYsf3@Jx(EKYCMPf$k!gEL6+~=pmv)wb_f4!+bkc(SdVTQ- z)RYV?cU4V|K$)?88d7r_-kIOv$z$8DXDL2Q3)-rxs%ADqJzkaF-A30|tyjBz`ElmC z;?%5?x#|IbUNoMq0A$uv=RUo@l8Cm$^8J%+NPOGTA49oy!c~Q^!#%s%~TC4a++R==TlP3$6igA`)wl|iRQCu;Bb{2lHFvZ;sHLp zu4e!%w*ggBuU=&;nky_SvU+>}c&c_*?3*~5EYvW)_twD)IB$xywJiG{sYuVB0w!wu zQWrf5%%DG16LZuh&vVwHc&gpXtQXO*0qC7xp<3+j1l*#DuT?*F>N@%UBsM4_{;Y>i zkkk+RS1d(fjxBn{&Q~+il`|B5iEb(tD9tmm1 zfb>DRxw&Y$Ch@ubS*44OcBKnuhte7&4jDHW*cth?F(=F~b-#Oj;o1gH$xB*`{mF5p zdIJ8Hrk8wmMn8o2R%Cc5RK`oROrc@F!&uTkdi3ZZL))nCjtv;4oqiq@X~%I~cM2ys zFBs_$7U-i#w{zJ+vXR6HRv|sVphb&z=-|P(tnBQTE%{#aQ(_&ZKgMg!Y~thN17^Q} zEogVf`}MciBLsChT)usH;XJ8A#V2P3#tECB=>aKXJ$2N1D8hbjqMxNVza=Hi#PfTd zxL|vCiC2ft!&7V0hCNUS8l<1nhwz;h{iXo3MOXd$tB#F~!{^t{9O=G04qtvg@C5m= z;8Ky(pjMhqWx)3Ei$_&dRHCiU@*eU2@x=bsKN~CFy?f_49NZA`qyF2+EId5>==-O^ zBQ{BS!@#nXj`|i4e@c(Gkyq#^XJH(K*4EaGG2hTsv0@`^TRo_!V+4>3>4gi3h4vj+ z5SC)y`dNlWW)0r$c@|n7Si9{6bWI-;VO>{Y-3IUL)_H8lzYy*<&-R@^gJNth!qyzOy+QCf)#=%Q`0o+E8Q>R1Aw# zoagwhLtR1k>iErN?_o>q@HRH3gvSi(EZQ}Mv!Jr@I@i4t%Y!l=$ zJ;Iizk#f=LeL=T(#|k+@Oa|=|ke4y>sVJr-5cF5|Sp~q8FX=wHN2- z+-EGBC>~QoREC($?B~sBW}-tG8Tzee^Q0XeqJx2m60x{Q&&UuB0n*&MeR~?nvp&-Q z#7H=4fb>D<_!`(cF9;+~cDoFH{!-oKU@M)3gcDk9MQ5?U)e;p_^b2gA?P}1^6T)L} zeXLYbT}s8$kMAl)mCWnbH6wdwgfX*=?CtE)SATHgH?)R~hn{#)7Son6)5}gyn&=B^ zt2w2a6Xyq=lSO)tlTHhWql&>0z43@+ur^Xy0`Iq{)6B9tOa<^1uTFCwZfEI5 z&+$0o#;zlHQ?1$Y-e4Q!VNb#d=6HdDY<_y>tBH@1v}J#BrqBuWs9V6zXn|A|S*|u! z53%u07x{eU_yjm*teTcoqC~0&tj}1)DDUo%Fna1Bi>A9dEvqCINRJ=7y#i&ez&!?VOW$SW{UL!X_Vx`(Q~Y*zqawqK?vl5SV`z^a?WJ2ZRSb;n z)6TrdSW-l^eCU7PNXV(n{~R)|xCB^K=rL^@HJ9uhV`U{y8~<8wN4LSdYCdzdFY561 z_0<ti4=H27t zh}PO>opd*3w_)4)N8K*xd={M2dcmRP=`DWqplTh_Mm4sWyh#rZ4lY+-Tury>&_X`Q z+q(n~z-GVEx5out?lp_YG9FKLjpEdRzDgUUqsF~|Slczdh`82`>5>L{8yg$zeA?4U zM<-&#j>G5k)BOZN$T)LbHRBn<7YPY$2LZ;iPv<_0Hx8>B;V8(Og4cQa^l2jh7dh_Y zR?X7~S=YJG4;R#h9e4^z)yJTqad~4aeOV?G4^q(Oywd^_Eds{dnmtV zwOLGfc=*~ad+(z_>ovDU|8wX4Oun}nz)&Ih^txZycsyI28zkgld4S)MD_*Am-d(ZA{+kP5>ny=># zVW22~a9|*6C|z*s@np%z{(cM5jW60gZWWI%K+cM4kgOf)wS>G*Vraq^ur(%7Y z%|J_rm8kvraZrbi{%pc({>do&YuBuyW_Oke&FORiB5y7J_VN!2*KhY_4G9nKz2N@a zs#Ox6Mc7?e&&MKQ0GJw0K{U=GEjCG=u}n&z`u;sv&(^AaeBymxSi(R5{9cS_=~p;$ zN39?zU#F;nr%jrt|8C>vQh<7`nqe;qsi$Ko5Q4E^6o+T>kwMA^hBl%~3fGBUut~W| zxIrLxt;(V9(o(Qp7Ol;*08|ZVFc}bVkZ7xic*5gJmw@U~py%@ai!yrnw*#BIBl)ne zeEC)W(;QY!A9(HUvt=q}1nvnis|dutY7vc5D=uEcam3-X{IvRP(D*`kgGwMlF`3R& z1}hsEcpogy^e}G~)DtPsM+*VW5R5g>KuPM4^I6Qc+Rj#|ic|NR6eLqan#rrW&gR$E zm7~xS9!FAkb(V?|oRpM}{xV{O@@rK|_ll|o&d^1>n~!Qk?pC&kp_NTR$E_9o=j|IfsG>_7GwJkfbug_6>pN`yi>sj^aFxwAw-50Ht3=wke>N) zO`|-P9g(Y-eenh$;m70ikVQZ)Ze*&(G`+ z85mZ+p7Yf_nkF~9824VeEAR0%b)tDkG8q0o@G^uck#HV*1b~;|D!qtPV#!DiE#KS} z`Qk2?x(_T`T_JklG(26WRZ-+r35->4)VP_OOT1=B&Z=!`K=m;9+T$Nt8-?xn@> zfiY{rBf;k~Gczq)Gh;wr^Y$;!e}|akI5nipA?bRybYUvGynY0KqL*WGi-=Ua4oUF= zk|&wRd9@69oH47v`Ry=xJ74MfA>p)JH>$YZ5?-~8UPW9!0)52NYq7v_8C(gfi9nke z()9ocROg|07h09X*|_5aC(Fd_y3(-|x-q)RVCzX-K{6V7JPd97!&t>nhB2H}Ou;NY zdMaZaWy8eeFxzf{VPn#qNcXzpS#u}<;EZg%^7)#sp zL1Y24?pZ_p1ZMvh?sx7E&OH?YET)K6s*v=Sy}%`>urGmC6Cx^JOA7^M zssK`$h0{HJ9*>7{KbX~&xjRR9?|wO;@$JHe3t}>t!2x|8<|Kb>4u?sw`hdZc`%qhx zJ${7f=*NCUfRPL!>!y&qYXa;hk?RT#o?nBcpB|Q+x9;YinYpYT`@L?Loa(2IshR?5y`W1NCdMe%hzTRmkmRqX6FKSbII$ZP$PHwO{MSTt;A|2!)I-jb6^{)Efi?im2Pc9$tEOb)OGb(#F(k#Gc3%D^%yxa{4TpJ{l~{E&5%98 zf2ugAlh;PmqK0qY{P@C2&2$}AZiSFLrKP0?hK5e_Q#K&PjnFaD7>mEH-6G%p=3N!c z8-z7PaY%q1t~QI*N$3^R=T}z{owZy!5w$7>V4j#C_%%}r z^dNZm3;Q64gyUIgggZ8=Yl&R`aBKldn?-!SHA6oc-E|t$3XxVrusG$Fm4jfPCEO<- zO%1ihL(GY^o$63vQV1dz1W=_w;NgIR4@kOzW{=MGE+Fad-Mh(xk-*8$&SqKQV$8qk z*$>^IJ=@s#Wm%cBDfqMcSO0JxK(p8iELvx#NSfT5JMJpLL-zj`s7O8_R~G1V!}4B3HgB`8nvT)XmS#n9KyY z%UPk`J}WL3v^~Ym!($1IPxJ|0x6`A>000xflsd(#Y}>XaqcZ6&L?(p2IDe<%QZr@DHb}#jh=rwTW zVnGF(yR1XA<(VM9fFE?{VLgJq;oHc4B!kM#$F?i2>yhWCE*GJB1CpHyA&cHbx0w`r z@G?sF{(wk>(RL;F@{+png*+9x@bcHRH3knJJRq=?%j+yN6h8(&i4NvjW6}!p@+#Sh z&uiAMtseT5(+YH_eV;3WrU^_j7S$M7;({FoVI+i|xCGUp{6TxYtIg99N4!(MEpf}hp~xE6NYj^D_K zntnmM1NmOhuu1v|Sy;U5kflD@qcj!Xq*kk};;>H|FZl#(?r?$wXBW3MxEWxl2yrBj zNVmCO%K{oQCfIdOdy^?Q_AVFBu#I08aF{2NJ(5%Guv2*#%h!T7?|yK|R2Bpp0|X^h zh=2l0iug4M}W|w*O{p7&%i)=vfkY}2Hv;2u1-+1r=tux1>`SX*XYsNT|}oWw_m$$|GfblpmZG&wHoA^J6kuRmgeEx z<1fCkT!08iQWh3tc%z(IT;-$}Z(581X_~MXL+8%DIDfDYz^nyQR$S`#HiRUKc(ht^ zJvtm99pju)VGE?*MGaoLCr_XDY5NPryoVW%mc8%RU57mcHhrZ=f}gLvKP0QSU|41Fs^A5kAb?DE*z*&)%$ z!>RQahbb~C+P}Xc+gOfn*L@C4(~R;=H?G`3P908`g1r2lxVSi|L7IsQTSaZ(>wqy; zWX!thXxsFQpqrWvw;2KQmfM<;HqZO=6i!i#0%T=rw7 zAm)Hy0`Q}djOllv5+L@~@$saU0 z4*^UY`avS7j7!)QQu6X(kRaVU9xN1kF&J#3@uWidAqo?=C-g$Li7C{g7N8ztTC+fN z-Cp8V1T`n^?c1{`+^=D<0g0f7yrBW&i6C^utAc$?Xx`jVZv6x#WlOZXSl#fc(e5mw z2;qzoZ5+0$XV0F6@Tzc+qzhA$8UB}S-JlZ>U_$CW0xroOpQIIG@PMuEp67nNRe3>-< z`3wB_f34H>(EoUz|9nF1y4wG#PxkeBLS*5*5G>L5;NLH23;gM#eSIEiI+UWBqQ)J( z$Nt$^Po)(ssO`lbE@r+oejot`TAC!B-c*w+8r&JgyoOw$QDAFGP{Zwj0D9XJ5JJ=; zXQ_w(bRvf=p`lB6397seQ zKYrYOnBJP_*sm77fN)BLm4Sel@$cV*@tNaa>e9&R#HRCcysRX`uPDqmIGTHNPiFUH zudx?T2q}kWHg!Inl5?SKRneQl6)RWi4ErF!kum!axY}?CeE#~?eBcV4iCYSMpcQyQ zyd#i*?N)vY99GcH5Sa?>uq+^__VgK&Y{H~-;UJ7EwHt^Yz!YmjTo*XK6T=1FX#2Dx zIReYz1FC_jhDMv5MGZt|?S?8_B6lK+?c9e42h(6a1lo+QCV-GIs_3wV2aD{wrGHz! zT44YFHyU5Sr2%8(kjT>@_D=wlreS}~yf$yzqyfe_cUWgLnq!6QI%wwrGRzuy_U%(5 z`482Dn+iH>adDBH6ux}4AH;9K;)P~aA7;ho<~Km4;3&u+r@+7J1yFF^q{njy14=(6 z)B@sG!xxI=mStPwTN1zwgQw%yHylq33(cJ6owozuu43lghopk+)oE6QXh}m6VCf}A z9eHkFHF)_goSgfihg_*eqO*b?kkoqS%o*Z4gq?~ey~wYZM~uG-IB&^j1!OguQeG|h zvrKnhh}ph>gghjrO8uRdXI{@)US3XYXjtL8q02*NBJ42zY665RgLKu<&i5E?-dtpZoa4@QuH2gI0ZODh^*7<-!l zLv6g#IM+D#AR3JFqptr6p$ECo0#%ln`&bT~OD}R69YClPA2)bS;v1&JlJ;;^MP?xZ z48`94`_F(qZ9vr{4@JHjQi^+TsJ7BD{`$SHvY5nIm6g$EqmRxNyYt8ycVLgWJL6<~ z?Lh{P)W6+DMcPrAk}L?jei`^56+acd6}zyRKO#RFy^#V9wwZg@oBpPd-TR0)WPXh6tf!1fC{>OKL+I0+lN?Z^sWJrZ8Kf&34qZ>|c3y3n!R{B&352;mD+ z{d6+~G*d%}X%uI%507t_%R>PZ7lk520ZE_hz{|@^(lqEJKb(t4h-2{DMlfdDr+y+_ zE$FpfD3%M~%9vRuiJ=EAC9$AD#32g^lql>M4-wcM9hBf44-d}(B0HP{fuLn~!~IJQ z52J+51Bs}OvPwE|Vy98t>Z9e1;|XrJ3KBB*0;f!w+NyQ=@?{=5nmhFtC&BQk^Iu-y zJqdYHL;!7EFl9b0e-^eu%~EeKzw5t&67qsNi98em@-PL;=KwMXe6`0?jhCkg-ubdE zo9uV4#F3WA0fIu3_s9d59@B%YN-7_S>lye+6GnjE-rht=MFU^~)h=Fq$jI(SJAtG| zthwU0@86;WeT0Nng(QH3-9(JH1hiOoLxYWIF_b|`s+;lKD!_wPN7;#}z6pO`e#Lx= z*ZjT~pHej6#lCxM?46uKF&RMNgOiqJ5=cB7p7giZ#Dj>4qG96!`wbzTC>DUm#ELH) z&8MCe5&I(+TB`1*6AIw>AAq!`(qPrFqX+o;`Qe3msY$)NgkEC99w30zbh0lJv5ORM zWOhrU?9PV>XrICB!}%9qCnn`SanNuZ6=BDVqg)tiBf=gEESY;n%;MG`RQ_|%p7cMH z0SAs;LTt>$5k^ip5MZ9$ceAt(LL;H#pio#>xPE;3Ta$BroQxem#O8ebHnAF08)4Z# z&teSy)jBY8>-8Hq2tPvmd`Fiw#1GT7v)ZN30h`cbL_fZ;`VIilAIMAs`22R^-Q?$E ztJ>i)NtkcC-I0pVPbXp&JZ;*<>LL5?O}X8-DkHYeZmzkMay%?a6vAr+Ik3crizN94z;*?VBl<7a6o}j_u3z2XV1}++eDsJ>h z9xOJ*^#D7A0u7kc)cO7JM(|(~GJ#w&w)DEig?ZphX{s8r@Q9X3A;mFr%~b>qo}6u5 zVB?4G33@aY+jP?B$CxiXSPyUtYSWGI!hByo#|Emn?_PRwWeST~9n2d<5+nzQjly2n zvyGzcoSe22qs%)F-{+Ka-;2o(xP=vzl$z0olH6$1mLm%%k2VOyx(aGxe!Agz=1SOr zoVW6+z0yE-6zzrd3%M?Yg3%JiGiR2j*O81x_Kz}*!HzVL44)s5QRdA~$y4>R$fyO= zv4KN-mTRWT{e%-tcdfhJijElrm$Ej!YX$H37T-ne;l0D+Vg9oX&pJpedQcL_;` zL)ue^HUXCoAetVIWORQ*R^;TI9)p`0E*J8LjQpIqoQpHILD+WUT-$kr?MQoJ%UGAK35kWX6) zDFN$ZTE}I`#qbbucgh^VXb+5sd#FT{BeasJMdK;Z6~t*@VB4wobh^s#_HE*}Ar|}I zs$e-lUJbk|HqU8c&WSjrXbI3aEI>S(d6E5t>k!&qBsZI+jWYHYQB~ZAzyKMZU1BWPao*dI?>y;aUh5f{sze$V|fRpdjs1?}dGtbqQ7i zPq%+1trA7|eY%X>cmkk5rGK&R2Vy^}{IJ)YW3;i^Z`4sXw9~2QiVM*{^^S~WK9uzv z2oUdj3+_@~^qiWSZ+N(NRA%{VY-^O6Gl<3{ELn8s7>Nj`sy#m^w1i%Rs3b%eHSE^! zSZM+fD0~P@qPu{LA!ILgI<5uQ)O1Qe%2FXl8+4bzWb&cX(n8<{0jqWr5lDl#0(k}D zXA*T`7rq0^h{#}^(uJmtu>X&OOj+?0SG5|qnZ^zCZ&p=(kw-$^YX9;?yWkaz`Hh<-B*exx5UCf0cm^C=H?yA!U6?cFO6 z8eXfiyVR!??kv=;bRGc;_bP6eE?(|U@@l|x#6toviZ&8FQ8iGVOBUy>sEyDxlx{II zPZxY`SFlLh0`wbgOfiQdWKoid1;w-8NgC$i)&bq$|LtR$gM)(_1Tj2GGGQ-KoEpHb z4|5*qH$R8umK(WCO#tW${AMfu?S;jBZkDu6OhOwJg~`vyR#6Kdkdy)(rk4?%KuVW# zeeCBKqF=bYC8YaT#W*hA94_loRD+*Q1Li3X8E*qvGqMZ4Jm8FWJ}I8!gINwpA+}fi zLq69kvDf>f!?>;NUcCrZL8m0|%u7-@lIw*x#$!JgBO?Xjr1LrFPe zzH;SCR1b44_)3o8h$tvXyM1|){LQ>B%2bq%jAA5>5+hcW@<-I}q@Jku-aRx*!9BlX z*y_;z#)?c&kB=*;mp*@fj7+})xtt~Ve^g@4t!q5;{O{}543uXj?#w*-3#0Y|^Lgjt@kh))c0TuS;- zCO%BF4GZH_KMk0h&2|#p*uF3YGP!4wF_G7`Pr2RO-8M`iM@Q(BsvZzP|2$!wbsisE zxoX{8ASc2Z5n13%&Q|96Pksj`i0SB(NGklpde2<5C8eZ>2abX#%`hrGM-y?MxJ2{? z%OXg)1eDT7x-Tu_sWoB0Q1`-TPjuLorl6>j!i_X#686t|5q50D0n#8q#eR!Z%D{vJ znPEbDVPMoy)cc34s9l!!q!|S8_U|4_wcB@EDD>{!W7}b!HvtF``OFI|{OZJ2wFekrsirFYNma7#tCsb`kQ@ z$Rz{#ihp}+Lfgpv_;yYD$G4OCJd9^d&>-Otx3BidFf7ud(UkAqH4!^{l7=Q2xlmrM z;*AJ6XA6b*7EogAoI$fCQLeAf%2=H>>zmc)YwQRUG*U(F0M0$3v8XC<=%ERS@1!`c&kt9as~O_U>K5~3 z{m;nDzx`|FY6@Rlato}M#)T=fA6Th9JO!&3NaKcX`^BfEOt=)4qvN{RHs(Yj&qtT| zxE}HIH7s`5gw)guhZW{zAEHklMAwkx!SWlGUBdAZLW3Vvl=8j9huZR~HEZP?!5G-) z#z6a_d`}d>icFSoGm@uC!$%mcGB`cJP^3A6Shs<9C)4XO8Q1GKNC%3~u?c=Nyt+OE z-qGkyd#pR}m!j20<8`q?{Fc5kbljnBly_5skE3>{Nyw#T6-Iw+M9`+mKmzPM-*$I+zMujQ#~sIoRwP%5+W znZ2MpFq{)vJxtxLgWsa0pcO+khaJ6R zV&R(jI9EP^GNBH)9XJ<)wIdxYCPFY%0kJ&G(W!XcN&9zzT< zM1M-A}Yve3BILf;D`-+9e^^aUW-z|o+)j$Oq0Nf5k57Wi1y z-JjeZ+s^)eD;%}l?Jt>IK+CjyO8h`DN_frXQEH;2s0)en)b%Z38HkIYmVj{gMWHpc0E(rhckfd!Y>PtiBGSeH>lDF`I7E%j(=XV5*?bGM&6W-CJjA0 z8CpX4Irl1!SjUw*YB{&K$ef#399oGj#yPFmRG!0~P`*qj4dP|c>xDyrw%}Lkzo(23 zK;Gu2Lg6RLQRl{)CoXL!+}eERf>uuXv9eqw~hfi;RkB_{L~Nii`k(BLX!Z$L#u zcv2VZkuS`r`89Atk1T|u7Bw$->QsW&e2Y@V!Q(UnHH%^U=F2ZIcfPe6;9wx5pi2$& zUX&C}3lPn(Jbg3SiK}wgmPKi%ohvDwZ`02{_`AI68MQ`;QXujY&8CS6d+)$ND#~^= zL+Xk19OTiAsU@)5j(u-;k@g89)uA*r<-rugQt!@(MxLI}Kun)6tg;)f8cB%cgqiQW zQ%A~14i3TihKQsruxjnaF-V6^P2>62UC(q$AW>4#0kMeZaxYIXpg&4SsFLi9nyutS z=phXH5WfVavvhW(zd!O~95mzCRsD-`V=dkz2w)=HVz&vW=rr`ML1X$uz!6UFqyU@&By`Wwq9B=%%we2w+kU)LJia7hDQ zb`yQUmjDI$1i~o9?4h+;l9TkgC@~5&GCJ^S2I=H12!<`2Htj}S5)%O$)t}}JCZUZ# zN7(PIpy>xmr|f^a-I^+XiXeZn=oi!6Fv&56tP|{;2g1j zhnj>|oQ_vTE7frS__b#EWkFh{fG|%=E(m`0>J>RVM7Bb7pi8ss$0$mkguroyOiVY! zy`e*E229E2T?!6HPeUec8W&Cl!~Yv;`xfO@1$c&mL(%*3Rn8nI%FP(HUY6lk0SR2S`8d2_y>PMBXxSdPP{vWLL?HxD`K1A zZE|4k_MXom(?Q@7!MvO!(Lq{^DKxNiV6(S^-UCEJ=aqu|L4HU1tDD2u^a-CDe#w0} zYUp&r9GBV$pm#39hz`5z>q{a4k2b^~M%Z3r*+b?xTac2HBF1m(w2w4OatO3?un(e` zkwo)^AzXRdzkQtW8#r0|&@@)TJOMgNEh*3M^cO%rt32-FSHF6bII(xX3PP7ob;3o^JM`mY7(;PAzs zuNV;_$9kvf>E+X1287bSx9Ttk%%Ytz^Y*x{6nl4?S^>r+B9xbhhkfD44}Y3J$Os_> zME%|>GDtfyHg-rsNI}Tm)s=PIKArPULy$aK4Gj&IZOBY1_-}j%bdK!U85i|Vcpn7k zj4+lKf%L1s9Skd(T=hpCQ0P^4cPs&6ul|hYkHlSpA)YVt{6i+5v0`$j5#G+$`A|6KBJOnZQibKZjlT7Fu-F$A4LHVFs%+b)WM8& z8Baz%=>-J^hbWX+mP_b3BHQ7~;@7mttd=oyQZ}M5=@9m!p}V_#Xc7ite*_qd*Ph+G z{qaYlex0D^^O*73qS&KWM?)Qq9Z9Jhp@(X!AHY(d#`wm2RKI>b;$7=!x5p4ImM3D4 zWxv6@=C2pvd=c^+doc(*`Iib9G}*b?Krz_TmjI#yfULI2AzT7rCpkpC$BfE()i#`w z+pS^nmXDEX$BOcDT@GWoW6g};%{kcHAHvjbG0i9Qc$)(fU4OltjoFD8cEu!1^d_4>{(35U=B zv^lI`CZCLNlrdT|e+E{>5HJtu5gAmjLo8RAe|Uum|c}3n(O0 z0yQBu$7x_{&&}PPALpKkb{GesLxKq4II0EimBkAuoSZzQMOFBg&H}ZODnrJd7Q!X@`%IE8>B$d?m7FPuqA$$N#D-<}Rଘp^G@4GlYd^lIyKx~YP6|!xr<~9eEw5! z`^Ne)#YWN+EN~&L9cI3FAE0GgM(GCnE#aV-{1H@?w+RqJSfxu{wt#XfAZ1qK7xVy zSLWY~t{ap!l)Q)%5ONeMytih(;3EMM3s~&{gEqY8hgy7gBaSP0;xrG@C3Nu=Ljep@T~_wP;S z$Rgx$ZN=dKzut!rm#Y>GM+XK-2R=JH%ZIE)28qc)6w+%NG6NJGk5>)9FkP)&LE`_a z+!sPVqpKy?5fHqFuVIjTW@ct)YB$s+efxHKvajLVPeS3jI$@{(9j)EKC7q7UGf0F2 z0@Gpev4Bbbw!eSv(7_c<$5@u{irCros=Ry|%A*j8fR(J>P`Rw|z$DDl3ml|@x9A-j z+DCb}2;%j3=ktjKPF z9%>_zdZ2e98+1SHggFLqSAdvdgCt&UiIMVXd6*8^PjDn-ieyYA5n+ja%i$_y(AYhR z8#{>lMQR92zi!@9_(BQdBy<2K{UiV6L-Mm+W>nbBAW*`eCL%D8J8B5*)y>0QI2Xf| zPW^NHcFtY_gD01HPW{Yc@utJ0bNC)spvdO1R&~|jTjIg#^#Op6PfdAWSH*6mzvjy& zIZ-7XgMI2GZuX*eC<2ebO+}>*h7E3pMog|=C`m+|Kut6`$Vz}=S_hdAhIb(ybQT$c zA>+p&dH??V@8(3sZ3GMou?J|)^VnR$Q?odci+7M1mcwkVxjoz3yOL zYnSa8AMU0Bb^xoJXoX=G1}YIhhS|mDv1cdHfY1vKUS6~9E4dFPjb`0L}b;4R-fkI^V%a5Gx6pgfS!o97rkTk|!Kvoe&~1qictG0W2vQSrc9J z@d8LU;$_C$7}KSNgw!x@OUB|c%@;wwqpA(j=cw~NzSd9&PRx;gbh#+#Dq#7}gKnAf;<^1!{ zoiZ||CiAhFfua!=hLNtTPIH|3b|E)$<2kq9tufMG?dhL?H_*# zGsp|UFu?$G`uUT?QmGu3D+z3b3{l0}TnVB?6F>Ccc==4lDuK>v0m*HxS^~zqse2fO z05@PPm^37T3PisahVF-?8^Ulhm=P@K14)k8oh7^zT#kDFUx$X$L9!8p8a$Y95!aL0 zJsoAO2q&dCA@JKL;gy6S&|+~VR&A-jRA zga#ob=0)VI(g>f`dKmp$3o=?AjT!XX9_k#_Ju+bxO~+L);folqgFf4)!2-8CqBBbB zf&EY4+9RkJw1hRAc`yUzyHs}Tccf-<#xgwomPgDOjBy-w#lq6-FA?IFsxn{%2$ksI zP$*-Oix!(40OXCg%{z-Y`VaMBa-Q0Z@0%#S{ls$syEi$zGnVULJUN{ue zlzkv@$fX3V@sJ;h3;^$eQU@0%FiMsdJ#w8PST}KS}}%xw63=E-pfu5YV%!jH$0Q)* z)5`VB$LqWyBz1-Z2nK4)(wXjJj{-(%)B07{P>U=Z7GJP!#9=T~?asH+-?~rd?MMXI zJ&r?EUK76>GHG`u7NJv-84DvHh z-a)JF*hJVp6P|ED^4&xzxt;1L0gsRmNOO#oW2+tKYcA&c4kTCOR&`vG)SvS6QSUs% z`sfPQL&8m%Sc=(SweQF-iz?sMFj?-`?9ufj_tlAYY=aZ!;n0i2pg1B^#~><&sQt4+ z#ujU|>Zid?@+qLs#i*9I=MRz_55biXFBXvuLB0xW_Haw7UZ^Qk6iX-v&O}c> zzzQNMnd_ex<1}E#y}J0Jvu)$X2XM*jh2ZEK;<$!8E{#t3VxaCc@=Bh^QQ$6NuQ(W) z;n#kZz6fpIIVe=5q70YB*yR?}3s+Cqw4Y=FgYo!C7w97OwG9LIe{9$wc9eyWOcZt6 ztYCWmJDG~T%3A>kiuv7t&OUW8CCsVwBz*7Tet#@UEn};lN5I<j@qK0fH= z6$_VK=E}gH6OynRHuBJTuEnFpa3i(D{jE#r{ysE(RH`ewi=-w1z0bq3EwimZ_H@0X!-%=9VV^^hT zFyFJ7l5qjhf|%XSm^uPki6}Qn;{))F*Y*AR6WFGcNl(OgE8<>ofJ_A=n{-bULTzDD znT~^75>B1FegM0-5y@NdUw!W5N3@XqkOs z*U9BH(sO-#!1n46!$X(coo=m;M8Xz6iLx*s+%Kn+awHt4@r|mT!^7uhc6!k zF-Vvh!o7gI?L#j`2DLzDq4EpoMgUhNk)|Vn+-wYLMF|&6s)9|)L(K*=)_hUM2&jNu zG&0eSc@kKx$=p?KJ2v7BX4=kz>;n+bBaZAa(>n>c{j(pXuN6COqqPVJm=)%#Qh%C+ z1gPNL!veWvGQn{2@7i^o)=an3F3MlYbnNu*WuT%;iWSO{?!Jc0@K5*&Ew9q~4xH~{ zOKn^ML0o`fI}09>`@)Een5ZFUy_mAq!>XhjM*$WPgE^iqDwLN(A+v(Gg-`LeF`2Q1 zZ#ze4@dJ|j9Ec(7DN-IuWm8>*Bzp;RB0&nevp52^fe%1Nl3SdBhZ9(Pi1iD+FOd+; zdP(bMilCQWRz_p@!s>*%r6OGBP^2;W2kxstODDXSfBD%#{*Hfc0BK6c00&GpRc1vn z#$0-NeGRlw+~tSPja(%2@I5jb3Vkae1j3;Hg%wdk$}YJ|MB5W#0Ynl>rjwr*6?HJ=XbxX$QxJh2ot&t>F9|>N z;Z`tU6Y=>Hgp80No&@5@fe3#a0v6a~a`OZ*QW_Em+~W2*ToKFf^DPg2>|IdIXmYk< z-Uf%43}KM!O-g`^&@s5r$+i+10^>+AX$7bVcwz#9H?eJ9a<-VAInOal3QS5YoW1X>I-$%r(aHSE4OHiB#17^Af2? zXa@M;Pa;);R5?Tgj0}Qpl$BJ{6p^POSJ{DhPU7Yv?}n80$Y*Jh7_9-wh9u54E*Q&P zeDymrk_wUv`Ax?U&sPv91FHP^!~`y|cuEeD)cjDcB?b*Y z1eo0o%zUpd1=TNC112`U;Rx?( zBr~!{5fCG(zfiNZGpOK0RkZ~}^b9i6jGs0wdnbw7W(PA@Ue1;6>F#P%>^n76#lC(n z?3{Vgno?^l3EeXnDpLa@rEf*IYUtTBukEe}km97IMN7XmbpR`UYo`*?D?Eo-D*nAU zS8gWJYT||{H`HRrr@l3P$x7ckm66%XCy@yo&HMD(R~L)}qnFjnUd+r!O$P=SmPL-% z#I=!0)O8fLj&OJ~q-X*I*JsnTLoG4(Q=Qz~tlbmBKW-fV3^~C@{#ZcQBbXE#vm=(~ zuS!X&){Cm3x(X=Loe6`76$0yU5gg6a`smIFw?X_0;9jKwPu3Zgns?3scjA%I=Jrql0hwUP!}2Jz%d@Y^ck9BBT3o1dneNs z#s9Vbiq1~PS!;C9vVID^ibF{EZ%xk6w{F(xaF1RN>U{Y%*<#+o18`0~U+|6y8Li@Et{0 zBu{%G69dLDrCnzh$a2GSk;v2_jT#sv_HuuwUX8#!a31tA4ym-JQ2+~#-~Z7@$mCuN z&gj_;_*L^wG2KMb63Ck*gR=p+V96r9_(Le%@7J7D+ffwoiYYFsT(HwXJL0?%%J z`#QJvyZblB-F_QeCKH+3ku&4}G{%?-Tr^qluhm#{p< z$}LvIepWSZ**neC`WR1Yi$=u`=&}Cq^iu%PPV|7SFC|gHVzhnNdePJ`>=(DbIsMV) zVlGsAbIl6^l$2sdPXmTQMm+-7I2y25_edAYm2cs3{DK+|q4HKhrm?#}7-=)>(io7+ zaZtLgUbuX&G(nei-C1LJ=HAW@n{p5|b_^2i*DX_dwX`oGGwqu-YbJ`BK5xiqKgQfk z6?(b?7@*s$uf95L>T$bURR_=g@a^~dD6=;9h#g0&BhHwH5>Ig3=@fRhWMWVlhEcTd zqB>~A)g?44ypx~5jgH9tpgk=cn;fb> z*u`bbmbnb*eiv*_J%BLcj3Y>B?lVOQmFLB`%mjomFiX#cMK|q(%2R9V*z@(?Z5KcI zmN;&1V5+j?uEksur#vuBlHtv=gVXX9FwJSl0E9?H9-j|t_=;QfOTClY_`NyNcokmb z#fO|VKnU4pz|t=pvvdYh8`94*u;+D#g%0jIyd7esI>IYiStA zMF4ROi(@?EW^rVoj-VKlOA@B5i<|a{v+0GYPF~n^za8A!$`eehd*|?-=&Lhbt zTMt;RAKJ5PAJd%-_=1n}s0iHP7-1)VZ0pc48C#|kN&3COnSZR9F2s6cQ(rBTnPQL``{g@jBhkb>&&il4G*MnUf7cQ5A^^Je3_UH*@b%TOA!8od=K? zVzMVAj2~5(3EbCp}2ZH^ToEtbw?(xN+htFivmaq>$MHk;g%O=i07=q zU%%2Gn=b>;PmC$O{?MqYbzjXQTQCT?<0Jc-pg*3D&Mz`tM2X^bsp*Lv-Ak-kb68{k z+8W#jxYRlm%r*}JKTDewH(_6TkuM8gbK(*l7e2Bty5Z8AT-lj$ytbBCoWF)BXfYuP zU3@|o>$B0VueLo?f2j~-D^6QUJiZr^naomvKGVB-F}H)HH!Kc;?RI`A^@G>D!8tCj zc<>W%W@w4rMO<#{%$heuxl$MmivwBG5o23#Y&Tu_-8ryLZJC_fL)IIJi`Rp%5j81J z&r^`SpIQCA%hPxf;#8d}N`RjX(`@te5U@7#6fO>W{K+TRoKBw)p!Zo(v-p?Wm@I;U z+tMh(W@XB5GIPPFNE4=0J9y55?zTZLFW2tYcxl{~e{oHaQoOo7geC#tZGEZc4QS`X z7Yh-&a_rd|&qbzdZ3h`Kj3panj)pCTkqD5jPL~@sxBmNs!z@OM940rlFg$GAaHV_c zpf_GXC4%jx-LQL#v0KZc<%=hET3m74^fC3{^h=Gb8`VDHz%h#u9>}rD>Kux})7ag9 z&fSud43Rb*`SRJr(9M#=BRDqAL|BOZ=j@(Zs(gSv!7(8OL7i?ff+B5J>7J0lb_RS* zBXsOY^Cm?fW0fSQoc63je>m)#{eXiVDCOild}o(E=#Z!gG#XnkdH?|&X2CC zt_sG*^=x4c zKCCTaO#!@A7W#5a|Mlra-247B{t;aSRn7@T$Gz*t!h9$QiO4ov?OmjA^N$#~9k zti~34f=l7++9gwLb~f3YXbgwuN-Tq-F@`L~cbtws!yWtLA!B65c-g*o3_xM4sfV}_ zPcFA9q)4y%fU#F8j7{9}jd|k!{hBtZ1_;H~cIu(sznYZVb(i6Cs%h}&uY$#LzRkO& zr4Vdxr{_{I3hOn$#kmB#wmWv4K&|CcjiI@dz4|{s zz;j`B=@ie;Okee0N4o}tW(;^P;x7YAdv!p%^#<|@CMlMjXIrH;sx8n=zZtmg(p3A7 zZ_;C{TqB=czH$-;laXmu(-Ve~zRZ*HMnyZ0vUhX4_Tr-p2|c_P8-2+CQMr3PhnPO1 z_P?%=Tv*h!3A3}?!(Xe+^vnfeFm6Ui6y;x4Vcg87mqba>y*bik2d9%?)B91b$sRo9 zc_MEx0%r5Q7=A zcZP*U=T+Ut4$=Geu z{RzL-F)+5j{oL*Aol<0WmP|XoO0#y&cL2t0ji*O7i~_chPu6f2X;}9YJn4<*JNOJw z!fu~FQ33MNzSz2T40~TA%Qu~YrzKuY&CEw&qrpQGnoHH!Bqa$Ng;L_O3&FlnA1TgE@{v>#AJ|fmLJZnY7*?c2xu@)} zlM+gj(bD@fniP$wHm{C^(_~Ga?GW0jvt0NXO>~*X;G&6m!T8j04Uf_^6KspJ4-@Hk zaOmMLKNbBkPBbstG3@P?n>ob5H0c|!=x*F!qLrmQ;Cd3y`pB*;;Lw|d{pbGuin1kMZ>e**;xD4Gj zk(<{>8S~9VYNVJ+hx?6PO!C}}n`c|(A6}E$inZGo{-bKzxgPhc{q+CWPq+Q;CjXbQ z;8Fa4kv0f?j)Mi(X%arCt!QK{t5b-T2!A&7r6Nv)_;Z0&+OVT3h_@@QKW%yPameKz z)ZeBuio*clZhKN8N$kn?$RH4PuQvhan~#>qWsP6d z<4)m`*hdvbqhl@FUd=!7H=7JjRtcLp_o4+TN!R8zH`~^%9zERF0ubOdDgWGt?@#s! z0Z;Ovb(i;D2iGAI?PSm&$n>0~Y^R|k-MV!3U;<@Cx-lux(%QmWmG+bR+7-Q}-qhWi4L;_!a`6Z`rb_9=4=~ozCyd^_1%B|BI z3hsetKBjNUW139Vk=tk%s_%y;xra`iyQ9roGLO+6+0CRpAAL$nN=j~{4l@t_y4Rs~ z0Moe>(9&4VF6T4uZn-~0&XK?~uc3;0@v9&7ZOo!DpfpH0*=AABxTK1Vf}Vo7Gy4uV z-M;g1W~yBTW<%}t(0$*Pxv8XtctppWX(+`=7~KJ_QeBy|8*c7$JIk8e{dmXC)P7x+ zcQYTa{gkp#K>t>ET;8|C4^v*7cg1&{(Vf$-hZ~0?!Pud3c|lw>XV!TqF3~$0(ef$N z8;=eoYndD9t3u2#&WMVUx@G9v-aYOdDjxn#IX1!fj72+-H*%ZwEq`fe3!UB0dk(~Y z^Qs}9yY6hr9PbrV#FH@?SraiQDrp-!8A!L~7Gw2AQJ=Nb3OhQABky{4eD~{9uoXj2 zH0v>bm3`%n#)-GKu8rptazF35eEShJTs7v&?=+jZiv5|m&<|uj5kb}Hkt4soxxD}F z)N*~&&NZ*szx*%AZMl?&oprKD&x1uqnvH1;Ct1MRn#By6*@EVwxwZRyy)S+6U4adH zs#7p9>*=+(fO=eHx?2L$yRTc>TJy+=AH&~y=ZAtvoKSCe*P?c1I)6E^^g>7bn({Nn z*OxSAiu3wpzV6)ik>AaXPwr&PlkvVZeVfNuJ@#(jb(_&_jwEz%WLBO!Gw}TOk(`tG zEn^xt6F8!7tkq)(buj6_CU@<9WzMI^lVK0O~C#E)f>Cjp*XH@l4A0$jB78nu@Ge-ly7MoPYD1b12doOsIhx zw!dsx;*At^8Xsf1r@_&2W0z!1)D`gKdmI3Yk~MV8Rh@>_h8cAODt`Gc@#ilsy%8;W zVx%uOT2@}r0Ar@(_GO*VOfI7$v}Ok%Vro%&zI{!Yq7?5aYRkKUmxk6V=F9@eZW_^L zGdkulTOZHl`9vNRU%>z`uZf~F?&d&5`qT+|2~bp3aBydZCbT(}u@#W?exc z3kNk>*VIZqLW@Gj{=e>iDKZ8W(W3oQig+tZv=1265J4yFj!AL>$fggAVJonV`P|{GL1R`ILF*@5&&zHW zdj}XQ9kpUq+`|h~S7iHUP252ZTjESLnE+Tcjb~eZ>`a5}Q)|P4nc0T!7(+&MTzdBl z;r*Nq-Ltvni%+MPWCPhA$~nNv&Mwo|`5yUb@xE+thvX8bmQ{r_g8d?&oIuGu4Flqt zT_@YDg_5j%#=d1>&B=Pe!msNjB#6dBL}b-IS$b!QFjY};RL>Ek9S8a@D-8fNbBF~O zi;(QxjX9l{7jU4vwfm`+t{;8>gtCa!2WtNQ;H@gzj`&n{E!x^nc2G||U zxoYKq@2i)tK65?l_hskFPQFNR?!^i%*wrh&Z%0_An(RGU6X#9j)VPU@50lcivd1l5 z9l$soV*6a4nNXS6C&@t5pFo7mTjLemHL>PI%h@?7?ObnEnH`CIylyQxTU6g$5s(x3 zSfegaHLWD$q@jVcpEqml*s-xZTH~%@+vnOFAkN2nA0`NH9nocs+iZf*4rcF#^+H`2 ze5vX6sQ2pkNcSn0`vw_V%ie51Vb-3EA-_I#cdU6` zgc{`1-6h_<^SU{sl6N#fo|!^O_2##&9{Q~JyPr(1h^^4+skS$kHnp2KQC7GvFGkuo z9&mS1a#h>P>kV93fl+bvdp}fGKTLAuX(`140Uma*-gl%V2L0yT#KI@9PyOvaexTba zCblHH&jhBleemO336usIWpi$4FW7OWoMAz3-d8A%h$B+lrRkm_Kso>ZyhzY zh-GE@)Z-jC7k9$&gynZ^+OXcU#`_nqy|&r8d#Ccg-S183<5P$Axv~nHZq+Qez_=HF z+ZHGG7eq#J)a#~#LWWLzvevB%+fYHo-9OcFMb2mTmG8^_CA5r(b&)vdLz2p4*i+e? zaYnNZtA39-MAUH+);l-#Jn+jV*Ev7EW4XKZZw=Zw-%dna@$t`pIoj>#I%=NCF|&Sg zngzx;G1Vx#!<2XH7MYn@mzJi2N&ob}|I5>N9;j6nm`%PI+rPjq>fqxuzCF_C&b|W< zW*^mY|KvKm`ZxH0%yoN?QCr-Hgw5x?Qf99WO!(pD37dH^(tA)!`T?`|)`Rg&B{2Zt(OzC{IQlxaC!u8t2H!iiHWMdp|k`uz&u^NvHLo%?8E9 zgEmYEbC5IZj|8S;db_9hZ54SvC=5 zz60*j=uthcJZ3g|SclpuF=*Xg#N@%Q=o3m)!#TtL{Lpp%_Eb~=q9l_oLLGW0n)k{6 z=&d6U58nyq%-yPAXKQPf_)=>ZU%BJIZ|0c>X3po%PpAW&kER?TPZiVJIJqH%f0c5) zsF{B{G6SFUBM0Yh{H#s43A4bkyCs5toU+u#&4pCw<1r^u9O0esJen(ur)9$u`3~T;Th>({1zB zg`_v17Cv&IoquJSA}lRl$&ieB)MKHp5g#3UiYa-wh74@VbIh-Q#ajX@Z??6)9{AXh zHyb8Z=34do==#wu={3<=?fq&VU9(E;`SMo_=oXwJk5G7Phd7q_@*wV$OXhD`{}kOq zUlVI(7C1hi7&vwKz>@O{;D%8Q*K}Qw8*F>;?@h=3vH&+<jaWP1rq=#s+WFpR4V{CRKCg|=qg?Ds z#|KlhY$hF*E#cCK-p-SA?F>fH)w~0Ru?|p24tet^3@OlhG?z%>>1A)0lK+&JPra)-)WV^(MLwnLU0htI4#T72&HM7?)+XHO2dA|jO+e2M`FT0WdQtR?x{pIKC8>eWy2G1^>?i)I&hxfx<6Jl?_z`C6r{6rItDu>DU zV2orKA}ZE>fX-UwHqsEfBc5-PNycpZb6YZ?<8s%dK|D$PKF~3VUX``V3Q(YHrx&?3z~z zkJKlC9p^)}zLiSg78um0Le@Z27yjCj)Sd_ZHfnU=hTh}3Bpq(jba>zP>fsZ9 zdbefFkG(W8oqMpy)Z(w)?Mjf)BOABdTz@TOLQutX(A;THbq-%i&h+42#kL6`dUzlG zOM{#p+tv@^(Wvdnaz=P@4QI`wWd$INftNm9wxl7WbW)z~^LR)eh{h>;w)MPo64B#) z^o&rO9oIfed+lV->Z!|jyrzLC`H@kvLv-p7c1rT=M~TB{>H)-kxzO^!&hTM>bA0jX zWixp=RBQ36JslX?eilQwT3s;y;f#v8f4pt=tMt%C)>-(?&83u;eFNuoQBqB{o|e5& z9EvUv#=bieHqK&<1-`ThC{J~>^D)7N)Vwoye*ffjLhoK8armO!iR_N*rx%OA-~LtR zZvW%)MZ{u*5f4}n&%8@v0!wtm-U&CFHpzTFiMNP6FI+H@D$pQv#1UyKXXz;KH19pr z`N*RycS53XVWFcFU9~53yd@qIN5xG>(B$d;GcJA&AH4sB9gd$x%lrQPxw6Z(ip+gl z-`PAu+q~6N@2B-T<9_|<-)u6~b6%Xa?}{~24#jt&XT`vgIiEb!^OQ;h^1KyKXp)6} zGriMFHLQf*BO8HtZEm9n=*M~putsCUs3JRBUAMUxJwmMx4&}Fyiy zmY&{`QY6S+f1rh-(Tt$>IlTDeEBwbZ-Y@#>owZO;)B#>cOqE)T}cP z-8<2ovtTZk-UeO!=J_{`;WTI$*__1y82RMzxHd^e@u^00n)>0sB;C1}XYC8;LxrN? zfOESatw@Lrab%-@zicSr6QqMuwHCDb#Jqr_T_>&4SVSMZeIUcfxEYY8{9Hu!Dnv?o zUs}eVLeDtv=gXw|L09(ld-R-xK++tdup_{S4?@|H?8{TW>S0YmBNP~XE1iaRoj(2I zN5`M@;pyf#urK?**(EP7Yhpn@qQ4&b<4970@!hy9T>^c6 zNVYvFmd}ai)0CkNnjJuf*0n6Cym#^(YO3syPqto=^X*_K!z^Cr&aGX%o_`;(*bi3R&f|ZS9)KVyhQb*&@f{TKGq?| zEuTJb+$_c3_U)%Q>e!mh6`hdYubh;qD(sAF7LRfO6|3@G0uQKr)})B9lR&dA%f3-$ z`*1?jisy7VJUDX1A^ti0&+*Adt;&`~+)K84?+|U6M~n%=f8FD`4ywWVs`r*WS#xI> zU015=LQU3Elq)HN_gG2opUW!iGVQ5E{;m%!)>yZBjLlROCaxbGv*Psnf&C5frI);{ zifyGd1D*1H{zi~Z#GY90#9*Y0?eTDu!-Qg&y;tU5-m*LiT-?8f6p$HnyWv1AHWGNG^mPZ*N_c$+qh6+kmz~#@fH;@+heP(}CCu-m z)a_;y6?JXI*s;NaztGL5vmw_5>1H$>=ts$S&O2l5xQxCfA2Eo*_SWX0`#c3t^1rgG z(}CW6fOAlANyfAA0`t~`gob;kbS)>nuH=&Z}(-4j1L_{9v?E4dLzis@43n(LE|=?-20JUVJ(2T7Sh;hdjsCueMJr|@Q={7Wm)#I(N#2^S%qOtcW}%3Y5zX{+T7J0y{o7PFMI4!~ ziiP1zhg-3%5?TEYc`tqS!?S?t&^Z(_({wi`qAJqOZ6_p?2lF(C7_-SwHYO0$aQa$g zUp2+I11{B6U90f`7rX^msq8-W^u&2rf+=`bMK7RqD*o}U##`vyZI&BV$=U7Y)c~na z`tpaL^ZjX0gD5WM-V@kN^j)b}%V&u>;5@9vXR+5bWG@89eI5T~6XOa;vsc(|!fX(D zcBWwV#UwD&S_8cjnO}IH?3UmET6y}AJGLaouU~ns1JG;N=5G8__1(|xbGEMiFa^>r z7_x6u#U#4eAoFGu=ZsXJQE(4&TRt*%xU6@e%=R1eIiRJFfGBasY=*Hb9M|=;v9G%2 z7-T*z{Ns0%brh0@>xcF8oskYFT5t~`dcOEygnpv8w(z$s%SUtDfd{;NO-)YVO9`Xm zJgG}tdSCxr1E-{l$qO=gY0Ncu4vTR(#JFKaOdW_OmRj$a3{VV6mIoT@KnE%&VCL?h z-p%k5MFwm;5V%2d7!$;WJiWQCw|mL#5jqL0=p5B-HjB@e{>Dc=7Z{XyZ6PXjjENxw`r|uS83tqecFd}LI`zP_&y_Z#~sNLGOe&Gxb>%G{D)|_0pp{@c z_#~lN%pYD)B-fF*S`^}Cy{kpCZ*Wqsr_WW)VEaf^Gf z@IO-1q(;%ha@Co=Tzw})r{UrI*I2R^jX4LttKM_KVMEBRMaBM62bZ2H@u<2uI_XC7 zt;fuA=N#P->NM}E^Bv|z9LX){dBZX}!m6ypjN=hK3OdxRzyHqpPXipsm7ZB=>D%EL zVc=WUr|jsV6t`3=Yn7W5!8*_J(A8Q99V(6ssIsK^a1D#mN{VY(C7OKu6{l{WuQZq2 z1H9IUE^HbM+X9%0v-F&LLdcqi#3tsg@#86x7-xCY(&Y15UI3t}Bl__fBgWn8!~YEW zbNU)6@xl)YwyNb+p`aWKdZy?8c6}g(L&xeUvk?t` z(G4A4v+etW3k1ons@;8fCR-yTyD8#oD?dl!_BErP>9!BbyPg9jfq*%Rf?xn7rW^sz5F(q-J=`!k` z>!{48O=3xH;bbTEVcy4|>$57mKao=%5^K;rSvT=nn}uEIVFIY*Y*YFBW(CjJKh1?5 zjv9YzBo0;Q70 zO#ugdg+i5nnWf<+E_$1D_dJUEp3`Ji8r1r`Yr29x4gfi<$W(CTbtvLOzD74+WiV3z zIjyJyaCpx+0I2uRU-&+g5_XJ>d+;qDqIZ>6zv+Fci*i!Qd_GQs1X{SA;THN5->>~m zuc0Tkpkf0($Z*QHk_)i4%tAjnH|DhfU|@0Xe_mx5KvaUvsgV~O;pr_B!31Es;>n~T zY3m3BD(s*<X);}3@z&RRxzdMVrS%y6#=vWvOle|N8o#$UO zt&r<#jx25W zrg084DG1yX@~Tu3)1tV%G(peKcUA2pY~-=H4$(z}kP7ZE$h~G8yd!p_GM`haY5>ZcpmClMba$@>mUAx=J zy2pCnGXJ=mB4%Ad@}{5E?Yq5S|IAC?UaB!bn!-UTXxmoQm*28K?6AE|Di^{PhSum$ z|GM2VVE3D<)<13eVV8Uz?j0ESYW?=2zLOLsrR%mFYU32JmofizZLfF9Kfavy&pcT> zL5u}??UCn~Gy(`Y@ERo04v!rMKl zpQCmiVI?K<5JSDB;@`Y!)6YJI3c_vZPmT9*DsST#Qe7WPL!-V_Of2BRH1hIr(w>ic z2vT2FT21;B8~o~48{7VUw76UvtK48@L0Zn{+eYR4TRTqWE#4NW{hu>(QM=b;ZHj-l z;f(~Ml)_gcbSmm2Xd1G6cSHux)fCZr=v>C>`Fn>xb2Q)&t3zv<3&{1Oz3?3bkWYIj z6@c-c>JH|V;oO})!Iw2)+^V_(v!LzTR!}yWFstgT$ zGu)_py1aB>(mLd7(ucVHQf}t=+y1N)Z&UVnDN=%l_v@-YQmQjD0R!->%W2rg52XIL zO{8D+5wl;?T<#e){>N;>=bSZY4H-8B`yr$Iep-3}G=q1tMd%n4?Inl}Od^b1XvGU+ z^4Hr&cWO0TzQ^Z^Z&u7li_QG(>eiOsqpXuC^`6i?Bn9_s7>#^bDg$9K>LhdS!_gBQ za<4}Bv1drZ{@joW-T*Y_h9dwl`ACaS(`MKq+uYI@<-!;ePYVMI6v(&f=XY2*R!8Fa!2-05OJ>rh>7;{r|tjf|@y>`DPrmYx6D&ikhIKl${y(ssuLWib%U zgLz2?^}}mZdApA~YVzsD z-f^3HTU%Z|r0QqgrOu$BO?QEM4X>=oO6nQbzm9j!$WmsBa;y4?Ai6lut8MjVmd3=X zGimT6`tOad$$MhMfZv1J4GckWfv3>X?md5PD5jsimwR$!Rtxs^B)XilM+emEqoaSZ z=-i-Q1{B(_CTYc`>fLZRk1Vq2YE_LG-LaY{(qaXfn+jAe%x`1A31nXNLtoQoe0SoI zs*{`p3M?PAH9PW$s&+NKt>#ukm@K}XjX+jO9DJhop`?lDS5QzgH?FAZf7dafVstuQ z=9d;;4QTL$sYWr_`R{oXG{-10(@-AvD3de?x-9tsV+1w>-4W%Kn0m`Ya zy=eK-2~Q9E2Xc6GtK3e8Qlh1?5aKu3HHlPYabu~72I1+cd+p6^eoM6N-o#I24X&G8 z5wk>MVZhvrsG(9!%Qg^|Q_kj~;oWRYHb50Lr%x}r^;oJY1#&y7bOK@ELw(2t4rrme zvHN^);Vn+CMC0I;)M6uq5$Abbv9>24ba%N-%oo^=?IgMl-T2M3pqT}QqD)LO$DEs4 zuUGyu*nY?^=jbCkugsov&&vL4-`ea*m*1aJSx%qtMO2+x_XvR%xL*5ObVJeTRVwgG z>gk@%*OHza4SeupG?~)I4iu_q*g-g6QP5{Z@KVKY1;J1ms#YXzjZM_TNicKXgdiuk z*)-#{S}E4{WRe~ir_n;I^J1>F>%dGXr^-7u)w4@V1PPF>w`e#fw_eQ3JXkS_w(S1e zYn7h4{?2B8J%7LOzU4VVpy)oF_>fCGGNB+Gi(VcuK)UE5n0>O_4pp{0pDe4FPDys} zPzA&lv+(LR$Byq2_EwXrjH4R@oH*~N(p8Hee7O3no(HHZ3hLqEg5 zKFiC9zteb7*Sbq-;a9yXH%kZrC4dRilM?76w2}aR^JJobv0z@kZl>z4H_6=NPxroT zT8=Xk|9uMGDap7oo8bwe;u~#${=B32X`3u}@O7apnE-AkN>%b%0V^V&-2=>q{9$n# zUBMDC^eyPgfxXYNh79a3A@RW}&Oe=O zGynSEtX>Zw&&=7~=V*&MLU~1egC_J%J!#H`cW_!$GMK`Bl7iYwb*LKW0Il6y;VyBX z%ISi-RXC7-Sh57&nYl8s1jA=078BB@efLf8Zqu6c%Y#~Ps%yFqd)r0R)46?6jUA!2 zuc_4S^!u`SMW6?zQTMEDL91HK9C2Lo=Pu<(nqODAkwoY|d!MVTx{g-F&ZG7|4emEF z{RO@?u;072gn1~~+J`8GUZ-k6&6vtnhW5+H=YNLB9?Wc`uBLDa#1s7IHb8kzP1u(| zO!2?Fq-_qRZ@K0hBL>$knja^R9au(VoXXR=kBE1mDBxYrL*a(4`8I)UQ8Om+H#Rh? z6E|GzX4~XO(>_<{PZa>d_JWz{tJ+Qiu5SC(A{F#Cy)Ea2n7dyzluXL|v=?iV<)bo2 z5?-6z&7NN7%3HG>|D@T`ac6x8S9wWraqeMW708iU>N~Y>B?%eBvk=n+tskb`>y^LQ zI{*6yZPO@KM@)64jHk|BTHE{iTef<;>_+2&)o*SbkuFU^U|j80Y_D=EFbSCgz0cM* zZgE!&NSU=i+x!1PtmtxMTrWS9XN4m2bW_Ix-;#1V88Qje!pK1=K7fYX7rFNji`%(4 zc;M$^4cJ&YdXU68qdCb|I}bf>rG>aG6j$F0rBTA>_Wl4d48FLep}NRoJ$(Kev~xNF zhY}IX25SVmVwz-U(TMita1o}^=lds`xkG%U zy7mWpl>$JTRvMOg;%z>3@}yg9xT_7lp=)l4UZqlz6dLUMRqQD?dD0w+RW1I?tx^^f zMkURvgELlVWX#@A`1Co|{<9qV)A!hZKH$QhRMBTB-`0bA^${(C?8qw`mAD%ojUaY% zPb?YlamJH1Cm8z`PPmo=mbmMzFU{@IHL*13hy({P=p&U|E>HBK?N1XOkEC(_AHMkm zkfY&!Z!FawljOdbPPwCuCiL3z2$gbGr2yCs$>iAUZi1OnWfwm<=a>)bb?Wmb9R<^h zazIO9@bJhTlBLDr6;C+fUwjL%v6*Za(f?u40;pRkJk!e;H*>O;&(K)TTAN93wzw$( zIrcz7qN0Wb#Mj|IX=^UbzQAh?GG{2tDQeS|=3z_)<>=mZ=CYy8h%R?80;3zh#q^m4PQ)at z=man&>7#416O$+}^nd-T3f-*gZ=zFa%Ree3495SM19qBjqL~i_+kE~bItKb_>1LY0 zI(d*J@Ke))og`1pp-mgg+TmKq)x~GeVMTqOMagf72gg>aU)#(ZLR!4}o>5f^6{tvY zKEY}FVxI2Kem1}rqA_W322hdO->2p6r>EI(-tF+``eeWQ+r| z6cwcAL0TU0L+63jhCx$+946&m(znLPTi3h55*Mjf{oTqH@O=tz=6dQ#MRtLVh;|pp z?*I8qwT3ctfdm+N0xSB=`mr5z=@(Y{h#!9%fLtr7_GFrzdO<@6Zq-)cPY^_DYtzEPKjx7%#Ic9)LEo#b^ILRQ## zIE=i-!?oy7&BHil!8+nK5CLy{t*mb4UxMloh58eG~Gu4(6qaAkw!lW|`y6?92(=<`>`ptaXZ$baT;P zP-AFJu4>ra^#K5s?(lWp>X$0o4=w#S8nj*|=*9#gU?kJxqnE$39cI$_1(WBiMP-?L?kNm8jLC}8D8UEQdV~FrVmZ3nk3n9 z>SM!`Q9i|bIxu5Vb;U&2aY;8)EaTJx2buqgt~HwfTkf;^&(C=IZ8d7DMKM6e#nu)& zTf?h91z#IeQ{4G)OSjk}w}@QQ&|m_;Sw1R~dJ#y~u>c#6Rz*0rE^+M2Vo`|=guyselYO};QdJj;oChUeeq z)5@rtgWA9KDgEAY8S&j**xleXt3B{kR@}wfTG}H@vF&WB!U+?LSQYO&HEU@;&)X!qDkNd?|v4Kz_J7~T- z=MZFUn=FVWW_!Zxs+{MX=GcbVG2b5lz zx&Uz4hm@L6%ot9E$8$C*AqUe14GHMVmu#3@S0K$sh-j*I5p?E)w4SjDbu~(C#|zfG z89i_t`$nP_@%vo0>~cBHN$ZLe{tv!uy9F9aT2rUWwS2JIp7nny(*6#q9mc2h5*#6G zTsC#;0tWPtriy|*mIc!$Q+=5CxXE#QWw+bVs1C64+Oni!#0B@DVfR5% zy2;iet&T~V!+lD3kQQNY4}CnGd&LiN*_|EVJxeo32DRQ=yF45(e$p^vFKm#zaJ83K zD0v^*R0q3k-&b!jp{A|ZF>?X`4;oxrME`P(4m@h+J5tVhEP9#i8L_CYSl@T-jit1{ zySqPuu2U^*_cL0iGy43#v_^D)Id#;a#Pz0q=3l!~AH{t1{0&Lz3d-ouLIjP0N{dxZ zpbV1`#bA__K$@2bibACwC=md@;xGS!`q2krz-VN1LX=53W&uGTD!+JJu{qT$8hM|V z(srvkVC0w`9$@F6{-M)KTA1s1Ar4kmbcHZzuh*Zjn+0-wn&i3`mNgGxwmxwGfrw)U z&`?rtJlw1Q?QCjt8yTM%lvQ8D{zkQ!tyj@^@89NOK*NDcchdN!25!eLUO+3e@t z=D&O4yaK1*`)97K2+SK6!9-l4A?x9D6%%AQEFg9|cX~$4e_{=nc-s}hp`NCPET+@+ zAoiy{#AYI&vt%$q=F+-JhaxqJYLM~FoOi=8USELxj{ChYH##?Is*c&Z>7vXk@sCBl z7DuVi%agp)_Z;iMZ1VK_zX#rGlxZ^^xtk9Fs@$-WxKZoyANApk>&cbigkMNNLbj@| z{HTin-$i%s;I|3LwfvT2z?z8a2ishbk0|4t8zcj!1R1+ak$%L1WH5&dZF#s^>4<+`;utBz0uB zM(0#htla;uZB-XaoGDBK*w@W82B)58Z=F&-<5O7Gu42?^yFzTYfl!%HVKg_}ka?6- zyt>`C=DPy>yK*8~# zx=R7sbz8WRB{0BbVA+yhB`U^T_R~Kt#imX4n%S=5zL_o&{paazN_@kLI{4>>ro4gw zX1X{b*J3v3uZpaH{yoV9WzgK^T$!CnUGpA%&LWaxb=czHFWDLaz(@oL{(0FM= z^Q3OZ9Tb{47h+6vfn-0z!!2?58EWOqut#*CJ)uMbyYCMf)p4K?Sm&X0lpVxhK+z$l>5+0aYZlxvR1$NiuMF(VAi89Yzg?{vqi1yZE z6cc$Ww-ZqSNes~n#L1+=mUM{^9Y*gVcUlV#yba}BbAbOGki^KkAcc&i-QMh~?!Iww z90Qdn->6DyHee zeKY=e2Vg(;q%Ul8iQ>q$!f20+2~AkjT4`m{rn9@BpT1EFOg|Cq7TZvBfweYJ3)xtlX;7o5-R* zVdk!PB=pGzS=v6*4WUdg`ps}`HX2qSKZvYPwmB=`lEg^nVV7GGDjd#_MbxN!i(XAf zs>ex3nj;pK0&hm;JJl3U*0=}M(yQ%|AWnhZ^6;J4PPQ^hR^*>ZF5}fih#fKLw@qHY ztG|u(kg1^H9>UTd+NdaNq{Imwa0l&Bs;n$a8ujZU-+Y_dXb1g)#5x&tGr`dd{tfkZ z-e6t6Z{{Cg4I;ywy_Y|f6wK7yVcbteA{Jz>`a|{IoyB(+g;4@)g_E(ik^&^Ql;#li z;pdew3`W1~4X^wr!`_wmf0ki~VU!Jx2A`iOrI%?k2UOt3JK`UGuh3%aLPKI${N~aD zioTXtkzFtZzJs`{VXG84Xic^foSNJF`$_EXJ=^WL4CJ8seIx)eNL1b3d8cU}X;LO+ltgOTL8D9np6tobMf?T&SpMz`}xyL4V+;=}Qm)|`pg9s5=a zz(uQb#Py6m|I$qQM*ZQ8Q|Y)~(Ov(A8b0>6kzdfyhV%&@ojX)Qn5Id$qQ{~dfJZcL z@oXq@QyJRjMW^INEwAy=i+EmPjG=(ri;y`y$DabkXpR_cVqPy1n&v8Uhh+N2#{$Mj z9Q0gJs_V3tNPW3g;bwuM+Z~*ynNp9lok!yihhQLN+v=qDm70P=%);kik+L%%X$k@Y z^2PVLS6Vr{<{ZlYjjeRPMt8=2J1joI)Knrz7^S{bgU6+!h}Nv-p$N*XDM>+4FQmt%sbjUvT(%=ZE&Q$m^QnK$Hra znV1`rx+gE8RK+V=*heNgVzbei?T(EdH!)Zsl1dj@f6>=%$_ailVlJ$ag@OGyF66M} z7m6r?4^WNY$pPXt&>e*Kn-{YuGevN{xx|(;vfK8X^FgiZ`euVxA7TA>e3Vr4C`ny` zphKd^jOawlNulmW(|iwU$iU64ub?uz8uv8U=hihU!4yFKK1QIG4YbZHOGHZ+QnB`_ z2$PzgUPwed*_W+l(nJte6e2{x;AyQ$vjMAdG|D$G16$O|lT1>Q8W3nyMdgr8Q7ZrU zRko$Myxa}TCrmW;1CxJ{Y)7mpVc#li_Ct_Ud%h844e3ipvx1JzE+ew#y)W2Nj6q?w zX_${H_y0grOALe3GfBKkGUf(-{A zn;`QRQBk7xWEz9@?Dig#ov_3YpS4%LxIGGyx0a{#YdmU2KzjU4Yqu}%UK4kJzr~W5 zY-lzM0;;3u4O&mfl0l>um3ezlhncZlPeZgSUMrG*)Fv)8DsdaQXuPW-^SLffvKhRF?`8k-bmn!qPvJnqsVrzHK z6mTIObD-%kKp>fYhh0x%^!sZHCt^9yKGGEmWpVmomV~X02)5e26{U2!;QeKpqE=A- z-fcEs_a1$u?qdsri{BUR9Ar+6F3e8{GY?n`*Ga>i@aIBMW-4_ie9e6dryjoP49xVA z=oFQ=1d_NW$NV$Vg^PoB&6c?v#f*SxkIJTxuvJe}6jlQ$L{Ey~p%Hci4T5F+b`zvh zHuQHmWTiQE(E`OB*sZ?A2?rxZt}9fdlqRb+Gln-Hd^d5Q@?$}}+8%XaN*g)*6M;a^ z;eHgi^Dpv}v}QAjV7XUY@OQU=_dC&bqbXVB6y@N-$Tp4ncYaTp`j4Y){*lyiy7)yy zgV2DX_zivoM@^1b3ZAivJ@~65KCAnyvw|OQf~lm&bj8S-fL%s;Xka-$%88WvhPTpT!DoT)&|I#RlaCckgz@Q>aEY}-}ej}VXUN?%9!He8NQ`cNJ{ z|M5t-XB|Z5cmK6GA)+d}do2?Zjcm8<^lbkiv z_DML_t^LaGdWV?G$M$^NM`&zh$EZA3D2}(s-V2*flrb+W{;UnZQ_Hnb0n!ZT2(|iV z>(OmCG$^tjD+}QoukWeT3X1v^QMjpKI7_P{R*k_K1_mG8O)QiV5(GDNi$Iv|v)Ev( zkz8wy(a99}vWco@>p1n26=th-}J!Gv8V+x5aa2Y8m_lkY^=ftlhr zjmAF^ugBVd*2*YR0WDKmcH+G3j|Xc81Niywy1h1ZR#6{mwFR-!C9W2>i-IXq>t)dO+XquGPDlq-SJJ8n$|ztDs9F-!AH6+r%snX z3HtFhd5k56+eX;9NR1dz!_>(j06ZIsZ1eteXPP9l06G|DFN}&P!EdxPe2d=Po-wlu z`j|WdlnoH=Nvgoaljyo|NWLx|ANfk}I@kXyFXiC*dvae>SS?6Q;T4!1jvGO3NqDAB z4gT(d0Ln%_nKU3^lA2oqV(L#!=>E#hb4yx^wJ;XKkK}FtZh4A;J|GJ_Jy%u3s3xvC zZ}H{ojuwm|+as%Uz48kG*R0ED2Y#IG3?E!aB)kd(=X|@}G0=(7en+B%1RzZwhAQ_M z_0P!klcu#9=KmJN7+?ye+>BV&T23(wqYmox0^>db>PGb7tJ-2jXx+5BDfDw3-0Ns8 zLO~gosNmwfMlUGwU`>v_M-X);aD{aBgUv$>kBu~mWi<_`DUV71Ry0blM zfsNI2cnhNiAdWsVNh+!w4OlC=pix!$$K9z>O~A60ZYBXP_DBuf17#)YrG-xw_vLa1A_MW-9 z_NJi~GheBBATxo{{NJ%$3FJy0)r{v5om>!;_Or-mgEv!BdX!WPO9lt(Y*>yNKvVrV zMr95<#|W5f`ODHk&Bkrj+uX3+C8**#iU{E9$}~rkXplK$1BQzUM``_-%3RsNsb2*U zq%!$xnXF!ADb72ZHqng`6jj1I9Iin|(E||y;{?)(+EFJ1)J&`4mU~wi0wv6alvTAc zuP)8Sod2SsoTZY@%-VZ#>xsh>d?BAR7$JsZPE@H%GTf)Ijv4w7oW3}Rt87q>UxR@5 zOY3}R%1r7aO#{i?pCkt+u@ba?V%R?*qx*uN?jVsiEbjP%!%}b<;#nQr;FfG|_d%Ef zmzG9~aH@MebvXE}EuS=$Aw>#EcJQ7PhjDSCSJwbY=02iPdVSWVIdw~{PRu9c%uWcM z&IY+-jQ2FSMaBxFIgjn+s4Nhgwb~OxqUtD}I-i1wlD-sDD6@)@$l4)|nXV!e;Kl0N zqpJ@wex75YaMqU34zH76HxJ6WIj@f@eZM*-k?v9!gH#RLQi8L8bn%SjQl=(E$ZO{=C?EauPrz0_Y(6=MM(YezcPFAe}9?5ivPW_ zmH+$b{(Fz^sPzCumnQQa!yBYQV(Z}qXPDx@9uJbV^h;vq87;C!*^}B(fy^ zxQ519W=c`D8c6F-9i|pdCV*);j5V`q+^}8jik+X<4NU!1AR?1~-;z>;bL*#J{Z0m; z4Y3U%el~yogAtm#qwt0b-(b`qC~4&uRr^Wuv$`cL=xC^;X;jqY=JI)2yuf)|U?lCX0~%4ITWrc1S~^LyXw)BQtBKw0XSx zPe5I^Fry%IDUCuWs&4SxQL8RXx&9wb7O+atO!l%-^`qj>KOz5@F&Y{zHC%C&-qI7z zt#Bs0hhbD{zm7V|u@f0HiO`l#UHHq*9I07aWR7sZbPoj;wo89>zOE8WUo#i^NJ6M<;S0kR(G$DAs zIP+wNtT^4QCZh+LlX@#h6_N-7t=5{t;h->CZw=l*z0)k05fia(+5(755C!Ha4cdf> zf-{bWZu-7ft~Ys@?|)VrnH&Bs;KHh;Ki7IRbXeU2p;ir{_St=+Y}-lx?Bsoiyz!q6 z;yt9W?tG>8#fGRp#+L9fbg!(J1)tO{ApVTSsCH57&;QfI51YKLPu0&qeX{c#s`1gR zg=<;|xVVF+=7C>GH>xS5#nOQDJI)1{=u`X1vdNkX0r`;}#k5tKbCipLe7L+#(zmCM zJ;_3nO^>o=m7ny%*8US}zdkVc++^QoW_x~uz>CKM8~})8;#mF_bW_i(F{_COxx;Ld;rJtQOTyh+j7=}uszzZnyi#s{6p=RnK>Do zHo5*xHvkU`5U-`rOA`G$o6x<894{gzGL9Jw>%LpwFY?`n?^8eMgm{M&;)Bq}-s$C~ z=_fuvyfiNQ=)Ss-bharN`uq*moqA|A{o=zjPM!-`FS-c+ONw`Ob*h%($nH*yZO&-E zQkm`Fkf+(S9Zm%re2)s=4lPzbiq$ZNyAdSj&EoYnuY`l|P@KC~!EuoXTH9rV{ zrXdVGht5e1I2v+$7)XVNh8W!et1RNYbjXVYdMoTIx@azqygP2rSt61tD)3h+6Bpq4egI|8?D3u7La&gS;#g=qu z6nsK9HI+=vlxY?dt3xGJK8&bDqlxEjfDDJgIBgSXqb@*VDy~3!Uhy4t&O0}`t<9Ls z+&POZ{Q}CtDsuzOkN2^sI5HoV>Y9gC_reyz9`pIj=`xfz>G-am zq8M*!MP&f6(X$nTWxCZWtG^l!r4VCNv60vi3;6pm9;sJi?+`be<|7wZ;{clxR?g_mpE$z?x03 z+1lDHJ(DW$7c7#(E~Le)UD&G(;f-`F{FpKut>&ho*AyxFnVfq*tlhC@uW?e15Eg^x z;brWvNxp!zW2|Y6kCo5qzcgrN+DT;pjDR;PZ`qJWqCXz;7CQQPIu6aJ0R|}=;@9!W zO!X0xK)b_w#kE=6f8K}5HKm#qs)Gw~U~hN-RZ0dS#TQX!MxBJ=4hDs@C@&k^h&Eyp zq<=sUo5>#$bH+s+`iTn2I6{geevTHD20z!Yo!(FBKP z*X+hd9VNtWzTvmU0>_l%jb;kT`NIPPDV<{N1<-bzG8qGB@i#ajnVTE?h z1Hf}b7s36!e|jE7iiSc+A1;BqIymCO8KA!i*80JPSD9@Wp;!l?X&Tw)PSQOis+Ywp z|FQJvreW(aS+es-X5nq<4Zh=^d_*R+sK=gV$V?i1aZ*8&#)|QCwAo#rUfOR%ihPEn zNeJ59@AJFKYtD>+c+D$h8dq1|-BjM#$58`#$4tw;LNz6{GBU0)b5>+-P*7pnuGL_9 zi@2Hdq1by7ugsH$S`Prw&UmN>P=f-qi6M;(6$~ zDm1oSt38dwZ1OvP7E2oLwB}5!3EqckNzbp=bo5ZyB%(y>p1m~U@?nG?DCH|6+q9%W z#MF?SpX8Wvfp~`X{w$1tZrN|!_a~C$PT-i#?gpWj9;P7@Mt3q(SC``C18lK?+iP+! zdQJDTOYNN0y5^b)rQGpnQXb_5Bsih~YiW~L`OsPHme;Q`R)geO`3k+}XV9#>!2_DS z&s?6ma~dAzLrD_sr2OvJPL5$@_doZ_T5Nq45BIvvSKR0+|iyr*$Un)r`gNddPfMjGQD5q_r0 zItEA@O>Fk?*EtTxJ-TSONZ=sunKe-}TlE;^dma_|tB1XhOHpHY#5FTcxbyM4wg{?I6K1CH)E{CCsM11Mpu2`*ng!Qvk=yE zMSiF_LRMhX3Q;SR8$0A8(MRs&ykS> zdrv+ZSkUz8f2*!H4pfoGe5dvtz&KwernDCd7#cUeRVoOLCWV7cvz{@(2YI zB&!H=@$oubR$;Kx{&}ZFDFii@P~KguLvz)NGC8I&KC>9$dR57S+Uqy7+?4)!lmBWP z;cz7%caDbonED?C@yRfE5~=1e7xH6aqg!HcG3h`vy-bE)SxuRrl!p6g|qLq``e)H?dNRXH_>U(X2Zxq=C3zd)#1RDzdRFF`1H@0uDrDKpk=q{ z`!7B7*G>C-KE0`n{Tpxm@YfC#`~B_lhJSnB`0!5S^sk@obt@=s=zYKNAfNC>_PN97 zPKro=bUU}W?!9r1*L!vUt)bI3+Z}(iF3s!z{r~Y_x7DfQtr`uEwbudUNuM>;a~SkR zKSN?_Z8^5M}U6T@V-s$ZMawg0O2 zzx_U4X8O>&e|B`*{MT!f!Jwz+Q>yBm>SarpZl~tOl@+BUO9L940&?IBC+pr$;0t#V z?jJn7o@3@w=){6UJ(^zoB3CtfxkK$Im;Fe?7fIz2Nu$UpvJe`rK0RFbkC4e8aj3uP z80D=)yZ*Qf74WS+i;1;x@84Q|Ch7`V)6%jZ+p8+QlP_}i~AEyO@${fZ16Wn8nSn}IVtON-w#{k<->^^3R56FMIk zb+Ls3mj3^0@5{rfUf*|Rr`n`xlXe58m5ixqP)N$IrOad|8In07W2kM{P9@8b3aQ9E z&qFlHkSH^YP$W|_EXsK9cWHlr=QmvEoZof+I@fi$_O&yu_4&M?_kEuGxu5%ao^NB8 z|KaD~0{2)qTB1);Wp%ZSHMFX+e=Yz05s0ULAY)C~t652T=AS4Uv3Zm1K0cz7nlsG( zf#6di153a&$K>%k6liB*RC|Q1mmX3g_0U1h$ed*Ux+2#<8^Dt!qY<{c2SCy)97=8aA(5JRPpC>3ADYJVj-NP{<#Gj5yH7|G{O;Ym zc0>2~**7AH*mJge<5LJn3eX>k-D?nRgdBuY?;*-)Yn>cyd5m1YK4g122z{cEucKmw zeW^Pid2v2EX2~QddxB{0$$AygbI`9Y zh@SH!AN9lSYxe`%8XpVb)pB<)2o*O@!N6?o4gmg`2i|k}VZ%%_rnrQh>>hH8(Fpdu z^Y8~lOj1^LbPWKO8CIwUJkn3kO{LLjn(1d0Fa$mx zII#e2k{O6=>OSpFAgge>0{}1ybHECXFza0h>lgLn#S1d-mejoxYn(WbQ~{an$bMlk zDM{1|JffW~sJXb0A3sv4;QZ7a9WxTew{Jg2+K}%0z_i`L;p}syQhyF>wuE$fM2Yijb4mVoq(t6QvE*TBzI6iO3>1)imNXftz8arAfd^$Lpr4L;}w(5LPA1z7mwVIMh$HKQ)NX( zM-Rj1V~5D!%AMA5P1 zu_W$UpaSgeQgIbS6I0NvG@>ye#|2OF zhJs4OIRqbste1=eZfgZg7)2~TT05XZ^of24VvY;rU(}igZr@(Tx#@r=>AMsh82Avo zuKj8U2jXof(mEdaQWJ33w?iSd)_Y}BTC#HUQAl+W% zRMI$x9hL(Hj>dS}Mt}Uag z!z~ckLL4-RU@dGfN<8vd=y{lnak|@by~s2&n)||&Xz|NF>c<|bBE!kS;pAZ$Rv4`m zaiC#{fd{4;qsU$Q>eUyJWn-=vu252K$#QhsFDwUwS}i!A0S1n`!Nncr6&3HX)9iDH zJQ)2kU%psU)qw(tr%ElmdGlriW*T-ky3CE8Ar5JAE(hO(+vvHI{?^ZS*REY`vUAAm zI3uFT&%S#X^NWQlre9pm%gbZsV;*ru4(7p8?x_rMMx{q#_RV#>Jt5zjwr+%F{uxC$$>HRUjfzP%NW{}#Fm_pl@L_B^W z3})vR)>MVWRnmRjl86&*=@@++P|Cz$tCs0p!v?rBhj`Ws?V9X3*mm#UePL?+T+;yJ zw8Qhm^RtJx2?(tH5(YY_s2w?b2G{oT#p@?${2N zK+38qB?2990*6Pi7oVV`4)gd};wK53@idj!1m8WIhtc~q;a+;3d$&ZBwL~2b!&qBv zobm(NWRy3&NVWwRDJ!)XzsmA`J#%Sm5nGXY%|nD96lxg05Y8%iu9bLp7ufC6PMDmG7&saV$$eaH|(|lkc5Ah z52T|!KN;eZ>g;7(BqcS`R^UA;;L5NkPD3u917n~Ov~R}MSb3nVn@l!uZq|`OOC38p zWIF!rzY2XHzkl}}h>rp|C@l=4BX}3v4ixP_|}SiJ3)2o=g4(i)gp)-#(H&hwf#PZqKCRHMIG zyU}o4ex2!k4!Zo2BZl)45fRD6>4KU}NKD48?D}A7;_ASUL`-091bBhK zy=S=VC!~R~wXy+ty;m+= z#P?-m3zso^%WqjsC$dP5@VPONby&X%6(b&%Oq@qE0Syv1V&~7u04w05bL&5Eco1)x zhL%QOkOzQ-06ki+|NZxQ208s!bx&?$aE%1U%zrRGvaouxFQl0-z#@87@I31FVvw46 zf`m*wTsiM7nBO1@F_SU>@y4FX|G8!oNskjWAhgcxj+SEBTJ;SJSpmV{e_jpwtz3^u zJE}TD)u$5&=@|Z{Wo5ds*S1~x$p=OQ^U3MniHzJqP0&ey)HFxjc`a-X`-LKS9e3M~ zBDYMSOUI#=l)`|>k;zr7j_MV2eI4Ese4?An@38XhR%ew0vOgpW%}>s4cph6W#4fVV z5$*W)0>1w6Shwh=8;eRkkKpotj`_{8g3BHI`S8_^({W;z{EpXOEc+{b?Zp*7e#^}N zeA%_^UPY9<){i3$dAQLa3d~S>&DKTLTVf;2g;PkOC}~J zDf3%5&6P^aIec=zCU-xenwX1i>f!}}s-Dp|` zkb?>bs_uJh5|YRt7N?g9R(n5xF<&g#$sa4%xs~~!;L5P2a)%vc7PIn6vHko{tV4cA z#wQ~MX?I#QGJCE=7jw36+ZNIq#(a!&6{iV&#s{`R9%gMR(japt&*ua9*8lSPpiP^0 z>Fd{4xZ5d&CDK!987@e2oT@?r3#q*C>lbzH?JC*{_joC$!jGGFUfpeShs(1O$*=nzvmy=OA%B6_1f8@H%#!$E>-YFGrI-2 zP9mGEnQLcgZeig*)RB2M;E8fX0$f)j!b8BxQ!uZ45b?o52-TR9_<^)Cz$C*YKx?fe z9W~-PQ1*HOsqixaGJE!DK_sH4u1+IyZ*Xul!b`(*=g#rcL8vKaTQ$^@Ni&0EOGF&k z7nKW*eo7VXoopG8l`AgZBWck>u}aLBtDV-fCB5;bCP86x{Zrf^!c2V z)=-JI$F{Zp)<&t1={4AbCuzyB9p$~87TuEj12wpgnq&_85901c&%w~{q3$((y*zc+ z*VAhQviH;ssCP)aa9eZ=%t6q{&zCBwYgBUSDK!I>B+ukfluvp5c1K0!)opR&mw#g@ z{E%0^@|%>xVO3*S_U?fV_{LFk&>i~`B0)2)n+`UbKDcpXxxc@ECb|)D|0sHXBYeqr$Lp8B9Wx1PI-4okt=$Z2 z?=|ku4EcWHcHf?@5)yR)WXaf@p%D>7L8fc>yg_m_6H%xhz?8~s-=aVbmG0nh{ya6~ z#oPP+NBtIFGOGC7QLn>Zt1TB6B0}eykNm^5Uu7`@*pHe4g-65VVMI}=hTjupMk;X_ z>qN9uB(J_S2@DL(fzi(`Mr0FYhN0Ax&Gvwt0u@B(z<<*Vb+BvKtU1%w)#ZYqKDQX| zQ6D=%Oj5EQ5x|%9VRYWGr`f>yC>*%Bb8GDLtKp#~;cKZ=3Q z_)2(a@v24Jx&Z;*L0e0@emlL#rdd@@O^pzxgnCAiYuy0BJVAhb0KDT6JAHvNK`TZv zB>492C=!b?>nOn49w&lmtytwv3V-~O0D%I@sj0la{Q$C~F9`lM(EE&g)23vEM2SeM z7l69T21(rH0^y9$K+A4o|4`4*#lsUzeg$w1QkYcfTThV;#2H0*0c9*;6R?-&<3BiY znBqa`xn%rCk!|8tX>^B zyzb+tPaIPHPN+HAicJ}ZBAEoW$A6W%E#+JvjbPRb1+=`=Mvwr|qQQE_AMf2EUM?i% zj8(I5NYJ`|^Cnq>EXN*2LY-T^i}I!PA^C@8a+oC}q~*%01~T0)t-Zx7lK|mEJT}Mm zYu8lK|6&|>!{c2yq~KOF!kuF^`U-_i}gIdSHdS8hK?z>hujaJ^4PZcvOjzJ z3wPLafeHl8w(!kapepoW7R5wGMVms@5got`xs94}aB^m&&Cx6HI2PCw&Jmx8eM>Vr zf=+gN>qTIO@6dg!Vu^p#$y=jeg&f z!dHu4yB;Y_APX`B%1Xn&Ex3eptPwTV!)Uo+Ve+l>d1KZ&jYPx39HbxG&_ewjXjp-Y z0XK3BhzN6b7#1`m%bNt%+9|L{L1sibH+@4E3b{tWnjkY!AA11=1P&3E@XPbJ8485U| z_avI|-l2z)#t#F{>5ofGOLakGa&FxJj#1egs{ONKALIR8xDMmJ}8F;*H zD%?-$`0;CA5SXM)VAhR34%A6nw&nX55OaNU*@8gi9+hO$er;2Xf`m*8c4DGfVh|u)$@eA4=h>}GG z1YTpQY;ca))@Np>fhh8CnnS8AjM8-Hh*5ZM(yzb%dVcycy~MjD+pfKpv27uOjDgHjSb zOUpRy9zuMNwiic@PTNF2@I%daD|n=#ezM)E893x_gSloEYjRKCjVQc?272!(EqFkD zxy3Vfj8M7<9&#=<1N2dw=>t5`7D@}+KXqFXA1kt*aNTW@zfHnr=%FvAC>##P2`x?i zB={w-de}*M9j~lsb3?B4z$vol%KnVk9m((nXu3=BKur>2rqA7u;Ehy zKpmbl)Uci>lKDV6nLm#IjbpW{pr4S6MZEaPTgX@iCLDTZL@T%stB(&tQOT6d!Nz?`$d{N3tr<6y0$bpOU%oV>eBYVGE zR5JxrbJf+Qry_o%u|>$uz*9d}KE<7#cK+)mMmM#XK7ikGKRGDnpAP+$yG+2C@p|NL z^5uqOqx{~2g_oHEeo(+c^_ZOaB4QO9(NuT&7N5c~CYZY3`3LMRFvoY8+ikJtJPg0> ztF_=OOplXP)R@>AW zQ&YmA<(@6(zPOxc3RS=qD#oR1T3TBAP6XntKappvi31<2h}&#{A>3n8wJx{VWvDeC z_ii*v3J{SM>1U0Yh9;lZJu)qv0!P{8ZQqus1}V5IHZ7Pz8AEiPJm5AOy>^EIp)j>c z%O|8?i;5r6PT~EjwIEzi%(gxMi`c^To(QyL3FX)S?%@^Uj0tj#EffE8HWOaL=z$Zl zd$Z`b)QST(plhuRDS9G$Ngc6mFVNbq5~-=Ie5@e*H7fu&j%bWrYH(1{W9%+!>6_iU zV=Eo+?Q`sR+wph!Bb%sV^0~|tTZ2~DdBiJlUq45U6V)s^E3gjO5&twl7z-Qrw-lyC zV4so9mqD)7`Hz1#O<*s&W`Wf2ZtWdBgd9!HYlnx>yerYb2tQ`E9Yze(o!#yh*pX}h zKxJa&Gm>4jXvGj=x1oT5fF>|1&LF=$Iac=a7{CAyoY)~Eass%JNJ+=B_YwhyXwe83 zAfHzNR}7_N;RP`EMDr{BF07_^rb8MA&UzmFQx7Y|b3h#oPGz^$OQ>Lk2d1GsTRzbZfWdy_Lb2fDBzwaBnQ z*Eoeyl$~`31KnEYHK_fi>O^9`yK0S$!{%+=}{x|u+Yi#=Onw;)Yc znF4M4F@SDOfD8nG!ZFd&UvQDTB~UIj9~5mMZdDy})Ng34XIBdxrejZ0v>^d0CIIx8W96Ei_IYN!Uw|qFWPcy8j1X@t+tx!biiKrU4 zg3gLCqE%H|EI)P+gYp=MFf6Zi4$baWXwuj3gokgWs-qFc9^@PhywN#hFmr_akMo%) z8x{Ga#MU`Z9dx{&;WPDOK2AI@^z+s5DdIvE4tuHXtmi4&Y8?|0zIM6Ptf5=E&_!%W z%+^dt?2~<=Pdy;S7(QoZ!7hq`fFcdR^Hkg{4tdB~53mqk83Fs?&}h*hQQe11K3*3Y z2lYlWFa=0mjY)@&o#;YT>X0qP!O?rC!!LXY&y8*UuwS^}oy@p`+nG;s??%%Ry$r@s zK;;@o_Z7IR=TESY-64`DEtJs+R)ZbPKZ+N;+;72@Dx|xPoT1jD6+|4`VCOMHUhU~c z?SM8=Ho4rG*jQ^kR6@wq7khaSS8uD|t2hEu%nC4Amd}45x?@bVBL8?E8=7REwXSjP zd&Ar!m$$^0Y>gGvX9 zj*CK@pi~&0WTU68&WTd7RaGYm+sMwz=`4~{or`gm)ZWOhrNiZYQm{Tt9>YFcBgcpL)%Zh3%U8)Jh-fa(sgpsYd`vG5#b4F|!!vTQ@Z{&1m7%2{Ffh+Cy~#^k_j= zzHJKgHje>N-XCCq3o_^pW9k7&{SR1^IJK9^VD^phZqr2|**1kLw+)HN%^`Z6#({o1 zx1xx8{CHYioZy?v%F)&m>^rCZ8%c10j=LEP)CAUiEf3FG+929CI}98IfY?T%Ag(}J z#2~`tlw$Z4J-7}f8XmkC8}G2+0&Kumv>P@-ZNA$jUJw%I@IRqpVLIJTOGU8ZYK|$t zS!=H=3!^{+K1vIj`jk5)At8Z?sLHA_=^P}kV&gwXr))moxoF>X9hsFMqA}2EEMFsb&jvK) znME&LN$T3QYc0r>lBba#7{iXGwInehA;9pc5mf@@qu{Zqc)ktMcQvKQZObJ9MKM?d ziY^#;F)67QZ2zyq-C)lLF#vG1U>H>u$?|^e0u{(6nj!;~16Qhtp#CJy`OCZK%HjLf z;&we4CIM-j+SVX5AR`iz7F_B;U&M5f8m{fb;O%x%UFwwfasOz0ejQGfBd{)GYp&>U zM@a{a{b$BXI&$52Db_fqhUlg1f@`7!$m=l!J5fcGF?52|xo`n67z@^n=4601&Jtjb z;AI*22FB*2@Kg(L@*Vn+QZna04ZAot6S1Fc)f5?85=0aE#Z1}|UiTyoJGB|b78V#7 zd=sNroA&%ghAc8foP>^>kVc;R7|bBLV@EZRi7iTNPeIcXWCpFzNONkOt}IXr{9h~N z0-90jk6oF#Q_z}t-yI11$=6R;Nr@dK-}p}f+b#)xG$A#_yo^)OIS^6~m3vj-JVFLa z-!=IuslC3fZDMNSf;9bo=p6}N$-i~0*xtR3D9N=Y03PY@j_iocG&B@XN!HDLjXEk@ zz}Aeguyq3nWCkvW(>>ZSY4zss60e~wumdH~m}9GtLYZXyqt2jKOc}BmtHQaq^gQ!+y{{8 zMdINe$fKS{(fJpzs|JCjoz1tEJlX*^6oZLP1^|zN9tmea*kNc>&my!d0++qNW;TLd zy%~!2c0^EvX!w=`N>Cr-6Jkdr(d5ZM`xVc&H;||!G=>pfYdI*$`#%Na_IR8eaMK@d z*Jb!lUB)N>H}JWN%1+Q@Mg19W+{~qU$#{(dk(2<;-z&4V9+f!%@o1148U33ga`zXt zKmPcOguBR?&&Y5s-tX9d2vj*2BlGVz?&x)6G2~#uS+&~BbQ{}4XnLFa; zdzn9W?g*GlhTpK7@w2A#_~F1h+G>6l^A8v__9!@*>n&H>K zM#2wQPc@dzF~s5Y`_L9?Y*N*S70D<~g;^W^o#-=pn55EC4i*Dm!+Yh@ZGVO@+2_a} zMNu_o@eK&)R|Utr4;-&=_>$bwU=3t?d!8X>@DNd3KLSE+3kwV1N$^t*s30Fl6O#nQ zRM@PBzuv0d2Lm62nTUJ^=pP0%ERL5K-065-POH8~fb}KO*&=1!4t`;dmtAp{6sKJ` zmmD>!m<2S4Le!(;Wp?#EJ-^`FJwma}WNeHn^fd~t;{5#bpvKV)ai?>?CLk=AB^?AE zci%=ZvWS1Y11H#KZsO9p8JWM%>t~q#jZB|9=y)D1sCt~Z4@i~OabW3rK*}FhzM}0` z8UM1Xs^bcasw}2V7!w}+U5>#$7G(ZNw#UJtUVE-pDl`vC+bA}p+k zC29gU1^fKOpe6+JEL2(OAlX4CEQ5x1M&+(G(4O0^^XQ4Q9uYiP!(_zpFcX@nHhzaZ zg#TFMqB{%N9B&q#>zd$0{G_} za7=>5;Rhm`adcm`#5*afsAz-Jc7`>!#vy0_)ueq*rwXDbcfdKYI0sPrp^RZ6+Su19 z=Ssvw_~GFy_!2WVgi+n%!8_#pZF}^X?h_zAM(cdDX#nAO?Q9Y9kp3UYc_bkcV?bml z1ST5p8^VGZARSGp%^m66Wn>J&iPnQ#e+){UNV~~YK@xXF$atJ0SOi+j&hvA)iUeeA z6sc4-oV)r*3LIW2#%qPoMyW{x)OUjX{4Xw+tO`Cs{>Cn^YY9~y=?C1P$fiS=3;;6U z10|zzAP=FYHRRDEzN-~+|BZ!haOvbZw4zFsq^dqe`Iquwhy&D$U~wL{h-7tM8B zQ6(dWD`M?YDUXhhmJcw8HcEg*gC2v6+gskzRHO^czH3Qb*px0EEK2dKHtb%;^ zN42AiinH+QOm;i;?%lpSyKZZ+EISFkfNy{s%VPG=i`vwTH{V`(I2+qcIeo~!LS8uz zJH7yYfDMKq;B!Y$HVPFghw-LwEXR>YZ6q3B1o?^BIoEI87@xfDBVSQneF~LqJr4WF zQAF-Kj!di`_~QgndmzTt`AtDKTt%iV8qk)~^^ z!dV3ekj!S~8)I4Kf(Z3-vD`Tm-p-AV>-R(A)BSuQXAEtBBYKxT1o0`1H-rYN05zP! zSRLmQLY5E$8tIH^q=W%>1d-Jt0aRJezh#;)APgiF2>ujzuZiQZ$78|lN zX9Umq7~tJuMEv!jnB#CyYy0!$&IyW3v0XL}2SL9a@KSY!Qh% z!mdK<#?#+SrXM*2_MO9QTF_m0E`0xHl1ldi-$ILqB@n_heq6$2WRSKhXeGKzDpE!v zFT1`qmq|6^A%20}h_}~~H*<3@tEVl*FSO{iw4GW=A$AN^oSW-qRh%#CSvCLR_M&Ar%zz7x#~`QJwb!9jZo__uC;}rWs4dMZlije- z4)kCG&}U=PT|<;*HpZOb2{MD~G7Z5k-F6z6fyf3zyJXlg67_&XBd^nxx&CRpV{%$= z#R-I3N;D$Qw4R3!`6P{4KQ>tbRlUj+%Hmn4pR?LkEGrTGD)xKt+xur}NuXF$^j{ew zYi4ABCzP)@`NNKxc%%-9zMPkifuD5(6*$Qs8Lmjx3sfYWO76F!LzG3`4af^U2=6rM zVxBTKhtil7mtKhw6U56KX!!;q zB+mkmMzYjzyBwNPv)h*A4P~TbM3|%mV%FUx4Ek(ELOO&@;Nk)5WhT(oz-M484Yowv z0}SR#8WETq)ev2OkQpfuefU8_RTVk`5?A130+B1XyEg!y?`T3%Jx|pyo5Moxgd1IR zWBJZ8#_}~zd{tc7((jw3I5+-WJGVl?c1W%SD$q)Gz)C;gF@~9cAQ8Fr1XWpiFLi>{ zGPPEV+dzT;HN6I8jwLfZDLtTKgox`TjWB9P)59QpgUkY?N00m7S{nRh6amw&MT@*g zN*BLE3LqIoW6Ixo*REYN*#6xABl6F?``Sw64j2PPXLe{?l)ZY@jNr+NKy~chX2`&m z4or&Qr?eo|*VB6S#gaK-41uQk=i}(B#@)9LObKg`)_>j}xWn&ZOE<^*3S(~&r9PqR zv-SaQFzc4I>0np3QIQ$^HYIu=_O#Kqa12LN3yGl-OFqP;lyiXjDQ9cAd!^-ck1G3i zsw3w{6lRqb6>8AD^P)3e_PM&(XUjlW4}>654nbMzQ2uM2R9ZA}OzO1M195Gl37mxw*oZg*n_%c{!L(R zHKL>+GqEMxWe$>zQCu;`(k*#gi>L;@e;`RP4!7iybqLLwaNzC@IOEzO)bCG^7?UTa z)e3ar1Ex^jsHV@R13}g?W9Vv{ypS>AMm9bxOETuk6Jc4|Vx8U*0Re$OZ~+FMs~yUP zR9!HBOh09w$VF5FHmzCn2zog(qmhvHBKA-+1W|@sGRp=oRZnnaeuev~lBltK@#nAZ zNLXj#c3{LU!JJqjSOZ!55bg*1gLl~|4`+-cXSa6BK<@lT8Fx+C6+|C}Brh(5+bBK2 zi+Dv&_lVcX!Ox3U#gQvfGgdWzHpzUdPz(y@BF~pV%rx&1(HH2O{I=#};fAXyJ$bO{ z&7IzO?bJwh4UO!xHN2@JZF%2c7~?he651MrW&T1fT|Bt7r#R8;akwzt?7)MFt>6Y} zT$?rxSy|I_k|}*Stmv;4T2`Xonr3o^(o@h{2wH`x#P#W@V<5Q^XK4Jmg=gksonc2A z2} z-HX>goX34fF?4!df*z!TupNB}phkEHRq{hrmWx23OLicb=B^;88T6u#}9M2I0;=&T+#4nXev#@eYu~qVt9PxMF zF%~zV_bYzOd{#mnphP1QV8Me2>5tK%^@kljV)!5CavD|fLe&1Y{LBIW=>5GTuiZwR zjo#HvVhdTpwZaK%rC*obth(4***9Y|>THPo3&HX@S~?T5p6oyLqyLqlW`g{hid@GC z?Tx66cL;8@9stbpQLZwz7FIMFE#BpPWf zaJGg*0aFu1;pZB1F-Q@r`*e^h(+1>U;NgVdrDKu0D(BIlZY8KZe!gPK+J2A zh&Y1{xYeR(TPzQRl`;P=sDQj&j+b4(Iqa@&Ti$!;)tFYA53aBhF`5-E2?XN=M56h& z!=g4s)yCv;eo7_QOO_Kph-QO!L%hA>gKncF1*b=9djcv!ZRYK0MTXC5nw6E6>Q>o@ z1}X+@#)zMkXhb%WylS!M+q;j8F0t>X&SIzQf$`CbS1&~;?oaPa(0$K8ARw2B3&94B z$$UV>Yy-AzU;9R!amOS-|(wWqobAh@=RQ4z!ztXq}F-iBYcrxx@SYXJ#%ji)PZn zAF`@6>T9LFd#f(eQ|g*F+8Ml zz^_&Jf5~!*SNln+9G*e0crPRv9E^cU#-r_P%B^|Q02-r?Uq{!^X#|2seaJpYT*iWcs#O~($C P`=Kf*A5D=zeffU?{H{~o literal 0 HcmV?d00001 diff --git a/doc/en/MiniMax-M2.1-Tutorial.md b/doc/en/MiniMax-M2.1-Tutorial.md new file mode 100644 index 0000000..3579903 --- /dev/null +++ b/doc/en/MiniMax-M2.1-Tutorial.md @@ -0,0 +1,198 @@ +# Running MiniMax-M2.1 with Native Precision using SGLang and KT-Kernel + +This tutorial demonstrates how to run MiniMax-M2.1 model inference using SGLang integrated with KT-Kernel. MiniMax-M2.1 provides native FP8 weights, enabling efficient GPU inference with reduced memory footprint while maintaining high accuracy. + +## Table of Contents + +- [Hardware Requirements](#hardware-requirements) +- [Prerequisites](#prerequisites) +- [Step 1: Download Model Weights](#step-1-download-model-weights) +- [Step 2: Launch Server with KT CLI](#step-2-launch-server-with-kt-cli) +- [Step 3: Send Inference Requests](#step-3-send-inference-requests) +- [Performance](#performance) +- [Troubleshooting](#troubleshooting) + +## Hardware Requirements + +**Minimum Configuration:** +- **GPU**: NVIDIA RTX 5090 32 GB (or equivalent with at least 32GB VRAM available) +- **CPU**: x86 CPU with AVX512 support (e.g., Intel Sapphire Rapids, AMD EPYC) +- **RAM**: At least 256GB system memory +- **Storage**: >220 GB for model weights (same weight dir for GPU and CPU) + +**Tested Configuration:** + +- **GPU**: 1/2 x NVIDIA GeForce RTX 5090 (32 GB) +- **CPU**: 2 x AMD EPYC 9355 32-Core Processor (128 threads) +- **RAM**: 1TB DDR5 5600MT/s ECC +- **OS**: Linux (Ubuntu 20.04+ recommended) + +## Prerequisites + +Before starting, ensure you have: + +1. **SGLang installed** + + Note: Currently, please clone our custom SGLang repository: + + ```bash + git clone https://github.com/kvcache-ai/sglang.git + cd sglang + pip install -e "python[all]" + ``` + + You can follow [SGLang integration steps](https://docs.sglang.io/get_started/install.html) + +2. **KT-Kernel installed** + + Please follow [kt-kernel](https://github.com/kvcache-ai/ktransformers/blob/main/kt-kernel/README.md) + + After installation, verify the CLI is working: + + ```bash + kt version + ``` + +3. **CUDA toolkit** - CUDA 12.0+ recommended for FP8 support +4. **Hugging Face CLI** - For downloading models: + ```bash + pip install -U huggingface-hub + ``` + +## Step 1: Download Model Weights + +Download the official MiniMax-M2.1 weights. + +* huggingface: https://huggingface.co/MiniMaxAI/MiniMax-M2.1 + + ```bash + hf download MiniMaxAI/MiniMax-M2.1 --local-dir /path/to/minimax-m2.1 + ``` + +## Step 2: Launch Server with KT CLI + +The simplest way to start the MiniMax-M2.1 server is using the `kt` CLI: + +```bash +kt run m2.1 +``` + +The CLI will automatically detect your hardware configuration and apply optimal parameters for your system. + +### Advanced Options + +For custom configurations, you can specify additional parameters: + +```bash +# Use specific number of GPUs (tensor parallel) +kt run m2.1 --tensor-parallel-size 2 + +# Custom CPU threads and NUMA configuration +kt run m2.1 --cpu-threads 64 --numa-nodes 2 +``` + +### Dry Run + +To preview the command without executing: + +```bash +kt run m2.1 --dry-run +``` + +See [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines. + +### Key Parameters + +| Parameter | Description | +|-----------|-------------| +| `--kt-method FP8` | Enable FP8 inference mode for MiniMax-M2.1 native FP8 weights. | +| `--kt-cpuinfer` | Number of CPU inference threads. Set to physical CPU cores (not hyperthreads). | +| `--kt-threadpool-count` | Number of thread pools. Set to NUMA node count. | +| `--kt-num-gpu-experts` | Number of experts kept on GPU for decoding. | +| `--chunked-prefill-size` | Maximum tokens per prefill batch. | +| `--max-total-tokens` | Maximum total tokens in KV cache. | +| `--kt-gpu-prefill-token-threshold` | Token threshold for layerwise prefill strategy. | + +## Step 3: Send Inference Requests + +Once the server is running (default: `http://localhost:30000`), you can interact with the model in several ways: + +### Option A: Interactive Chat with KT CLI + +The easiest way to chat with the model: + +```bash +kt chat +``` + +This opens an interactive terminal chat session. Type your messages and press Enter to send. Use `Ctrl+C` to exit. + +### Option B: OpenAI-Compatible API + +The server exposes an OpenAI-compatible API at `http://localhost:30000/v1`. + +**curl example (streaming):** + +```bash +curl http://localhost:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "MiniMax-M2.1", + "messages": [{"role": "user", "content": "Hello!"}], + "stream": true + }' +``` + + +## Performance + +### Throughput (tokens/s) + +The following benchmarks were measured with single concurrency (Prefill tps / Decode tps): + +| GPU | CPU | PCIe | 2048 tokens | 8192 tokens | 32768 tokens | +|------------|-------------|-------------|-------------|-------------|--------------| +| 1 x RTX 4090 (48 GB) | 2 x Intel Xeon Platinum 8488C| PCIe 4.0 | 129 / 21.8 | 669 / 20.9 | 1385 / 18.5 | +| 2 x RTX 4090 (48 GB) | 2 x Intel Xeon Platinum 8488C| PCIe 4.0 | 139 / 23.6 | 1013 / 23.3 | 2269 / 21.6 | +| 1 x RTX 5090 (32 GB) | 2 x AMD EPYC 9355 | PCIe 5.0 | 408 / 32.1 | 1196 / 31.4 | 2540 / 27.6 | +| 2 x RTX 5090 (32 GB) | 2 x AMD EPYC 9355 | PCIe 5.0 | 414 / 35.9 | 1847 / 35.5 | 4007 / 33.1 | + +### Comparison with llama.cpp + +We benchmarked KT-Kernel + Sglang against llama.cpp to demonstrate the performance advantages of our CPU-GPU heterogeneous inference approach. + +- **Weight formats**: KT-Kernel uses native unquantized FP8 weights from MiniMax-M2, while llama.cpp only supports quantized weights, so we used Q8_0 quantization for the llama.cpp benchmarks. + +- **Test environment**: 2 x RTX 5090 (32 GB) with AMD EPYC 9355 CPUs, input tokens=32768, output tokens=512. We made our best effort to optimize llama.cpp performance, but we could not achieve optimal prefill and decode with a single command, so we used separate configurations for prefill and decode measurements. + +![Performance Comparison with llama.cpp](../assets/MiniMax-M2_comparison.png) + +As shown in the chart, KT-Kernel achieves up to **>4.5x prefill** and **30% faster decode** compared to llama.cpp on the same hardware. + +## Troubleshooting + +### OOM (Out of Memory) Issues + +Layerwise prefill requires extra VRAM (~3.6GB + incremental cost with prefill length). If you encounter OOM, adjust these parameters when launching the server: + +| Parameter | VRAM Impact | +|-----------|-------------| +| `--kt-num-gpu-experts` | Reduces expert weight VRAM usage | +| `--chunked-prefill-size` | Reduces prefill extra VRAM allocation | +| `--max-total-tokens` | Reduces KV cache VRAM usage | + +**Tip:** Test with an input of length `chunked-prefill-size` to verify your configuration won't OOM during prefill. + +## Advanced Use Case: Running Claude Code with MiniMax-M2.1 Local Backend + +```bash +kt run m2.1 --tool-call-parser minimax-m2 --reasoning-parser minimax-append-think +``` + +With the above command, you can use [claude-code-router](https://github.com/musistudio/claude-code-router) to connect MiniMax-M2.1 as a local backend for [Claude Code](https://github.com/anthropics/claude-code). + +## Additional Resources + +- [KT-Kernel Documentation](../../kt-kernel/README.md) +- [SGLang GitHub](https://github.com/sgl-project/sglang) +- [KT-Kernel Parameters Reference](../../kt-kernel/README.md#kt-kernel-parameters) \ No newline at end of file diff --git a/kt-kernel/README.md b/kt-kernel/README.md index 10ca854..3fc7770 100644 --- a/kt-kernel/README.md +++ b/kt-kernel/README.md @@ -38,6 +38,7 @@ High-performance kernel operations for KTransformers, featuring CPU-optimized Mo - ✅ **Universal CPU (llamafile backend)**: Supported (using GGUF-format weights) - ✅ **AMD CPUs with BLIS**: Supported (for int8 prefill & decode) - ✅ **Kimi-K2 Native INT4 (RAWINT4)**: Supported on AVX512 CPUs (CPU-GPU shared INT4 weights) - [Guide](../doc/en/Kimi-K2-Thinking-Native.md) +- ✅ **FP8 weights (e.g., MiniMax-M2.1)**: Supported on AVX512 CPUs (CPU-GPU shared FP8 weights) - [Guide](../doc/en/MiniMax-M2.1-Tutorial.md) ## Features @@ -167,10 +168,57 @@ Simply run the install script - it will auto-detect your CPU and optimize for be ## Verification +After installation, verify that the CLI is working: + +```bash +kt version +``` + +Expected output: +``` +KTransformers CLI v0.x.x + + Python: 3.11.x + Platform: Linux 5.15.0-xxx-generic + CUDA: 12.x + kt-kernel: 0.x.x (amx) + sglang: 0.x.x +``` + +You can also verify the Python module directly: + ```bash python -c "from kt_kernel import KTMoEWrapper; print('✓ kt-kernel installed successfully')" ``` +## KT CLI Overview + +The `kt` command-line tool provides a unified interface for running and managing KTransformers models: + +| Command | Description | +|---------|-------------| +| `kt run ` | Start model inference server with auto-optimized parameters | +| `kt chat` | Interactive chat with a running model server | +| `kt model` | Manage models and storage paths | +| `kt doctor` | Diagnose environment issues and check system compatibility | +| `kt config` | Manage CLI configuration | +| `kt version` | Show version information | + +**Quick Start Example:** + +```bash +# Start a model server (auto-detects hardware and applies optimal settings) +kt run m2 + +# In another terminal, chat with the model +kt chat + +# Check system compatibility +kt doctor +``` + +Run `kt --help` for more options, or `kt --help` for command-specific help. + ## Integration with SGLang KT-Kernel can be used standalone via [Direct Python API](#direct-python-api-usage) or integrated with SGLang for production deployment. This section describes SGLang integration to enable CPU-GPU heterogeneous inference, where "hot" experts run on GPU and "cold" experts run on CPU for optimal resource utilization. @@ -361,13 +409,13 @@ python -m sglang.launch_server \ | Parameter | Description | Example Value | |-----------|-------------|---------------| -| `--kt-method` | CPU inference backend method | `AMXINT4`, `AMXINT8`, `RAWINT4`, or `LLAMAFILE` | +| `--kt-method` | CPU inference backend method | `AMXINT4`, `AMXINT8`, `RAWINT4`, `FP8` or `LLAMAFILE` | | `--kt-weight-path` | Path to quantized CPU weights | `/path/to/cpu-weights` | | `--kt-cpuinfer` | Number of CPU inference threads | `64` (adjust based on CPU cores) | | `--kt-threadpool-count` | Number of thread pools for parallel execution | `2` (typically 1-4) | | `--kt-num-gpu-experts` | Number of experts to keep on GPU | `32` (remaining experts go to CPU) | | `--kt-max-deferred-experts-per-token` | Number of experts per token to defer for pipelined execution | `2` (0 to disable, 1-4 recommended) | -| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy (RAWINT4 only) | ~`400` | +| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy (FP8 and RAWINT4 only) | ~`1024` | **Parameter Guidelines:** @@ -375,6 +423,7 @@ python -m sglang.launch_server \ - `AMXINT4`: Best performance on AMX CPUs with INT4 quantized weights (May cause huge accuracy drop for some models, e.g., Qwen3-30B-A3B) - `AMXINT8`: Higher accuracy with INT8 quantized weights on AMX CPUs - `RAWINT4`: Native INT4 weights shared by CPU and GPU (AMX backend only, currently supports Kimi-K2-Thinking model). See [Kimi-K2-Thinking Native Tutorial](../doc/en/Kimi-K2-Thinking-Native.md) for details. + - `FP8`: FP8 weights shared by CPU and GPU - `LLAMAFILE`: GGUF-based backend - **`kt-cpuinfer`**: Set to the number of **physical CPU cores** (not hyperthreads). @@ -400,10 +449,10 @@ python -m sglang.launch_server \ - `1-4`: Deferred execution (recommended range; good latency/quality balance, requires tuning) - `5-7`: Highest latency reduction but may introduce noticeable accuracy loss; use with care -- **`kt-gpu-prefill-token-threshold`** (RAWINT4 only): Controls prefill strategy for native INT4 inference: +- **`kt-gpu-prefill-token-threshold`** (FP8 and RAWINT4 only): Controls prefill strategy for native FP8 and INT4 inference: - **≤ threshold**: Uses hybrid CPU+GPU prefill. No extra VRAM needed, but performance degrades slowly as token count increases. - - **> threshold**: Uses layerwise GPU prefill. Performance scales better with longer sequences, but requires ~9GB+ extra VRAM. - - Only applicable when `--kt-method RAWINT4` is used. Currently supports Kimi-K2-Thinking model only. + - **> threshold**: Uses layerwise GPU prefill. Performance scales better with longer sequences, but requires one MoE layer extra VRAM (e.g., ~9GB+ for Kimi-K2-Thinking and ~3.6GB for MiniMax-M2.1). + - Only applicable when `--kt-method RAWINT4` or `--kt-method FP8` is used. ## Direct Python API Usage diff --git a/kt-kernel/bench/bench_fp8_moe.py b/kt-kernel/bench/bench_fp8_moe.py new file mode 100644 index 0000000..a5108f5 --- /dev/null +++ b/kt-kernel/bench/bench_fp8_moe.py @@ -0,0 +1,286 @@ +""" +Performance benchmark for FP8 MoE kernel (AVX implementation). + +This benchmark measures the performance of the FP8 MoE operator with: +- FP8 (E4M3) weights with 128x128 block-wise scaling +- BF16 activations +- AVX-512 DPBF16 compute path +""" + +import os +import sys +import time +import json +import subprocess +import platform + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build")) + +import torch +import kt_kernel_ext +from tqdm import tqdm + +# Test parameters +expert_num = 256 +hidden_size = 7168 +intermediate_size = 2048 +num_experts_per_tok = 8 +fp8_group_size = 128 +max_len = 25600 + +layer_num = 2 +qlen = 1024 +warm_up_iter = 10 +test_iter = 30 +CPUINFER_PARAM = 80 + +CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM) + +# Result file path +script_path = os.path.abspath(__file__) +script_dir = os.path.dirname(script_path) +json_path = os.path.join(script_dir, "bench_results.jsonl") + + +def get_git_commit(): + """Get current git commit info""" + result = {} + try: + commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip() + commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip() + result["commit"] = commit + result["commit_message"] = commit_msg + dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip() + result["dirty"] = bool(dirty_output) + if dirty_output: + result["dirty_files"] = dirty_output.splitlines() + except Exception as e: + result["commit"] = None + result["error"] = str(e) + return result + + +def get_system_info(): + """Get system information""" + info = {} + uname = platform.uname() + info["system_name"] = uname.system + info["node_name"] = uname.node + + cpu_model = None + if os.path.exists("/proc/cpuinfo"): + try: + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + cpu_model = line.split(":", 1)[1].strip() + break + except Exception: + pass + info["cpu_model"] = cpu_model + info["cpu_core_count"] = os.cpu_count() + return info + + +def record_results(result, filename=json_path): + """Append result to JSON file""" + with open(filename, "a") as f: + f.write(json.dumps(result) + "\n") + + +def generate_fp8_weights_direct(shape: tuple, group_size: int = 128): + """ + Directly generate random FP8 weights and e8m0 format scale_inv. + + Args: + shape: (expert_num, n, k) - weight tensor shape + group_size: block size for scaling (128x128 blocks) + + Returns: + fp8_weights: uint8 tensor with random FP8 E4M3 values + scale_inv: fp32 tensor with e8m0 format (powers of 2) + """ + e, n, k = shape + n_blocks = n // group_size + k_blocks = k // group_size + + # Directly generate random FP8 weights as uint8 + # FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa + # Valid range for normal numbers: exp 1-14 (0 is subnormal, 15 is special) + fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous() + + # Generate e8m0 format scale_inv (powers of 2) + # e8m0: 8-bit exponent only, no mantissa, bias = 127 + # Generate random exponents in a reasonable range (e.g., -8 to 8) + exponents = torch.randint(-8, 9, (e, n_blocks, k_blocks), dtype=torch.int32, device="cuda").to("cpu").contiguous() + scale_inv = (2.0 ** exponents.float()).to(torch.float32).contiguous() + + return fp8_weights, scale_inv + + +def bench_fp8_moe(): + """Benchmark FP8 MoE performance""" + with torch.inference_mode(): + print("=" * 70) + print("FP8 MoE Kernel Performance Benchmark") + print("=" * 70) + + # Generate FP8 weights directly (no quantization from fp32) + print("\nGenerating FP8 weights directly...") + torch.manual_seed(42) + gate_fp8, gate_scales = generate_fp8_weights_direct( + (expert_num, intermediate_size, hidden_size), fp8_group_size + ) + up_fp8, up_scales = generate_fp8_weights_direct((expert_num, intermediate_size, hidden_size), fp8_group_size) + down_fp8, down_scales = generate_fp8_weights_direct( + (expert_num, hidden_size, intermediate_size), fp8_group_size + ) + + physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous() + + # Build MoE layers + print("Building FP8 MoE layers...") + moes = [] + for _ in tqdm(range(layer_num), desc="Initializing MOEs"): + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) + config.max_len = max_len + config.quant_config.bits = 8 + config.quant_config.group_size = fp8_group_size + config.quant_config.zero_point = False + + config.gate_proj = gate_fp8.data_ptr() + config.up_proj = up_fp8.data_ptr() + config.down_proj = down_fp8.data_ptr() + config.gate_scale = gate_scales.data_ptr() + config.up_scale = up_scales.data_ptr() + config.down_scale = down_scales.data_ptr() + config.pool = CPUInfer.backend_ + + moe = kt_kernel_ext.moe.AMXFP8_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + moes.append(moe) + + # Generate input data + print("Generating input data...") + gen_iter = 1000 + expert_ids = ( + torch.rand(gen_iter * qlen, expert_num, device="cpu") + .argsort(dim=-1)[:, :num_experts_per_tok] + .reshape(gen_iter, qlen * num_experts_per_tok) + .contiguous() + ) + weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").contiguous() + input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous() + output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous() + qlen_tensor = torch.tensor([qlen], dtype=torch.int32) + + # Warmup + print(f"Warming up ({warm_up_iter} iterations)...") + for i in tqdm(range(warm_up_iter), desc="Warm-up"): + CPUInfer.submit( + moes[i % layer_num].forward_task( + qlen_tensor.data_ptr(), + num_experts_per_tok, + expert_ids[i % gen_iter].data_ptr(), + weights[i % gen_iter].data_ptr(), + input_tensor[i % layer_num].data_ptr(), + output_tensor[i % layer_num].data_ptr(), + False, + ) + ) + CPUInfer.sync() + + # Benchmark + print(f"Running benchmark ({test_iter} iterations)...") + start = time.perf_counter() + for i in tqdm(range(test_iter), desc="Testing"): + CPUInfer.submit( + moes[i % layer_num].forward_task( + qlen_tensor.data_ptr(), + num_experts_per_tok, + expert_ids[i % gen_iter].data_ptr(), + weights[i % gen_iter].data_ptr(), + input_tensor[i % layer_num].data_ptr(), + output_tensor[i % layer_num].data_ptr(), + False, + ) + ) + CPUInfer.sync() + end = time.perf_counter() + total_time = end - start + + # Calculate metrics + time_per_iter_us = total_time / test_iter * 1e6 + + # FLOPS calculation: + # Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate) + # GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element) + # For vector-matrix multiply (qlen=1): 2 * n * k per matrix + flops_per_expert = ( + 2 * intermediate_size * hidden_size # gate + + 2 * intermediate_size * hidden_size # up + + 2 * hidden_size * intermediate_size # down + ) + total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter + tflops = total_flops / total_time / 1e12 + + # Bandwidth calculation (FP8 = 1 byte per element) + bytes_per_elem = 1.0 + # Weight memory: gate + up + down per expert + bandwidth = ( + hidden_size + * intermediate_size + * 3 + * num_experts_per_tok + * (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen)) + * bytes_per_elem + * test_iter + / total_time + / 1e9 + ) # 单位:GB/s + + # Print results + print("\n" + "=" * 70) + print("Benchmark Results") + print("=" * 70) + print(f"Quant mode: FP8 (E4M3) with {fp8_group_size}x{fp8_group_size} block scaling") + print(f"Total time: {total_time:.4f} s") + print(f"Iterations: {test_iter}") + print(f"Time per iteration: {time_per_iter_us:.2f} us") + print(f"Bandwidth: {bandwidth:.2f} GB/s") + print(f"TFLOPS: {tflops:.4f}") + print("") + + # Record results + result = { + "test_name": os.path.basename(__file__), + "quant_mode": "fp8_e4m3", + "total_time_seconds": total_time, + "iterations": test_iter, + "time_per_iteration_us": time_per_iter_us, + "bandwidth_GBs": bandwidth, + "flops_TFLOPS": tflops, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + "test_parameters": { + "expert_num": expert_num, + "hidden_size": hidden_size, + "intermediate_size": intermediate_size, + "num_experts_per_tok": num_experts_per_tok, + "fp8_group_size": fp8_group_size, + "layer_num": layer_num, + "qlen": qlen, + "warm_up_iter": warm_up_iter, + "test_iter": test_iter, + "CPUInfer_parameter": CPUINFER_PARAM, + }, + } + result.update(get_git_commit()) + result.update(get_system_info()) + record_results(result) + + return tflops, bandwidth + + +if __name__ == "__main__": + bench_fp8_moe() diff --git a/kt-kernel/bench/bench_fp8_write_buffer.py b/kt-kernel/bench/bench_fp8_write_buffer.py new file mode 100644 index 0000000..f14a4b7 --- /dev/null +++ b/kt-kernel/bench/bench_fp8_write_buffer.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python +# coding=utf-8 +""" +Benchmark write_weight_scale_to_buffer for AMX_FP8_MOE_TP (FP8 weights + float32 scales). + +Uses two MOE instances that alternate writing to simulate realistic multi-layer scenarios. +""" +import json +import os +import platform +import subprocess +import sys +import time + +from tqdm import tqdm + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build")) + +from kt_kernel import kt_kernel_ext +from kt_kernel_ext.moe import AMXFP8_MOE +import torch + +# Benchmark parameters +expert_num = 256 +num_experts_per_tok = 8 +gpu_tp_count = 2 + +warm_up_iter = 3 +test_iter = 7 + +gpu_experts_num = expert_num + +hidden_size = 7168 +intermediate_size = 2048 +group_size = 128 # FP8 uses 128x128 block-wise scales +max_len = 1 + +physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous() +CPUInfer = kt_kernel_ext.CPUInfer(80) + + +def get_git_commit(): + result = {} + try: + commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip() + commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip() + result["commit"] = commit + result["commit_message"] = commit_msg + dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip() + result["dirty"] = bool(dirty_output) + if dirty_output: + result["dirty_files"] = dirty_output.splitlines() + except Exception as e: + result["error"] = str(e) + return result + + +def get_system_info(): + info = {} + info["system_name"] = platform.uname().system + info["node_name"] = platform.uname().node + info["cpu_core_count"] = os.cpu_count() + if os.path.exists("/proc/cpuinfo"): + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + info["cpu_model"] = line.split(":", 1)[1].strip() + break + if os.path.exists("/proc/meminfo"): + with open("/proc/meminfo", "r") as f: + for line in f: + if "MemTotal" in line: + mem_kb = float(line.split(":", 1)[1].split()[0]) + info["memory_size_GB"] = round(mem_kb / (1024 * 1024), 2) + break + return info + + +script_path = os.path.abspath(__file__) +script_dir = os.path.dirname(script_path) +script_name = os.path.splitext(os.path.basename(script_path))[0] +json_path = os.path.join(script_dir, script_name + ".jsonl") + + +def record_results(result, filename=json_path): + with open(filename, "a") as f: + f.write(json.dumps(result) + "\n") + + +def allocate_weights(): + per_mat_weight_bytes = hidden_size * intermediate_size + n_blocks_n_gate_up = (intermediate_size + group_size - 1) // group_size + n_blocks_k = (hidden_size + group_size - 1) // group_size + per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k + per_mat_scale_elems_down = n_blocks_k * n_blocks_n_gate_up + + gate_q = ( + torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda") + .to("cpu") + .contiguous() + ) + up_q = ( + torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda") + .to("cpu") + .contiguous() + ) + down_q = ( + torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda") + .to("cpu") + .contiguous() + ) + gate_scale = ( + torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous() + ) + up_scale = ( + torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous() + ) + down_scale = ( + torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32, device="cuda").to("cpu").contiguous() + ) + + return ( + gate_q, + up_q, + down_q, + gate_scale, + up_scale, + down_scale, + per_mat_weight_bytes, + per_mat_scale_elems_gate_up, + per_mat_scale_elems_down, + ) + + +def build_moe(layer_idx=0): + """Build a single MOE instance with the given layer_idx.""" + ( + gate_q, + up_q, + down_q, + gate_scale, + up_scale, + down_scale, + per_mat_weight_bytes, + per_mat_scale_elems_gate_up, + per_mat_scale_elems_down, + ) = allocate_weights() + + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size) + config.max_len = max_len + config.layer_idx = layer_idx + config.quant_config.bits = 8 + config.quant_config.group_size = group_size + config.quant_config.zero_point = False + config.pool = CPUInfer.backend_ + config.gate_proj = gate_q.data_ptr() + config.up_proj = up_q.data_ptr() + config.down_proj = down_q.data_ptr() + config.gate_scale = gate_scale.data_ptr() + config.up_scale = up_scale.data_ptr() + config.down_scale = down_scale.data_ptr() + + moe = AMXFP8_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + + keep_tensors = { + "gate_q": gate_q, + "up_q": up_q, + "down_q": down_q, + "gate_scale": gate_scale, + "up_scale": up_scale, + "down_scale": down_scale, + } + + buffer_shapes = { + "per_mat_weight_bytes": per_mat_weight_bytes, + "per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up, + "per_mat_scale_elems_down": per_mat_scale_elems_down, + } + + return moe, buffer_shapes, keep_tensors + + +def allocate_buffers(buffer_shapes): + """Allocate shared output buffers for single expert.""" + per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"] + per_mat_scale_elems_gate_up = buffer_shapes["per_mat_scale_elems_gate_up"] + per_mat_scale_elems_down = buffer_shapes["per_mat_scale_elems_down"] + + weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count + scale_elems_per_expert_per_tp_gate_up = per_mat_scale_elems_gate_up // gpu_tp_count + scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down // gpu_tp_count + + # Each buffer stores data for a single expert + w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] + w13_scale_bufs = [ + torch.empty(2 * scale_elems_per_expert_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count) + ] + w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] + w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)] + + buffer_ptrs = { + "w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs], + "w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs], + "w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs], + "w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs], + } + + keep_tensors = { + "w13_weight_bufs": w13_weight_bufs, + "w13_scale_bufs": w13_scale_bufs, + "w2_weight_bufs": w2_weight_bufs, + "w2_scale_bufs": w2_scale_bufs, + } + + return buffer_ptrs, keep_tensors + + +def bench_write_buffer(): + # Build two MOE instances with different layer_idx + moe_0, buffer_shapes, keep_tensors_0 = build_moe(layer_idx=0) + moe_1, _, keep_tensors_1 = build_moe(layer_idx=1) + moes = [moe_0, moe_1] + + # Allocate shared buffers + buffer_ptrs, buffer_keep_tensors = allocate_buffers(buffer_shapes) + + total_weights = hidden_size * intermediate_size * expert_num * 3 + total_scale_bytes = ( + (buffer_shapes["per_mat_scale_elems_gate_up"] * 2 + buffer_shapes["per_mat_scale_elems_down"]) * expert_num * 4 + ) + bytes_per_call = total_weights + total_scale_bytes + + # Warm-up: alternate between two MOEs + for _ in tqdm(range(warm_up_iter), desc="Warm-up"): + for moe_idx, moe in enumerate(moes): + for expert_id in range(gpu_experts_num): + CPUInfer.submit( + moe.write_weight_scale_to_buffer_task(gpu_tp_count=gpu_tp_count, expert_id=expert_id, **buffer_ptrs) + ) + CPUInfer.sync() + + total_time = 0 + for iter_idx in tqdm(range(test_iter), desc="Testing"): + start = time.perf_counter() + # Alternate between two MOEs + for moe_idx, moe in enumerate(moes): + for expert_id in range(gpu_experts_num): + CPUInfer.submit( + moe.write_weight_scale_to_buffer_task(gpu_tp_count=gpu_tp_count, expert_id=expert_id, **buffer_ptrs) + ) + CPUInfer.sync() + end = time.perf_counter() + iter_time = end - start + total_time += iter_time + print(f"Iter {iter_idx}: {iter_time*1000:.2f} ms") + time.sleep(0.3) + + # bytes_per_call is for one MOE, we have 2 MOEs + bytes_per_iter = bytes_per_call * 2 + time_per_iter_ms = total_time / test_iter * 1000 + bandwidth_gbs = bytes_per_iter * test_iter / total_time / 1e9 + + print(f"\n{'='*60}") + print("FP8 write_weight_scale_to_buffer benchmark (2 MOEs alternating)") + print(f"{'='*60}") + print(f"Time per iteration: {time_per_iter_ms:.2f} ms") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + print(f"Experts per MOE: {gpu_experts_num}, MOEs: 2") + print(f"Time per expert: {time_per_iter_ms/(gpu_experts_num*2)*1000:.2f} us") + + result = { + "op": "write_weight_scale_to_buffer_fp8", + "time_per_iteration_ms": time_per_iter_ms, + "bandwidth_GBs": bandwidth_gbs, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "test_parameters": { + "expert_num": expert_num, + "hidden_size": hidden_size, + "intermediate_size": intermediate_size, + "group_size": group_size, + "gpu_tp_count": gpu_tp_count, + "bytes_per_iter": bytes_per_iter, + "num_moes": 2, + }, + } + result.update(get_git_commit()) + result.update(get_system_info()) + record_results(result) + + +if __name__ == "__main__": + bench_write_buffer() diff --git a/kt-kernel/bench/bench_k2_write_buffer.py b/kt-kernel/bench/bench_k2_write_buffer.py index 30e042c..cece01e 100644 --- a/kt-kernel/bench/bench_k2_write_buffer.py +++ b/kt-kernel/bench/bench_k2_write_buffer.py @@ -2,6 +2,8 @@ # coding=utf-8 """ Benchmark write_weight_scale_to_buffer for AMX_K2_MOE_TP (int4 packed weights + bf16 scales). + +Uses two MOE instances that alternate writing to simulate realistic multi-layer scenarios. """ import json import os @@ -17,7 +19,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build")) from kt_kernel import kt_kernel_ext import torch -# Benchmark parameters (single MoE, mirror examples/test_k2_write_buffer.py) +# Benchmark parameters expert_num = 384 num_experts_per_tok = expert_num gpu_tp_count = 4 @@ -33,7 +35,7 @@ group_size = 32 max_len = 1 physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous() -CPUInfer = kt_kernel_ext.CPUInfer(96) +CPUInfer = kt_kernel_ext.CPUInfer(80) def get_git_commit(): @@ -140,7 +142,8 @@ def allocate_weights(): ) -def build_moe(): +def build_moe(layer_idx=0): + """Build a single MOE instance with the given layer_idx.""" ( gate_q, up_q, @@ -154,6 +157,7 @@ def build_moe(): config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size) config.max_len = max_len + config.layer_idx = layer_idx config.quant_config.bits = 4 config.quant_config.group_size = group_size config.quant_config.zero_point = False @@ -170,16 +174,36 @@ def build_moe(): CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) CPUInfer.sync() - # Buffer sizing per TP + keep_tensors = { + "gate_q": gate_q, + "up_q": up_q, + "down_q": down_q, + "gate_scale": gate_scale, + "up_scale": up_scale, + "down_scale": down_scale, + } + + buffer_shapes = { + "per_mat_weight_bytes": per_mat_weight_bytes, + "per_mat_scale_elems": per_mat_scale_elems, + } + + return moe, buffer_shapes, keep_tensors + + +def allocate_buffers(buffer_shapes): + """Allocate shared output buffers for single expert.""" + per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"] + per_mat_scale_elems = buffer_shapes["per_mat_scale_elems"] + weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count - total_weight_bytes_per_tp = gpu_experts_num * weight_bytes_per_expert_per_tp - total_scale_elems_per_tp = gpu_experts_num * scale_elems_per_expert_per_tp - w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] - w13_scale_bufs = [torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)] - w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] - w2_scale_bufs = [torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)] + # Each buffer stores data for a single expert + w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] + w13_scale_bufs = [torch.empty(2 * scale_elems_per_expert_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)] + w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] + w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)] buffer_ptrs = { "w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs], @@ -188,97 +212,89 @@ def build_moe(): "w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs], } - buffer_shapes = { - "per_mat_weight_bytes": per_mat_weight_bytes, - "per_mat_scale_elems": per_mat_scale_elems, - "weight_bytes_per_expert_per_tp": weight_bytes_per_expert_per_tp, - "scale_elems_per_expert_per_tp": scale_elems_per_expert_per_tp, - "total_weight_bytes_per_tp": total_weight_bytes_per_tp, - "total_scale_elems_per_tp": total_scale_elems_per_tp, - } - keep_tensors = { - "gate_q": gate_q, - "up_q": up_q, - "down_q": down_q, - "gate_scale": gate_scale, - "up_scale": up_scale, - "down_scale": down_scale, "w13_weight_bufs": w13_weight_bufs, "w13_scale_bufs": w13_scale_bufs, "w2_weight_bufs": w2_weight_bufs, "w2_scale_bufs": w2_scale_bufs, } - return moe, buffer_ptrs, buffer_shapes, keep_tensors + return buffer_ptrs, keep_tensors def bench_write_buffer(): - moe, buffer_ptrs, buffer_shapes, keep_tensors = build_moe() + # Build two MOE instances with different layer_idx + moe_0, buffer_shapes, keep_tensors_0 = build_moe(layer_idx=0) + moe_1, _, keep_tensors_1 = build_moe(layer_idx=1) + moes = [moe_0, moe_1] + + # Allocate shared buffers + buffer_ptrs, buffer_keep_tensors = allocate_buffers(buffer_shapes) total_weights = hidden_size * intermediate_size * expert_num * 3 - # Throughput accounting consistent with examples/test_k2_write_buffer.py - bytes_per_call = total_weights // group_size + total_weights // 2 + # Throughput accounting: scale bytes (bf16) + weight bytes (int4 packed) + bytes_per_call = total_weights // group_size * 2 + total_weights // 2 - # Warm-up + # Warm-up: alternate between two MOEs for _ in tqdm(range(warm_up_iter), desc="Warm-up"): - CPUInfer.submit( - moe.write_weight_scale_to_buffer_task( - gpu_tp_count=gpu_tp_count, - gpu_experts_num=gpu_experts_num, - **buffer_ptrs, - ) - ) - CPUInfer.sync() + for moe_idx, moe in enumerate(moes): + for expert_id in range(gpu_experts_num): + CPUInfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + expert_id=expert_id, + **buffer_ptrs, + ) + ) + CPUInfer.sync() total_time = 0 - for _ in tqdm(range(test_iter), desc="Testing"): + for iter_idx in tqdm(range(test_iter), desc="Testing"): start = time.perf_counter() - CPUInfer.submit( - moe.write_weight_scale_to_buffer_task( - gpu_tp_count=gpu_tp_count, - gpu_experts_num=gpu_experts_num, - **buffer_ptrs, - ) - ) - CPUInfer.sync() + # Alternate between two MOEs + for moe_idx, moe in enumerate(moes): + for expert_id in range(gpu_experts_num): + CPUInfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + expert_id=expert_id, + **buffer_ptrs, + ) + ) + CPUInfer.sync() end = time.perf_counter() - total_time += end - start - time.sleep(0.6) - print(end - start) + iter_time = end - start + total_time += iter_time + print(f"Iter {iter_idx}: {iter_time*1000:.2f} ms") + time.sleep(0.3) - time_per_iter_us = total_time / test_iter * 1e6 - bandwidth_gbs = bytes_per_call * test_iter / total_time / 1e9 + # bytes_per_call is for one MOE, we have 2 MOEs + bytes_per_iter = bytes_per_call * 2 + time_per_iter_ms = total_time / test_iter * 1000 + bandwidth_gbs = bytes_per_iter * test_iter / total_time / 1e9 - print("write_weight_scale_to_buffer benchmark") - print("Time(s): ", total_time) - print("Iteration: ", test_iter) - print("Time(us) per iteration: ", time_per_iter_us) - print("Bandwidth: ", bandwidth_gbs, "GB/s") - print("") + print(f"\n{'='*60}") + print("K2 write_weight_scale_to_buffer benchmark (2 MOEs alternating)") + print(f"{'='*60}") + print(f"Time per iteration: {time_per_iter_ms:.2f} ms") + print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s") + print(f"Experts per MOE: {gpu_experts_num}, MOEs: 2") + print(f"Time per expert: {time_per_iter_ms/(gpu_experts_num*2)*1000:.2f} us") result = { - "op": "write_weight_scale_to_buffer", - "total_time_seconds": total_time, - "iterations": test_iter, - "time_per_iteration_us": time_per_iter_us, + "op": "write_weight_scale_to_buffer_k2", + "time_per_iteration_ms": time_per_iter_ms, "bandwidth_GBs": bandwidth_gbs, - "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "test_parameters": { "expert_num": expert_num, "hidden_size": hidden_size, "intermediate_size": intermediate_size, "group_size": group_size, - "max_len": max_len, - "num_experts_per_tok": num_experts_per_tok, "gpu_tp_count": gpu_tp_count, - "gpu_experts_num": gpu_experts_num, - "warm_up_iter": warm_up_iter, - "test_iter": test_iter, - "bytes_per_call": bytes_per_call, + "bytes_per_iter": bytes_per_iter, + "num_moes": 2, }, - "buffer_shapes": buffer_shapes, - "keep_tensors_alive": list(keep_tensors.keys()), } result.update(get_git_commit()) result.update(get_system_info()) diff --git a/kt-kernel/examples/test_fp8_moe.py b/kt-kernel/examples/test_fp8_moe.py new file mode 100644 index 0000000..0b7f2e0 --- /dev/null +++ b/kt-kernel/examples/test_fp8_moe.py @@ -0,0 +1,457 @@ +""" +Test script for GemmKernel224FP8 (FP8 MoE) kernel validation. + +This script: +1. Generates random BF16 weights +2. Quantizes them to FP8 format with 128x128 block-wise scales +3. Runs the FP8 MoE kernel +4. Compares results with PyTorch reference using dequantized BF16 weights + +FP8 format notes: +- Weight: FP8 (E4M3) stored as uint8, shape [expert_num, n, k] +- Scale: FP32, shape [expert_num, n // group_size, k // group_size], group_size=128 +""" + +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__) + "/../build") + +import torch +import kt_kernel + +torch.manual_seed(42) + +# Model config +hidden_size = 3072 +intermediate_size = 1536 +max_len = 25600 + +expert_num = 16 +num_experts_per_tok = 8 + +qlen = 100 +layer_num = 1 +CPUInfer = kt_kernel_ext.CPUInfer(40) +validation_iter = 1 +fp8_group_size = 128 # FP8 uses 128x128 block quantization +debug_print_count = 16 + +physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() + + +def act_fn(x): + """SiLU activation function""" + return x / (1.0 + torch.exp(-x)) + + +def mlp_torch(input, gate_proj, up_proj, down_proj): + """Reference MLP computation in PyTorch""" + gate_buf = torch.mm(input, gate_proj.t()) + up_buf = torch.mm(input, up_proj.t()) + intermediate = act_fn(gate_buf) * up_buf + ret = torch.mm(intermediate, down_proj.t()) + return ret + + +def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): + """Reference MoE computation in PyTorch""" + cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num)) + cnts.scatter_(1, expert_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = expert_ids.view(-1).argsort() + sorted_tokens = input[idxs // expert_ids.shape[1]] + + outputs = [] + start_idx = 0 + + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + t_output = ( + new_x.view(*expert_ids.shape, -1) + .type(weights.dtype) + .mul_(weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return t_output + + +# FP8 E4M3 constants +FP8_E4M3_MAX = 448.0 # Maximum representable value in FP8 E4M3 + + +def fp8_e4m3_to_float(fp8_val: int) -> float: + """ + Convert FP8 E4M3 value to float. + FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits + """ + sign = (fp8_val >> 7) & 1 + exp = (fp8_val >> 3) & 0xF + mant = fp8_val & 0x7 + + if exp == 0: + # Subnormal or zero + if mant == 0: + return -0.0 if sign else 0.0 + # Subnormal: value = (-1)^sign * 2^(-6) * (0.mant) + return ((-1) ** sign) * (2**-6) * (mant / 8.0) + elif exp == 15: + # NaN (FP8 E4M3 doesn't have Inf, all exp=15 are NaN) + return float("nan") + else: + # Normal: value = (-1)^sign * 2^(exp-7) * (1.mant) + return ((-1) ** sign) * (2 ** (exp - 7)) * (1.0 + mant / 8.0) + + +def float_to_fp8_e4m3(val: float) -> int: + """ + Convert float to FP8 E4M3 value. + """ + if val != val: # NaN + return 0x7F # NaN representation + + sign = 1 if val < 0 else 0 + val = abs(val) + + if val == 0: + return sign << 7 + + # Clamp to max representable value + val = min(val, FP8_E4M3_MAX) + + # Find exponent + import math + + if val < 2**-9: # Subnormal threshold + # Subnormal + mant = int(round(val / (2**-9))) + mant = min(mant, 7) + return (sign << 7) | mant + + exp = int(math.floor(math.log2(val))) + 7 + exp = max(1, min(exp, 14)) # Clamp exponent to valid range + + # Calculate mantissa + mant = int(round((val / (2 ** (exp - 7)) - 1.0) * 8)) + mant = max(0, min(mant, 7)) + + # Handle overflow to next exponent + if mant > 7: + mant = 0 + exp += 1 + if exp > 14: + exp = 14 + mant = 7 + + return (sign << 7) | (exp << 3) | mant + + +def quantize_to_fp8_blockwise(weights: torch.Tensor, group_size: int = 128): + """ + Quantize BF16/FP32 weights to FP8 with block-wise scaling. + + Args: + weights: [expert_num, n, k] tensor in BF16/FP32 + group_size: Block size for quantization (default 128 for DeepSeek) + + Returns: + fp8_weights: [expert_num, n, k] uint8 tensor + scales: [expert_num, n // group_size, k // group_size] BF16 tensor (scale_inv) + """ + weights_f32 = weights.to(torch.float32) + e, n, k = weights_f32.shape + + assert n % group_size == 0, f"n ({n}) must be divisible by group_size ({group_size})" + assert k % group_size == 0, f"k ({k}) must be divisible by group_size ({group_size})" + + n_blocks = n // group_size + k_blocks = k // group_size + + # Reshape to [e, n_blocks, group_size, k_blocks, group_size] + reshaped = weights_f32.view(e, n_blocks, group_size, k_blocks, group_size) + # Move to [e, n_blocks, k_blocks, group_size, group_size] for block processing + reshaped = reshaped.permute(0, 1, 3, 2, 4) + + # Calculate max abs per block + max_abs = reshaped.abs().amax(dim=(-2, -1), keepdim=True) + max_abs = torch.clamp(max_abs, min=1e-12) + + # Scale to FP8 range: scale = max_abs / FP8_MAX + # We store scale_inv = scale (for dequantization: fp8 * scale) + scales = (max_abs / FP8_E4M3_MAX).squeeze(-1).squeeze(-1) # [e, n_blocks, k_blocks] + + # Quantize: q = round(val / scale) + scaled = reshaped / (scales.unsqueeze(-1).unsqueeze(-1) + 1e-12) + + # Convert to FP8 E4M3 using vectorized approach + # Clamp to FP8 representable range + scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX) + + # Simple quantization: round to nearest representable FP8 value + # For simplicity, we use a lookup table approach + fp8_q = torch.zeros_like(scaled, dtype=torch.uint8) + + # Vectorized FP8 quantization + sign_mask = (scaled < 0).to(torch.uint8) << 7 + abs_scaled = scaled.abs() + + # Handle different ranges + # Subnormal: 0 < |x| < 2^-6 + subnormal_mask = (abs_scaled > 0) & (abs_scaled < 2**-6) + subnormal_mant = (abs_scaled / (2**-9)).round().clamp(0, 7).to(torch.uint8) + + # Normal values + normal_mask = abs_scaled >= 2**-6 + log2_val = torch.log2(abs_scaled.clamp(min=2**-9)) + exp = (log2_val.floor() + 7).clamp(1, 14).to(torch.int32) + mant = ((abs_scaled / (2.0 ** (exp.float() - 7)) - 1.0) * 8).round().clamp(0, 7).to(torch.uint8) + + # Combine + fp8_q = torch.where(subnormal_mask, sign_mask | subnormal_mant, fp8_q) + fp8_q = torch.where(normal_mask, sign_mask | (exp.to(torch.uint8) << 3) | mant, fp8_q) + + # Reshape back to [e, n, k] + fp8_q = fp8_q.permute(0, 1, 3, 2, 4).reshape(e, n, k) + + # Scales shape: [e, n_blocks, k_blocks] -> store as [e, n_blocks, k_blocks] + scales_fp32 = scales.to(torch.float32).contiguous() + + return fp8_q.contiguous(), scales_fp32 + + +def dequantize_fp8_blockwise(fp8_weights: torch.Tensor, scales: torch.Tensor, group_size: int = 128): + """ + Dequantize FP8 weights back to BF16 for reference computation. + + Args: + fp8_weights: [expert_num, n, k] uint8 tensor + scales: [expert_num, n // group_size, k // group_size] BF16 tensor + group_size: Block size + + Returns: + dequantized: [expert_num, n, k] BF16 tensor + """ + e, n, k = fp8_weights.shape + n_blocks = n // group_size + k_blocks = k // group_size + + # Convert FP8 to float + # Build lookup table for FP8 E4M3 -> float + fp8_lut = torch.tensor([fp8_e4m3_to_float(i) for i in range(256)], dtype=torch.float32) + + # Use lookup table + fp8_float = fp8_lut[fp8_weights.to(torch.int64)] + + # Reshape for block-wise scaling + fp8_reshaped = fp8_float.view(e, n_blocks, group_size, k_blocks, group_size) + fp8_reshaped = fp8_reshaped.permute(0, 1, 3, 2, 4) # [e, n_blocks, k_blocks, group_size, group_size] + + # Apply scales + scales_f32 = scales.to(torch.float32).unsqueeze(-1).unsqueeze(-1) # [e, n_blocks, k_blocks, 1, 1] + dequantized = fp8_reshaped * scales_f32 + + # Reshape back + dequantized = dequantized.permute(0, 1, 3, 2, 4).reshape(e, n, k) + + return dequantized.to(torch.bfloat16).contiguous() + + +def build_random_fp8_weights(): + """ + Generate random BF16 weights and quantize to FP8. + + Returns: + dict with fp8 weights, scales, and original bf16 for reference + """ + torch.manual_seed(42) + + # Generate random BF16 weights with small values + gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to( + torch.bfloat16 + ) + up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to( + torch.bfloat16 + ) + down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0).to( + torch.bfloat16 + ) + + # Quantize to FP8 + gate_fp8, gate_scales = quantize_to_fp8_blockwise(gate_proj, fp8_group_size) + up_fp8, up_scales = quantize_to_fp8_blockwise(up_proj, fp8_group_size) + down_fp8, down_scales = quantize_to_fp8_blockwise(down_proj, fp8_group_size) + + # Dequantize for reference computation + gate_deq = dequantize_fp8_blockwise(gate_fp8, gate_scales, fp8_group_size) + up_deq = dequantize_fp8_blockwise(up_fp8, up_scales, fp8_group_size) + down_deq = dequantize_fp8_blockwise(down_fp8, down_scales, fp8_group_size) + + print(f"FP8 weights shape: gate={gate_fp8.shape}, up={up_fp8.shape}, down={down_fp8.shape}") + print(f"Scales shape: gate={gate_scales.shape}, up={up_scales.shape}, down={down_scales.shape}") + + # Debug: Print FP8 weight and scale info for expert 0 + print("\n=== DEBUG: FP8 Weight and Scale Info (Expert 0) ===") + print(f"gate_fp8[0] first 8x8 block:") + for i in range(8): + print(f" row {i}: {gate_fp8[0, i, :8].numpy().tobytes().hex(' ')}") + print(f"gate_fp8[0] stats: min={gate_fp8[0].min()}, max={gate_fp8[0].max()}") + print(f"gate_scales[0] first 4x4 block:\n{gate_scales[0, :4, :4]}") + print(f"gate_scales[0] stats: min={gate_scales[0].min()}, max={gate_scales[0].max()}") + + print(f"\nup_fp8[0] first 8x8 block:") + for i in range(8): + print(f" row {i}: {up_fp8[0, i, :8].numpy().tobytes().hex(' ')}") + print(f"up_fp8[0] stats: min={up_fp8[0].min()}, max={up_fp8[0].max()}") + print(f"up_scales[0] first 4x4 block:\n{up_scales[0, :4, :4]}") + print(f"up_scales[0] stats: min={up_scales[0].min()}, max={up_scales[0].max()}") + + print(f"\ndown_fp8[0] first 8x8 block:") + for i in range(8): + print(f" row {i}: {down_fp8[0, i, :8].numpy().tobytes().hex(' ')}") + print(f"down_fp8[0] stats: min={down_fp8[0].min()}, max={down_fp8[0].max()}") + print(f"down_scales[0] first 4x4 block:\n{down_scales[0, :4, :4]}") + print(f"down_scales[0] stats: min={down_scales[0].min()}, max={down_scales[0].max()}") + + return { + "gate_fp8": gate_fp8.contiguous(), + "up_fp8": up_fp8.contiguous(), + "down_fp8": down_fp8.contiguous(), + "gate_scales": gate_scales.contiguous(), + "up_scales": up_scales.contiguous(), + "down_scales": down_scales.contiguous(), + "gate_deq": gate_deq.contiguous(), + "up_deq": up_deq.contiguous(), + "down_deq": down_deq.contiguous(), + } + + +def build_moes_from_fp8_data(fp8_data: dict): + """ + Build FP8 MoE modules from quantized data. + """ + moes = [] + with torch.inference_mode(mode=True): + for _ in range(layer_num): + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) + config.max_len = max_len + config.quant_config.bits = 8 + config.quant_config.group_size = fp8_group_size + config.quant_config.zero_point = False + + # Set FP8 weight pointers + config.gate_proj = fp8_data["gate_fp8"].data_ptr() + config.up_proj = fp8_data["up_fp8"].data_ptr() + config.down_proj = fp8_data["down_fp8"].data_ptr() + + # Set scale pointers + config.gate_scale = fp8_data["gate_scales"].data_ptr() + config.up_scale = fp8_data["up_scales"].data_ptr() + config.down_scale = fp8_data["down_scales"].data_ptr() + config.pool = CPUInfer.backend_ + + moe = kt_kernel_ext.moe.AMXFP8_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + moes.append(moe) + return moes + + +def run_fp8_moe_test(): + """ + Run FP8 MoE validation test. + """ + print("\n" + "=" * 70) + print("FP8 MoE Kernel Validation Test") + print("=" * 70) + + # Build FP8 weights + print("\nGenerating and quantizing weights...") + fp8_data = build_random_fp8_weights() + + # Build MoE modules + print("\nBuilding FP8 MoE modules...") + moes = build_moes_from_fp8_data(fp8_data) + + # Get dequantized weights for reference + gate_deq = fp8_data["gate_deq"] + up_deq = fp8_data["up_deq"] + down_deq = fp8_data["down_deq"] + + diffs = [] + with torch.inference_mode(mode=True): + for i in range(validation_iter): + torch.manual_seed(100 + i) + bsz_tensor = torch.tensor([qlen], device="cpu") + expert_ids = torch.stack( + [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)] + ).contiguous() + weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100 + input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5 + output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous() + + moe = moes[i % layer_num] + CPUInfer.submit( + moe.forward_task( + bsz_tensor.data_ptr(), + num_experts_per_tok, + expert_ids.data_ptr(), + weights.data_ptr(), + input_tensor.data_ptr(), + output.data_ptr(), + False, + ) + ) + CPUInfer.sync() + + assert not torch.isnan(output).any(), "NaN values detected in CPU expert output." + assert not torch.isinf(output).any(), "Inf values detected in CPU expert output." + + # Reference computation using dequantized weights + t_output = moe_torch(input_tensor, expert_ids, weights, gate_deq, up_deq, down_deq) + + t_output_flat = t_output.flatten() + output_flat = output.flatten() + + diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12) + diffs.append(diff.item()) + print(f"Iteration {i}: relative L1 diff = {diff:.6f}") + + if i < 3: # Print detailed output for first few iterations + print(f" kernel output: {output_flat[:debug_print_count]}") + print(f" torch output: {t_output_flat[:debug_print_count]}") + + mean_diff = float(sum(diffs) / len(diffs)) + max_diff = float(max(diffs)) + min_diff = float(min(diffs)) + + print("\n" + "=" * 70) + print("FP8 MoE Test Results") + print("=" * 70) + print(f"Mean relative L1 diff: {mean_diff*100:.4f}%") + print(f"Max relative L1 diff: {max_diff*100:.4f}%") + print(f"Min relative L1 diff: {min_diff*100:.4f}%") + + # Pass/Fail criteria + threshold = 15.0 # 15% relative error threshold for FP8 + if mean_diff * 100 < threshold: + print(f"\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold") + else: + print(f"\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold") + + return {"mean": mean_diff, "max": max_diff, "min": min_diff} + + +if __name__ == "__main__": + run_fp8_moe_test() diff --git a/kt-kernel/examples/test_fp8_write_buffer.py b/kt-kernel/examples/test_fp8_write_buffer.py new file mode 100644 index 0000000..44d379c --- /dev/null +++ b/kt-kernel/examples/test_fp8_write_buffer.py @@ -0,0 +1,389 @@ +import os +import sys +import time + +import torch +import numpy as np + + +from kt_kernel import kt_kernel_ext +from kt_kernel_ext import CPUInfer +from kt_kernel_ext.moe import AMXFP8_MOE + + +def make_cpu_infer(thread_num=80): + return CPUInfer(thread_num) + + +def build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size): + cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size) + cfg.max_len = 1 + cfg.quant_config.bits = 8 # FP8 + cfg.quant_config.group_size = group_size + cfg.quant_config.zero_point = False + cfg.pool = cpuinfer.backend_ + return cfg + + +def allocate_weights(expert_num, hidden_size, intermediate_size, group_size): + """Allocate FP8 weights and scales for testing""" + # FP8 weights: 1 byte per element + per_mat_weight_bytes = hidden_size * intermediate_size + # FP8 scales: block-wise (group_size x group_size blocks), stored as float32 + n_blocks_n_gate_up = (intermediate_size + group_size - 1) // group_size + n_blocks_k = (hidden_size + group_size - 1) // group_size + per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k + + # For down: n=hidden_size, k=intermediate_size + n_blocks_n_down = n_blocks_k + n_blocks_k_down = n_blocks_n_gate_up + per_mat_scale_elems_down = n_blocks_n_down * n_blocks_k_down + + gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + + # FP8 scales are float32 + gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32) + up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32) + down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32) + + return ( + gate_q, + up_q, + down_q, + gate_scale, + up_scale, + down_scale, + per_mat_weight_bytes, + per_mat_scale_elems_gate_up, + per_mat_scale_elems_down, + ) + + +def test_with_tp(gpu_tp_count): + """Test write_weight_scale_to_buffer with a specific gpu_tp_count""" + torch.manual_seed(123) + + expert_num = 256 # Reduced for debugging + gpu_experts = expert_num # Number of experts on GPU + + num_experts_per_tok = 8 + hidden_size = 3072 + intermediate_size = 1536 # Changed from 2048 to test non-aligned case + group_size = 128 # FP8 uses 128x128 block-wise scales + + cpuinfer = make_cpu_infer() + cfg = build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size) + + ( + gate_q, + up_q, + down_q, + gate_scale, + up_scale, + down_scale, + per_mat_weight_bytes, + per_mat_scale_elems_gate_up, + per_mat_scale_elems_down, + ) = allocate_weights(expert_num, hidden_size, intermediate_size, group_size) + + cfg.gate_proj = gate_q.data_ptr() + cfg.up_proj = up_q.data_ptr() + cfg.down_proj = down_q.data_ptr() + cfg.gate_scale = gate_scale.data_ptr() + cfg.up_scale = up_scale.data_ptr() + cfg.down_scale = down_scale.data_ptr() + + moe = AMXFP8_MOE(cfg) + + physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous() + cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + cpuinfer.sync() + + # TP configuration + # Calculate sizes per TP part (per expert) - must match C++ code which uses div_up + def div_up(a, b): + return (a + b - 1) // b + + weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count + + # For W13 (gate/up): n=intermediate_size/gpu_tp, k=hidden_size + gpu_n_w13 = intermediate_size // gpu_tp_count + gpu_k_w13 = hidden_size + scale_elems_per_expert_per_tp_gate_up = div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size) + + # For W2 (down): n=hidden_size, k=intermediate_size/gpu_tp + gpu_n_w2 = hidden_size + gpu_k_w2 = intermediate_size // gpu_tp_count + scale_elems_per_expert_per_tp_down = div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size) + + # Total sizes for all gpu_experts + total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp + total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up + total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down + + # Create buffer lists for w13 (gate+up) and w2 (down) + # These hold all experts' data for each GPU TP + w13_weight_bufs = [] + w13_scale_bufs = [] + w2_weight_bufs = [] + w2_scale_bufs = [] + + for tp_idx in range(gpu_tp_count): + # w13 combines gate and up, so needs 2x the size + w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8)) + w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32)) + w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8)) + w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32)) + + print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}") + print(f"GPU TP count: {gpu_tp_count}") + print(f"Original per matrix weight bytes: {per_mat_weight_bytes}") + print(f"Original per matrix scale elements (gate/up): {per_mat_scale_elems_gate_up}") + print(f"Original per matrix scale elements (down): {per_mat_scale_elems_down}") + print(f"Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}") + print(f"Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}") + print(f"Scale elements per expert per TP (down): {scale_elems_per_expert_per_tp_down}") + print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}") + print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}") + + # Helper function to get pointers with expert offset + # write_weights_to_buffer writes one expert at a time, so we need to pass + # pointers that already point to the correct location for each expert + def get_expert_ptrs(expert_id): + w13_weight_ptrs = [] + w13_scale_ptrs = [] + w2_weight_ptrs = [] + w2_scale_ptrs = [] + + for tp_idx in range(gpu_tp_count): + # Calculate byte offsets for this expert + # w13: gate_weight + up_weight interleaved by expert + # Layout: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...] + w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp + w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up + w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp + w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down + + w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset) + w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4) # float32 = 4 bytes + w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset) + w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4) # float32 = 4 bytes + + return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs + + # Warm up + for i in range(2): + for expert_id in range(gpu_experts): + w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id) + cpuinfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + expert_id=expert_id, + w13_weight_ptrs=w13_weight_ptrs, + w13_scale_ptrs=w13_scale_ptrs, + w2_weight_ptrs=w2_weight_ptrs, + w2_scale_ptrs=w2_scale_ptrs, + ) + ) + cpuinfer.sync() + + # Timing + begin_time = time.perf_counter_ns() + for expert_id in range(gpu_experts): + w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id) + cpuinfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + expert_id=expert_id, + w13_weight_ptrs=w13_weight_ptrs, + w13_scale_ptrs=w13_scale_ptrs, + w2_weight_ptrs=w2_weight_ptrs, + w2_scale_ptrs=w2_scale_ptrs, + ) + ) + cpuinfer.sync() + end_time = time.perf_counter_ns() + elapsed_ms = (end_time - begin_time) / 1000000 + + # Calculate throughput + total_weights = hidden_size * intermediate_size * gpu_experts * 3 + total_scale_bytes = (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4 # float32 + total_bytes = total_weights + total_scale_bytes + print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms") + print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s") + + def split_expert_tensor(tensor, chunk): + """Split tensor by experts""" + return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)] + + # Split by experts first + gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes) + up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes) + down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes) + + gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up) + up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up) + down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down) + + # For down matrix + n_blocks_n = (hidden_size + group_size - 1) // group_size + n_blocks_k = (intermediate_size + group_size - 1) // group_size + n_blocks_k_per_tp = n_blocks_k // gpu_tp_count + + # Verify buffers for each TP part + for tp_idx in range(gpu_tp_count): + expected_w13_weights = [] + expected_w13_scales = [] + expected_w2_weights = [] + expected_w2_scales = [] + + weight13_per_tp = per_mat_weight_bytes // gpu_tp_count + scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count + + # Process each GPU expert + for expert_id in range(gpu_experts): + # For w13 (gate and up), the slicing is along intermediate_size (n direction) + start_weight = tp_idx * weight13_per_tp + end_weight = (tp_idx + 1) * weight13_per_tp + start_scale = tp_idx * scale13_per_tp + end_scale = (tp_idx + 1) * scale13_per_tp + + # Gate + gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight] + gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale] + + # Up + up_weight_tp = up_q_experts[expert_id][start_weight:end_weight] + up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale] + + # Down matrix needs special handling because it's sliced column-wise + # down is (hidden_size, intermediate_size) in n-major format + down_weight_tp_parts = [] + down_scale_tp_parts = [] + + # Iterate through each row to extract the corresponding parts + for row_idx in range(hidden_size): + row_weight_start = row_idx * intermediate_size + + # Direct mapping: each CPU TP corresponds to a GPU TP + tp_slice_weight_size = intermediate_size // gpu_tp_count + + tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size + + down_weight_tp_parts.append( + down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size] + ) + + # For scale: only process at block boundaries + for bn in range(n_blocks_n): + row_scale_start = bn * n_blocks_k + tp_scale_offset = row_scale_start + tp_idx * n_blocks_k_per_tp + down_scale_tp_parts.append( + down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + n_blocks_k_per_tp] + ) + + # Concatenate all slices for this TP + down_weight_tp = torch.cat(down_weight_tp_parts) + down_scale_tp = torch.cat(down_scale_tp_parts) + + # Append to expected lists - interleaved by expert: [gate0, up0, gate1, up1, ...] + expected_w13_weights.append(gate_weight_tp) + expected_w13_weights.append(up_weight_tp) + expected_w13_scales.append(gate_scale_tp) + expected_w13_scales.append(up_scale_tp) + expected_w2_weights.append(down_weight_tp) + expected_w2_scales.append(down_scale_tp) + + # Concatenate all experts for this TP part + expected_w13_weight = torch.cat(expected_w13_weights) + expected_w13_scale = torch.cat(expected_w13_scales) + expected_w2_weight = torch.cat(expected_w2_weights) + expected_w2_scale = torch.cat(expected_w2_scales) + + print(f"=== Checking TP part {tp_idx} ===") + print(f" w13 weight shape: actual={w13_weight_bufs[tp_idx].shape}, expected={expected_w13_weight.shape}") + print(f" w13 scale shape: actual={w13_scale_bufs[tp_idx].shape}, expected={expected_w13_scale.shape}") + print(f" w2 weight shape: actual={w2_weight_bufs[tp_idx].shape}, expected={expected_w2_weight.shape}") + print(f" w2 scale shape: actual={w2_scale_bufs[tp_idx].shape}, expected={expected_w2_scale.shape}") + + # Assert all checks pass + if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight): + # Find first mismatch + diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight + first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1 + print(f" w13 weight mismatch at index {first_diff_idx}") + print(f" actual: {w13_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}") + print(f" expected: {expected_w13_weight[first_diff_idx:first_diff_idx+10]}") + raise AssertionError(f"w13 weight bytes mismatch for TP {tp_idx}") + + if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale): + diff = torch.abs(w13_scale_bufs[tp_idx] - expected_w13_scale) + max_diff_idx = diff.argmax().item() + print(f" w13 scale mismatch, max diff at index {max_diff_idx}") + print(f" actual: {w13_scale_bufs[tp_idx][max_diff_idx]}") + print(f" expected: {expected_w13_scale[max_diff_idx]}") + raise AssertionError(f"w13 scale values mismatch for TP {tp_idx}") + + if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight): + diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight + first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1 + print(f" w2 weight mismatch at index {first_diff_idx}") + print(f" actual: {w2_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}") + print(f" expected: {expected_w2_weight[first_diff_idx:first_diff_idx+10]}") + raise AssertionError(f"w2 weight bytes mismatch for TP {tp_idx}") + + if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale): + diff = torch.abs(w2_scale_bufs[tp_idx] - expected_w2_scale) + max_diff_idx = diff.argmax().item() + print(f" w2 scale mismatch, max diff at index {max_diff_idx}") + print(f" actual: {w2_scale_bufs[tp_idx][max_diff_idx]}") + print(f" expected: {expected_w2_scale[max_diff_idx]}") + raise AssertionError(f"w2 scale values mismatch for TP {tp_idx}") + + print( + f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts" + ) + return True + + +def main(): + """Run tests for all gpu_tp_count values: 1, 2, 4, 8""" + tp_values = [1, 2, 4] # Test TP=8 + all_passed = True + results = {} + + print("=" * 60) + print("Testing FP8 write_weight_scale_to_buffer for TP = ", tp_values) + print("=" * 60) + + for tp in tp_values: + print(f"\n{'='*60}") + print(f"Testing with gpu_tp_count = {tp}") + print(f"{'='*60}") + try: + test_with_tp(tp) + results[tp] = "PASSED" + print(f"✓ TP={tp} PASSED") + except Exception as e: + results[tp] = f"FAILED: {e}" + all_passed = False + print(f"✗ TP={tp} FAILED: {e}") + + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + for tp, result in results.items(): + status = "✓" if "PASSED" in result else "✗" + print(f" {status} TP={tp}: {result}") + + if all_passed: + print("\n✓ ALL TESTS PASSED") + else: + print("\n✗ SOME TESTS FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/kt-kernel/examples/test_k2_write_buffer.py b/kt-kernel/examples/test_k2_write_buffer.py index 4b156ed..7f453f1 100644 --- a/kt-kernel/examples/test_k2_write_buffer.py +++ b/kt-kernel/examples/test_k2_write_buffer.py @@ -6,11 +6,6 @@ import torch import numpy as np -# Ensure we can import the local extension -# REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) -# if REPO_ROOT not in sys.path: -# sys.path.insert(0, REPO_ROOT) - from kt_kernel import kt_kernel_ext from kt_kernel_ext import CPUInfer @@ -54,12 +49,12 @@ def allocate_weights(expert_num, hidden_size, intermediate_size, group_size): ) -def main(): +def test_with_tp(gpu_tp_count): + """Test write_weight_scale_to_buffer with a specific gpu_tp_count""" torch.manual_seed(123) - expert_num = 256 # Total experts + expert_num = 8 # Reduced for faster testing gpu_experts = expert_num # Number of experts on GPU - gpu_tp_count = 2 # Number of TP parts num_experts_per_tok = 8 hidden_size = 7168 @@ -94,11 +89,7 @@ def main(): cpuinfer.sync() # TP configuration - - # Since weights are col-major, we can directly divide the total size by tp_count - # Each matrix is divided into gpu_tp_count parts in memory order - - # Calculate sizes per TP part (direct division since col-major) + # Calculate sizes per TP part (per expert) weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count @@ -107,24 +98,19 @@ def main(): total_scale_elems_per_tp = gpu_experts * scale_elems_per_expert_per_tp # Create buffer lists for w13 (gate+up) and w2 (down) + # These hold all experts' data for each GPU TP w13_weight_bufs = [] w13_scale_bufs = [] w2_weight_bufs = [] w2_scale_bufs = [] for tp_idx in range(gpu_tp_count): - # w13 combines gate and up, so needs 2x the size + # w13 combines gate and up, so needs 2x the size per expert w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8)) w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16)) w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8)) w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16)) - # Get data pointers for all buffers - w13_weight_ptrs = [buf.data_ptr() for buf in w13_weight_bufs] - w13_scale_ptrs = [buf.data_ptr() for buf in w13_scale_bufs] - w2_weight_ptrs = [buf.data_ptr() for buf in w2_weight_bufs] - w2_scale_ptrs = [buf.data_ptr() for buf in w2_scale_bufs] - print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}") print(f"GPU TP count: {gpu_tp_count}") print(f"Original per matrix weight bytes: {per_mat_weight_bytes}") @@ -133,14 +119,56 @@ def main(): print(f"Scale elements per expert per TP: {scale_elems_per_expert_per_tp}") print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}") print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}") - print(f"Total scale elements per TP (w13): {2 * total_scale_elems_per_tp}") - print(f"Total scale elements per TP (w2): {total_scale_elems_per_tp}") - for i in range(5): + # Helper function to get pointers with expert offset + # K2 write_weights_to_buffer writes one expert at a time, so we need to pass + # pointers that already point to the correct location for each expert + def get_expert_ptrs(expert_id): + w13_weight_ptrs = [] + w13_scale_ptrs = [] + w2_weight_ptrs = [] + w2_scale_ptrs = [] + + for tp_idx in range(gpu_tp_count): + # Calculate byte offsets for this expert + # w13: gate_weight + up_weight interleaved by expert + # Layout: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...] + w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp + w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp + w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp + w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp + + w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset) + w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 2) # bf16 = 2 bytes + w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset) + w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 2) # bf16 = 2 bytes + + return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs + + # Warm up + for i in range(2): + for expert_id in range(gpu_experts): + w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id) + cpuinfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + expert_id=expert_id, + w13_weight_ptrs=w13_weight_ptrs, + w13_scale_ptrs=w13_scale_ptrs, + w2_weight_ptrs=w2_weight_ptrs, + w2_scale_ptrs=w2_scale_ptrs, + ) + ) + cpuinfer.sync() + + # Timing + begin_time = time.perf_counter_ns() + for expert_id in range(gpu_experts): + w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id) cpuinfer.submit( moe.write_weight_scale_to_buffer_task( gpu_tp_count=gpu_tp_count, - gpu_experts_num=gpu_experts, + expert_id=expert_id, w13_weight_ptrs=w13_weight_ptrs, w13_scale_ptrs=w13_scale_ptrs, w2_weight_ptrs=w2_weight_ptrs, @@ -148,23 +176,10 @@ def main(): ) ) cpuinfer.sync() - - begin_time = time.perf_counter_ns() - cpuinfer.submit( - moe.write_weight_scale_to_buffer_task( - gpu_tp_count=gpu_tp_count, - gpu_experts_num=gpu_experts, - w13_weight_ptrs=w13_weight_ptrs, - w13_scale_ptrs=w13_scale_ptrs, - w2_weight_ptrs=w2_weight_ptrs, - w2_scale_ptrs=w2_scale_ptrs, - ) - ) - cpuinfer.sync() end_time = time.perf_counter_ns() elapsed_ms = (end_time - begin_time) / 1000000 - total_weights = hidden_size * intermediate_size * expert_num * 3 - total_bytes = total_weights // group_size + total_weights // 2 + total_weights = hidden_size * intermediate_size * gpu_experts * 3 + total_bytes = total_weights // group_size * 2 + total_weights // 2 # scale (bf16) + weight (int4) print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms") print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s") @@ -181,9 +196,6 @@ def main(): up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems) down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems) - # CPU TP count is always 2 in this test setup (one TP per NUMA node) - cpu_tp_count = 2 - # Verify buffers for each TP part for tp_idx in range(gpu_tp_count): expected_w13_weights = [] @@ -193,22 +205,22 @@ def main(): weight13_per_tp = per_mat_weight_bytes // gpu_tp_count scale13_per_tp = per_mat_scale_elems // gpu_tp_count - # Process each GPU expert - for expert_idx in range(gpu_experts): - # For w13 (gate and up), the slicing is straightforward + # Process each GPU expert + for expert_id in range(gpu_experts): + # For w13 (gate and up), the slicing is straightforward start_weight = tp_idx * weight13_per_tp end_weight = (tp_idx + 1) * weight13_per_tp start_scale = tp_idx * scale13_per_tp end_scale = (tp_idx + 1) * scale13_per_tp # Gate - gate_weight_tp = gate_q_experts[expert_idx][start_weight:end_weight] - gate_scale_tp = gate_scale_experts[expert_idx][start_scale:end_scale] + gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight] + gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale] # Up - up_weight_tp = up_q_experts[expert_idx][start_weight:end_weight] - up_scale_tp = up_scale_experts[expert_idx][start_scale:end_scale] + up_weight_tp = up_q_experts[expert_id][start_weight:end_weight] + up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale] # Down matrix needs special handling because it's sliced column-wise # We need to reconstruct it from column slices @@ -228,16 +240,17 @@ def main(): tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size down_weight_tp_parts.append( - down_q_experts[expert_idx][tp_weight_offset : tp_weight_offset + tp_slice_weight_size] + down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size] ) down_scale_tp_parts.append( - down_scale_experts[expert_idx][tp_scale_offset : tp_scale_offset + tp_slice_scale_size] + down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + tp_slice_scale_size] ) # Concatenate all column slices for this TP down_weight_tp = torch.cat(down_weight_tp_parts) down_scale_tp = torch.cat(down_scale_tp_parts) + # Append to expected lists - interleaved by expert: [gate0, up0, gate1, up1, ...] expected_w13_weights.append(gate_weight_tp) expected_w13_weights.append(up_weight_tp) expected_w13_scales.append(gate_scale_tp) @@ -252,16 +265,85 @@ def main(): expected_w2_scale = torch.cat(expected_w2_scales) print(f"=== Checking TP part {tp_idx} ===") + print(f" w13 weight shape: actual={w13_weight_bufs[tp_idx].shape}, expected={expected_w13_weight.shape}") + print(f" w13 scale shape: actual={w13_scale_bufs[tp_idx].shape}, expected={expected_w13_scale.shape}") + print(f" w2 weight shape: actual={w2_weight_bufs[tp_idx].shape}, expected={expected_w2_weight.shape}") + print(f" w2 scale shape: actual={w2_scale_bufs[tp_idx].shape}, expected={expected_w2_scale.shape}") # Assert all checks pass - assert torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight), f"w13 weight bytes mismatch for TP {tp_idx}" - assert torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale), f"w13 scale values mismatch for TP {tp_idx}" - assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}" - assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}" + if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight): + diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight + first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1 + print(f" w13 weight mismatch at index {first_diff_idx}") + print(f" actual: {w13_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}") + print(f" expected: {expected_w13_weight[first_diff_idx:first_diff_idx+10]}") + raise AssertionError(f"w13 weight bytes mismatch for TP {tp_idx}") + + if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale): + diff = torch.abs(w13_scale_bufs[tp_idx].float() - expected_w13_scale.float()) + max_diff_idx = diff.argmax().item() + print(f" w13 scale mismatch, max diff at index {max_diff_idx}") + print(f" actual: {w13_scale_bufs[tp_idx][max_diff_idx]}") + print(f" expected: {expected_w13_scale[max_diff_idx]}") + raise AssertionError(f"w13 scale values mismatch for TP {tp_idx}") + + if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight): + diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight + first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1 + print(f" w2 weight mismatch at index {first_diff_idx}") + print(f" actual: {w2_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}") + print(f" expected: {expected_w2_weight[first_diff_idx:first_diff_idx+10]}") + raise AssertionError(f"w2 weight bytes mismatch for TP {tp_idx}") + + if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale): + diff = torch.abs(w2_scale_bufs[tp_idx].float() - expected_w2_scale.float()) + max_diff_idx = diff.argmax().item() + print(f" w2 scale mismatch, max diff at index {max_diff_idx}") + print(f" actual: {w2_scale_bufs[tp_idx][max_diff_idx]}") + print(f" expected: {expected_w2_scale[max_diff_idx]}") + raise AssertionError(f"w2 scale values mismatch for TP {tp_idx}") print( - f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts" + f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts" ) + return True + + +def main(): + """Run tests for all gpu_tp_count values: 1, 2, 4, 8""" + tp_values = [1, 2, 4, 8] + all_passed = True + results = {} + + print("=" * 60) + print("Testing K2 write_weight_scale_to_buffer for TP = 1, 2, 4, 8") + print("=" * 60) + + for tp in tp_values: + print(f"\n{'='*60}") + print(f"Testing with gpu_tp_count = {tp}") + print(f"{'='*60}") + try: + test_with_tp(tp) + results[tp] = "PASSED" + print(f"✓ TP={tp} PASSED") + except Exception as e: + results[tp] = f"FAILED: {e}" + all_passed = False + print(f"✗ TP={tp} FAILED: {e}") + + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + for tp, result in results.items(): + status = "✓" if "PASSED" in result else "✗" + print(f" {status} TP={tp}: {result}") + + if all_passed: + print("\n✓ ALL TESTS PASSED") + else: + print("\n✗ SOME TESTS FAILED") + sys.exit(1) if __name__ == "__main__": diff --git a/kt-kernel/ext_bindings.cpp b/kt-kernel/ext_bindings.cpp index f5c4104..5326f12 100644 --- a/kt-kernel/ext_bindings.cpp +++ b/kt-kernel/ext_bindings.cpp @@ -36,6 +36,7 @@ static const bool _is_plain_ = false; #if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL) #include "operators/amx/awq-moe.hpp" +#include "operators/amx/fp8-moe.hpp" #include "operators/amx/k2-moe.hpp" #include "operators/amx/la/amx_kernels.hpp" #include "operators/amx/moe.hpp" @@ -255,7 +256,7 @@ void bind_moe_module(py::module_& moe_module, const char* name) { CPUInfer* cpuinfer; MoeClass* moe; int gpu_tp_count; - int gpu_experts_num; + int expert_id; std::vector w13_weight_ptrs; std::vector w13_scale_ptrs; std::vector w2_weight_ptrs; @@ -265,12 +266,12 @@ void bind_moe_module(py::module_& moe_module, const char* name) { static void inner(void* args) { Args* args_ = (Args*)args; args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count, - args_->gpu_experts_num, args_->w13_weight_ptrs, args_->w13_scale_ptrs, - args_->w2_weight_ptrs, args_->w2_scale_ptrs); + args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs, + args_->w2_scale_ptrs); } static std::pair cpuinfer_interface(std::shared_ptr moe, int gpu_tp_count, - int gpu_experts_num, py::list w13_weight_ptrs, + int expert_id, py::list w13_weight_ptrs, py::list w13_scale_ptrs, py::list w2_weight_ptrs, py::list w2_scale_ptrs) { // Convert Python lists to std::vector @@ -281,15 +282,59 @@ void bind_moe_module(py::module_& moe_module, const char* name) { for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast(item)); for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast(item)); - Args* args = new Args{nullptr, moe.get(), gpu_tp_count, gpu_experts_num, + Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id, w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface, - py::arg("gpu_tp_count"), py::arg("gpu_experts_num"), py::arg("w13_weight_ptrs"), - py::arg("w13_scale_ptrs"), py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs")); + py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"), + py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs")); + } + + // FP8 MoE: processes one expert at a time (expert_id instead of gpu_experts_num) + if constexpr (std::is_same_v>) { + struct WriteWeightScaleToBufferBindings { + struct Args { + CPUInfer* cpuinfer; + MoeClass* moe; + int gpu_tp_count; + int expert_id; + std::vector w13_weight_ptrs; + std::vector w13_scale_ptrs; + std::vector w2_weight_ptrs; + std::vector w2_scale_ptrs; + }; + + static void inner(void* args) { + Args* args_ = (Args*)args; + args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count, + args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs, + args_->w2_scale_ptrs); + } + + static std::pair cpuinfer_interface(std::shared_ptr moe, int gpu_tp_count, + int expert_id, py::list w13_weight_ptrs, + py::list w13_scale_ptrs, py::list w2_weight_ptrs, + py::list w2_scale_ptrs) { + // Convert Python lists to std::vector + std::vector w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec; + + for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast(item)); + for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast(item)); + for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast(item)); + for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast(item)); + + Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id, + w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec}; + return std::make_pair((intptr_t)&inner, (intptr_t)args); + } + }; + + moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface, + py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"), + py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs")); } #endif } @@ -562,6 +607,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) { bind_moe_module>(moe_module, "AMXInt4_1_MOE"); bind_moe_module>(moe_module, "AMXInt4_1KGroup_MOE"); bind_moe_module>(moe_module, "AMXInt4_KGroup_MOE"); + bind_moe_module>(moe_module, "AMXFP8_MOE"); #endif #if defined(USE_MOE_KERNEL) bind_moe_module>(moe_module, "Int8_KERNEL_MOE"); diff --git a/kt-kernel/operators/amx/awq-moe.hpp b/kt-kernel/operators/amx/awq-moe.hpp index 9936f38..23cef12 100644 --- a/kt-kernel/operators/amx/awq-moe.hpp +++ b/kt-kernel/operators/amx/awq-moe.hpp @@ -1,73 +1,49 @@ /** - * @Description : - * @Author : chenht2022 + * @Description : AWQ Int4 AMX MoE operator with KGroup quantization and zero-point support + * @Author : chenht2022, oql * @Date : 2024-07-22 02:03:22 - * @Version : 1.0.0 - * @LastEditors : chenht2022 - * @LastEditTime : 2024-07-25 10:35:10 + * @Version : 2.0.0 + * @LastEditors : oql + * @LastEditTime : 2025-12-10 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + * + * This file implements AWQ Int4 MoE using CRTP pattern, inheriting from moe_base.hpp. + * AWQ weights are stored with group-wise scales and zero-points (KGroup Int4 with zeros). **/ #ifndef CPUINFER_OPERATOR_AMX_AWQ_MOE_H #define CPUINFER_OPERATOR_AMX_AWQ_MOE_H // #define CHECK -#include -#include -#include -// #define FORWARD_TIME_PROFILE -// #define FORWARD_TIME_REPORT - -#include - -#include -#include -#include -#include -#include -#include - -#include "../../cpu_backend/shared_mem_buffer.h" -#include "../../cpu_backend/worker_pool.h" -#include "../common.hpp" -#include "../moe-tp.hpp" -#include "la/amx.hpp" -#include "llama.cpp/ggml.h" +#include "moe_base.hpp" +/** + * @brief AWQ Int4 MoE operator using CRTP pattern + * @tparam T Kernel type for AWQ quantization + * + * This class provides AWQ-specific implementations: + * - do_gate_up_gemm: Int4 weight with KGroup scale + zeros + AMX GEMM + * - do_down_gemm: Same Int4 KGroup GEMM + * - load_weights: Load Int4 weights with group-wise scales and zero-points + */ template -class AMX_AWQ_MOE_TP { +class AMX_AWQ_MOE_TP : public AMX_MOE_BASE> { private: - int tp_part_idx; + using Base = AMX_MOE_BASE>; + using Base::config_; + using Base::tp_part_idx; + using Base::gate_bb_; + using Base::up_bb_; + using Base::down_bb_; + using Base::gate_up_ba_; + using Base::gate_bc_; + using Base::up_bc_; + using Base::down_ba_; + using Base::down_bc_; + using Base::m_local_num_; + std::filesystem::path prefix; - void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if - // quantized)] - void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if - // quantized)] - void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if - // quantized)] - - ggml_bf16_t* m_local_input_; // [num_experts_per_tok * max_len * hidden_size] - ggml_bf16_t* m_local_gate_output_; // [num_experts_per_tok * max_len * intermediate_size] - ggml_bf16_t* m_local_up_output_; // [num_experts_per_tok * max_len * intermediate_size] - ggml_bf16_t* m_local_down_output_; // [num_experts_per_tok * max_len * hidden_size] - - std::vector> m_local_pos_; // [max_len, num_experts_per_tok] - std::vector m_local_num_; // [expert_num] - std::vector m_expert_id_map_; // [expert_num] - std::vector m_local_input_ptr_; // [expert_num] - std::vector m_local_gate_output_ptr_; // [expert_num] - std::vector m_local_up_output_ptr_; // [expert_num] - std::vector m_local_down_output_ptr_; // [expert_num] - - std::vector> gate_up_ba_; - std::vector> gate_bb_; - std::vector> gate_bc_; - std::vector> up_bb_; - std::vector> up_bc_; - std::vector> down_ba_; - std::vector> down_bb_; - std::vector> down_bc_; #ifdef CHECK char verify_bb[100000000]; char check_bb[100000000]; @@ -274,32 +250,35 @@ class AMX_AWQ_MOE_TP { zeros_size / mat_split); zeros_file.close(); } + #ifdef CHECK inline void load_check() { memcpy(check_bb, (char*)down_bb_[compare_expers]->b, - T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)); + T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size)); } void verify_load_right() { - // printf("varify down bb_0 %d\n", tp_part_idx); memcpy(verify_bb, (char*)down_bb_[compare_expers]->b, - T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)); - // check if verify_bb_0 equal to check_bb_0 - if (memcmp(verify_bb, check_bb, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)) != 0) { + T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size)); + if (memcmp(verify_bb, check_bb, + T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, + config_.quant_config.group_size)) != 0) { printf("verify error\n"); - for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); ++i) { + for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, + config_.quant_config.group_size); + ++i) { if (verify_bb[i] != check_bb[i]) { printf("Difference at byte %zu: verify_bb_%d[%zu] = %02x, check_bb[%zu] = %02x\n", i, compare_expers, i, (unsigned char)verify_bb[i], i, (unsigned char)check_bb[i]); - break; // find the first difference and exit + break; } } assert(0); } else { printf("pass verify\n"); - // pick out the 100th~150th byte of scale to see printf("numa %d, verify_bb_%d:\n", tp_part_idx, compare_expers); - size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); + size_t size = + T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size); size_t scale_size = config_.hidden_size * sizeof(float); for (size_t i = size - scale_size; i < size - scale_size + 50; ++i) { printf("%02x ", (unsigned char)verify_bb[i]); @@ -392,7 +371,7 @@ class AMX_AWQ_MOE_TP { } // AVX-optimized function to convert INT4 zeros to float mins - // mins = zeros * scales (element-wise), where scales is float format + // mins = -(zeros * scales) (element-wise), where scales is float format inline void convert_zeros_to_mins_avx(const uint32_t* zeros_int4_packed, const float* scales, float* mins, size_t num_elements) { constexpr size_t simd_width = 8; // 每次解 8 个 int4 @@ -408,30 +387,25 @@ class AMX_AWQ_MOE_TP { } } -#ifdef FORWARD_TIME_REPORT - std::chrono::time_point last_now; -#endif - public: - using input_t = ggml_bf16_t; - using output_t = float; - GeneralMOEConfig config_; - static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE; + using typename Base::input_t; + using typename Base::output_t; - AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx) { - auto& quant_config = config.quant_config; - int& group_size = quant_config.group_size; + AMX_AWQ_MOE_TP() = default; + + AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) { + auto& quant_config = config_.quant_config; if (quant_config.group_size == 0 || !quant_config.zero_point) { throw std::runtime_error("AWQ-Quantization AMX MoE only support KGroup Int4_1"); } - auto& load = config.load; - auto& save = config.save; - if (load && config.path == "") { - load = false; - } - prefix = config.path; - prefix = prefix / ("_layer_" + std::to_string(config.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx)); + printf("Creating AMX_AWQ_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu())); + + auto& load = config_.load; + auto& save = config_.save; + + prefix = config_.path; + prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx_)); if (save) { std::cout << "Creating " << prefix << std::endl; std::filesystem::create_directories(prefix); @@ -443,77 +417,74 @@ class AMX_AWQ_MOE_TP { throw std::runtime_error("Path not found: " + prefix.string()); } } - - this->tp_part_idx = tp_part_idx; - config_ = config; - gate_proj_ = config_.gate_proj; - up_proj_ = config_.up_proj; - down_proj_ = config_.down_proj; - - MemoryRequest mem_requests; - mem_requests.append_pointer( - &m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size); - mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * - config_.max_len * config_.intermediate_size); - mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * - config_.max_len * config_.intermediate_size); - mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * - config_.max_len * config_.hidden_size); - - m_local_pos_.resize(config_.max_len); - for (int i = 0; i < config_.max_len; i++) { - m_local_pos_[i].resize(config_.num_experts_per_tok); - } - m_expert_id_map_.resize(config_.expert_num); - m_local_num_.resize(config_.expert_num); - m_local_input_ptr_.resize(config_.expert_num); - m_local_gate_output_ptr_.resize(config_.expert_num); - m_local_up_output_ptr_.resize(config_.expert_num); - m_local_down_output_ptr_.resize(config_.expert_num); - - for (size_t i = 0; i < config_.expert_num; i++) { - gate_up_ba_.push_back( - std::make_shared(config_.max_len, config_.hidden_size, group_size, nullptr)); - gate_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); - up_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); - down_ba_.push_back( - std::make_shared(config_.max_len, config_.intermediate_size, group_size, nullptr)); - down_bc_.push_back(std::make_shared(config_.max_len, config_.hidden_size, nullptr)); - - void* gate_bb_ptr = - std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size)); - gate_bb_.push_back(std::make_shared(config_.intermediate_size, config_.hidden_size, - group_size, gate_bb_ptr)); - - void* up_bb_ptr = - std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size)); - up_bb_.push_back( - std::make_shared(config_.intermediate_size, config_.hidden_size, group_size, up_bb_ptr)); - - void* down_bb_ptr = - std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size)); - down_bb_.push_back(std::make_shared(config_.hidden_size, config_.intermediate_size, - group_size, down_bb_ptr)); - } - for (int i = 0; i < config_.expert_num; i++) { - mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); }, - T::BufferA::required_size(config_.max_len, config_.hidden_size, group_size)); - mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); }, - T::BufferC::required_size(config_.max_len, config_.intermediate_size)); - mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); }, - T::BufferC::required_size(config_.max_len, config_.intermediate_size)); - mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); }, - T::BufferA::required_size(config_.max_len, config_.intermediate_size, group_size)); - mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); }, - T::BufferC::required_size(config_.max_len, config_.hidden_size)); - } - shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests); } - ~AMX_AWQ_MOE_TP() { - // shared_mem_buffer_numa.dealloc(this); + ~AMX_AWQ_MOE_TP() = default; + + // ============================================================================ + // CRTP buffer creation - with group_size (AWQ uses zero-point) + // ============================================================================ + + size_t buffer_a_required_size_impl(size_t m, size_t k) const { + return T::BufferA::required_size(m, k, config_.quant_config.group_size); + } + size_t buffer_b_required_size_impl(size_t n, size_t k) const { + return T::BufferB::required_size(n, k, config_.quant_config.group_size); + } + size_t buffer_c_required_size_impl(size_t m, size_t n) const { + return T::BufferC::required_size(m, n); } + std::shared_ptr make_buffer_a_impl(size_t m, size_t k, void* data) const { + return std::make_shared(m, k, config_.quant_config.group_size, data); + } + std::shared_ptr make_buffer_b_impl(size_t n, size_t k, void* data) const { + return std::make_shared(n, k, config_.quant_config.group_size, data); + } + std::shared_ptr make_buffer_c_impl(size_t m, size_t n, void* data) const { + return std::make_shared(m, n, data); + } + + // ============================================================================ + // CRTP virtual points - GEMM dispatch (uses kgroup with zeros) + // ============================================================================ + + void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) { + auto& group_size = config_.quant_config.group_size; + int m = m_local_num_[expert_idx]; + auto& ba = gate_up_ba_[expert_idx]; + auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx]; + auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx]; + + // Dispatch based on qlen threshold + if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { + amx::mat_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth); + } else { + amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth); + } + } + + void do_down_gemm(int expert_idx, int ith, int nth, int qlen) { + auto& group_size = config_.quant_config.group_size; + int m = m_local_num_[expert_idx]; + + if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { + amx::mat_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], + down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); + } else { + amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], + down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); + } + } + + /** + * @brief Load Int4 weights with scales and zero-points + * + * AWQ weights include: + * - Quantized INT4 weights + * - FP16 scales (converted to FP32) + * - INT4 zeros (converted to FP32 mins = -scale * zero) + */ void load_weights() { auto& quant_config = config_.quant_config; int& group_size = quant_config.group_size; @@ -524,15 +495,12 @@ class AMX_AWQ_MOE_TP { auto pool = config_.pool->get_subpool(tp_part_idx); if (config_.gate_projs.size()) { - throw std::runtime_error("AMX load weights is not support"); + throw std::runtime_error("AMX load weights from gate_projs is not supported"); } else { - // AWQ Load from file implementation int nth = T::recommended_nth(config_.intermediate_size); - static uint8_t mat_type_all = 3, mat_split = 1; if (config_.load) { - throw std::runtime_error("AMX load weights from file is not support"); + throw std::runtime_error("AMX load weights from file is not supported"); } -// check process, store down matrix to check #ifdef CHECK load_check(); #endif @@ -540,7 +508,7 @@ class AMX_AWQ_MOE_TP { else if (config_.gate_scale != nullptr) #endif { - // Loading quantized weights + // Loading quantized weights with scales and zeros pool->do_work_stealing_job( nth * config_.expert_num, nullptr, [this, nth, physical_to_logical_map](int task_id) { @@ -594,7 +562,7 @@ class AMX_AWQ_MOE_TP { (ggml_fp16_t*)config_.down_scale + (logical_expert_id * scale_elem_count), scale_elem_count); - // Convert INT4 zeros to FP32 mins + // Convert INT4 zeros to FP32 mins: mins = -(scale * zero) convert_zeros_to_mins_avx( (const uint32_t*)((uint8_t*)config_.gate_zero + ((logical_expert_id * scale_elem_count) >> 1)), gate_bb_[expert_idx]->d, gate_bb_[expert_idx]->mins, scale_elem_count); @@ -617,7 +585,7 @@ class AMX_AWQ_MOE_TP { } } else { - // Online Quantization + // Online Quantization from BF16 assert(config_.gate_proj != nullptr); pool->do_work_stealing_job( @@ -668,450 +636,21 @@ class AMX_AWQ_MOE_TP { } } - void warm_up() { - int qlen = config_.max_len; - std::vector input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); - std::vector output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); - std::vector expert_ids(qlen * config_.num_experts_per_tok); - std::vector weights(qlen * config_.num_experts_per_tok); - for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) { - expert_ids[i] = i % config_.expert_num; - weights[i] = 0.01; - } - forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data()); - } - - void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { - if (qlen > 1) { - forward_prefill(qlen, k, expert_ids, weights, input, output); - } else { - forward_decode(k, expert_ids, weights, input, output); - } - } - -#define DIRECT_OR_POOL_BY_QLEN(var, fn) \ - do { \ - if (qlen < 10) { \ - for (int i = 0; i < (var); i++) { \ - (fn)(i); \ - } \ - } else { \ - pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \ - } \ - } while (0) - -#define MATMUL_OR_VECMUL_KGROUP_BY_QLEN(...) \ - do { \ - if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { \ - amx::mat_mul_kgroup(__VA_ARGS__); \ - } else { \ - amx::vec_mul_kgroup(__VA_ARGS__); \ - } \ - } while (0) - - void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, - void* output) { - auto pool = config_.pool->get_subpool(tp_part_idx); - auto& quant_config = config_.quant_config; - int& group_size = quant_config.group_size; -#ifdef FORWARD_TIME_PROFILE - auto start_time = std::chrono::high_resolution_clock::now(); - auto last = start_time; - // 用于保存各阶段耗时(单位:微秒) - long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0; - long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; - int max_local_num = 0; // 记录最大的 local num -#endif - - int activated_expert = 0; - for (int i = 0; i < config_.expert_num; i++) { - m_local_num_[i] = 0; - } - for (int i = 0; i < qlen; i++) { - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; - } - } - - for (int i = 0; i < config_.expert_num; i++) { - if (m_local_num_[i] > 0) { -#ifdef FORWARD_TIME_PROFILE - max_local_num = std::max(max_local_num, m_local_num_[i]); -#endif - m_expert_id_map_[activated_expert] = i; - activated_expert++; - } - } - - // activated_expert 已经统计完成 - - size_t offset = 0; - for (int i = 0; i < config_.expert_num; i++) { - m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size; - m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; - m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; - m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; - offset += m_local_num_[i]; - } -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - prepare_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) { - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size, - (ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); - } - }); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - cpy_input_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) { - int expert_idx = m_expert_id_map_[task_id]; - gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); - }); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_input_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - int nth = T::recommended_nth(config_.intermediate_size); - pool->do_work_stealing_job( - nth * activated_expert * 2, [](int _) { T::config(); }, - [this, nth, qlen](int task_id2) { - int& group_size = config_.quant_config.group_size; - int task_id = task_id2 / 2; - bool do_up = task_id2 % 2; - int expert_idx = m_expert_id_map_[task_id / nth]; - - int ith = task_id % nth; - if (do_up) { - MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, - group_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], - ith, nth); - up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); - } else { - MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, - group_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx], - gate_bc_[expert_idx], ith, nth); - gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); - } - }, - nullptr); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - up_gate_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - auto up_gate_fn = [this, nth](int task_id) { - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); - for (int i = 0; i < m_local_num_[expert_idx]; i++) { - ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; - ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; - for (int j = n_start; j < n_end; j += 32) { - __m512 gate_val0, gate_val1, up_val0, up_val1; - avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); - avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); - __m512 result0 = amx::act_fn(gate_val0, up_val0); - __m512 result1 = amx::act_fn(gate_val1, up_val1); - avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); - } - } - }; - DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - act_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - pool->do_work_stealing_job( - activated_expert, nullptr, - [this](int task_id) { - int expert_idx = m_expert_id_map_[task_id]; - down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1); - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - nth = T::recommended_nth(config_.hidden_size); - pool->do_work_stealing_job( - nth * activated_expert, [](int _) { T::config(); }, - [this, nth, qlen](int task_id) { - int& group_size = config_.quant_config.group_size; - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, - group_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx], - ith, nth); - down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth); - }, - nullptr); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - pool->do_work_stealing_job( - qlen, nullptr, - [this, nth, output, k, expert_ids, weights](int i) { - for (int e = 0; e < config_.hidden_size; e += 32) { - __m512 x0 = _mm512_setzero_ps(); - __m512 x1 = _mm512_setzero_ps(); - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - __m512 weight = _mm512_set1_ps(weights[i * k + j]); - __m512 down_output0, down_output1; - avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] + - m_local_pos_[i][j] * config_.hidden_size + e), - &down_output0, &down_output1); - x0 = _mm512_fmadd_ps(down_output0, weight, x0); - x1 = _mm512_fmadd_ps(down_output1, weight, x1); - } - auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e); - f32out[0] = x0; - f32out[1] = x1; - } - }, - nullptr); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - weight_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } - auto end_time = std::chrono::high_resolution_clock::now(); - auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); - // 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen - printf( - "Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, " - "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: " - "%d, qlen: %d\n", - tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time, - down_time, weight_time, forward_total_time, max_local_num, qlen); -#endif - } - - void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { - int qlen = 1; - auto pool = config_.pool->get_subpool(tp_part_idx); - auto& quant_config = config_.quant_config; - int& group_size = quant_config.group_size; -#ifdef FORWARD_TIME_PROFILE - auto start_time = std::chrono::high_resolution_clock::now(); - auto last = start_time; - // 用于保存各阶段耗时(单位:微秒) - long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0; - long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; - int max_local_num = 0; // 记录最大的 local num -#endif - - int activated_expert = 0; - for (int i = 0; i < k; i++) { - if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) { - continue; - } - m_expert_id_map_[activated_expert] = expert_ids[i]; - activated_expert++; - } - - size_t offset = 0; - for (int i = 0; i < activated_expert; i++) { - auto expert_idx = m_expert_id_map_[i]; - m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size; - m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size; - m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size; - offset += qlen; - } - - gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_input_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - int nth = T::recommended_nth(config_.intermediate_size); - pool->do_work_stealing_job( - nth * activated_expert * 2, [](int _) { T::config(); }, - [this, nth, qlen](int task_id2) { - int& group_size = config_.quant_config.group_size; - int task_id = task_id2 / 2; - bool do_up = task_id2 % 2; - int expert_idx = m_expert_id_map_[task_id / nth]; - - int ith = task_id % nth; - if (do_up) { - amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0], - up_bb_[expert_idx], up_bc_[expert_idx], ith, nth); - up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth); - } else { - amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0], - gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth); - gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth); - } - }, - nullptr); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - up_gate_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - for (int task_id = 0; task_id < nth * activated_expert; task_id++) { - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); - for (int i = 0; i < qlen; i++) { - ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; - ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; - for (int j = n_start; j < n_end; j += 32) { - __m512 gate_val0, gate_val1, up_val0, up_val1; - avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); - avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); - __m512 result0 = amx::act_fn(gate_val0, up_val0); - __m512 result1 = amx::act_fn(gate_val1, up_val1); - avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); - } - } - } - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - act_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - pool->do_work_stealing_job( - activated_expert, nullptr, - [this, qlen](int task_id) { - int expert_idx = m_expert_id_map_[task_id]; - down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1); - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - nth = T::recommended_nth(config_.hidden_size); - pool->do_work_stealing_job( - nth * activated_expert, [](int _) { T::config(); }, - [this, nth, qlen](int task_id) { - int& group_size = config_.quant_config.group_size; - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - amx::vec_mul_kgroup(qlen, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], - down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); - down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth); - }, - nullptr); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - for (int i = 0; i < qlen; i++) { - for (int e = 0; e < config_.hidden_size; e += 32) { - __m512 x0 = _mm512_setzero_ps(); - __m512 x1 = _mm512_setzero_ps(); - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - __m512 weight = _mm512_set1_ps(weights[i * k + j]); - __m512 down_output0, down_output1; - avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] + - m_local_pos_[i][j] * config_.hidden_size + e), - &down_output0, &down_output1); - x0 = _mm512_fmadd_ps(down_output0, weight, x0); - x1 = _mm512_fmadd_ps(down_output1, weight, x1); - } - auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e); - f32out[0] = x0; - f32out[1] = x1; - } - } - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - weight_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } - auto end_time = std::chrono::high_resolution_clock::now(); - auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); - // 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen - printf( - "Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, " - "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n", - tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time, - forward_total_time); -#endif - } + // forward, forward_prefill, forward_decode, warm_up are inherited from Base }; +// ============================================================================ +// TP_MOE specialization for AMX_AWQ_MOE_TP +// Inherits from TP_MOE> to reuse merge_results implementation +// ============================================================================ + template -class TP_MOE> : public TP_MOE_Common> { +class TP_MOE> : public TP_MOE>> { public: - using TP_MOE_Common>::TP_MOE_Common; - void load_weights() { + using Base = TP_MOE>>; + using Base::Base; + + void load_weights() override { auto& config = this->config; auto& tps = this->tps; auto& tp_count = this->tp_count; @@ -1157,7 +696,7 @@ class TP_MOE> : public TP_MOE_Common> { ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), ((sizeof(uint8_t) * weight_elem_count) >> 1)); - // zeros TP-slicing + // down scales and zeros TP-slicing memcpy((ggml_fp16_t*)tpc.down_scale + (expert_id * scales_elem_count), (ggml_fp16_t*)config.down_scale + (expert_id * (config.intermediate_size / group_size) * config.hidden_size + @@ -1172,7 +711,7 @@ class TP_MOE> : public TP_MOE_Common> { (sizeof(uint8_t) * scales_elem_count) >> 1); for (size_t kg = 0; kg < config.hidden_size / group_size; kg++) { - // copy scale + // copy gate/up scales memcpy((ggml_fp16_t*)tpc.gate_scale + (expert_id * scales_elem_count) + kg * tpc.intermediate_size, (ggml_fp16_t*)config.gate_scale + (expert_id * ((config.hidden_size / group_size) * config.intermediate_size) + @@ -1185,7 +724,7 @@ class TP_MOE> : public TP_MOE_Common> { kg * config.intermediate_size + i * tpc.intermediate_size), (sizeof(ggml_fp16_t) * tpc.intermediate_size)); - // zeros TP-slicing + // copy gate/up zeros TP-slicing memcpy( (uint8_t*)tpc.gate_zero + (((expert_id * scales_elem_count) + kg * tpc.intermediate_size) >> 1), (uint8_t*)config.gate_zero + @@ -1202,6 +741,7 @@ class TP_MOE> : public TP_MOE_Common> { ((sizeof(uint8_t) * tpc.intermediate_size) >> 1)); } + // down weights TP-slicing (column-wise) for (size_t col = 0; col < config.hidden_size; col++) { memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1), (uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size + @@ -1285,37 +825,7 @@ class TP_MOE> : public TP_MOE_Common> { } } - void merge_results(int qlen, void* output, bool incremental) { - auto pool = this->config.pool; - auto merge_fn = [this, output, incremental](int token_nth) { - auto& local_output_numa = this->local_output_numa; - auto& tp_configs = this->tp_configs; - auto& tp_count = this->tp_count; - auto& config = this->config; - float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size; - if (incremental) { - for (int e = 0; e < config.hidden_size; e += 32) { - __m512 x0, x1; - avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1); - *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0); - *((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1); - } - } - for (int i = 1; i < tp_count; i++) { - float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size; - for (int e = 0; e < tp_configs[i].hidden_size; e += 16) { - *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e))); - } - } - for (int e = 0; e < config.hidden_size; e += 32) { - __m512 x0 = *(__m512*)(merge_to + e); - __m512 x1 = *(__m512*)(merge_to + e + 16); - avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e)); - } - }; - DIRECT_OR_POOL_BY_QLEN(qlen, merge_fn); - } - void merge_results(int qlen, void* output) { merge_results(qlen, output, false); } + // merge_results is inherited from TP_MOE>> }; #endif diff --git a/kt-kernel/operators/amx/fp8-moe.hpp b/kt-kernel/operators/amx/fp8-moe.hpp new file mode 100644 index 0000000..7bb7b83 --- /dev/null +++ b/kt-kernel/operators/amx/fp8-moe.hpp @@ -0,0 +1,782 @@ +/** + * @Description : FP8 AMX MoE operator for DeepSeek V3.2 native inference + * @Author : oql, Codex and Claude + * @Date : 2025-12-09 + * @Version : 1.0.0 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + * + * This file implements FP8 MoE using CRTP pattern, inheriting from moe_base.hpp. + * FP8 weights are stored with 128x128 block-wise scales. + **/ +#ifndef CPUINFER_OPERATOR_AMX_FP8_MOE_H +#define CPUINFER_OPERATOR_AMX_FP8_MOE_H + +// #define DEBUG_FP8_MOE + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "la/amx_raw_buffers.hpp" +#include "la/amx_raw_kernels.hpp" +#include "moe_base.hpp" + +/** + * @brief FP8 MoE operator using CRTP pattern + * @tparam T Kernel type, defaults to GemmKernel224FP8 + * + * This class provides FP8-specific implementations: + * - do_gate_up_gemm, do_down_gemm : FP8 weight -> BF16 conversion mat mul + * - load_weights: Load FP8 weights with 128x128 block scales + */ +template +class AMX_FP8_MOE_TP : public AMX_MOE_BASE> { + using Base = AMX_MOE_BASE>; + using Base::config_; + using Base::down_ba_; + using Base::down_bb_; + using Base::down_bc_; + using Base::gate_bb_; + using Base::gate_bc_; + using Base::gate_up_ba_; + using Base::m_local_num_; + using Base::tp_part_idx; + using Base::up_bb_; + using Base::up_bc_; + + public: + using typename Base::input_t; + using typename Base::output_t; + + AMX_FP8_MOE_TP() = default; + + AMX_FP8_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) { + auto& quant_config = config_.quant_config; + if (quant_config.group_size == 0 || quant_config.zero_point) { + throw std::runtime_error("KT-Kernel fp8 MoE only support block-wise FP8. group_size = %d, zero_point = %d", + quant_config.group_size, quant_config.zero_point); + } + printf("Created AMX_FP8_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu())); + } + + ~AMX_FP8_MOE_TP() = default; + // ============================================================================ + // CRTP buffer creation - with group_size + // ============================================================================ + + size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); } + size_t buffer_b_required_size_impl(size_t n, size_t k) const { + return T::BufferB::required_size(n, k, config_.quant_config.group_size); + } + size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); } + + std::shared_ptr make_buffer_a_impl(size_t m, size_t k, void* data) const { + return std::make_shared(m, k, data); + } + std::shared_ptr make_buffer_b_impl(size_t n, size_t k, void* data) const { + return std::make_shared(n, k, config_.quant_config.group_size, data); + } + std::shared_ptr make_buffer_c_impl(size_t m, size_t n, void* data) const { + return std::make_shared(m, n, data); + } + + // ============================================================================ + // CRTP virtual points - GEMM dispatch + // ============================================================================ + + void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) { + auto& group_size = config_.quant_config.group_size; + int m = m_local_num_[expert_idx]; + auto& ba = gate_up_ba_[expert_idx]; + auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx]; + auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx]; + + amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth); + } + void do_down_gemm(int expert_idx, int ith, int nth, int qlen) { + auto& group_size = config_.quant_config.group_size; + int m = m_local_num_[expert_idx]; + + amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], + down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); + } + +#ifdef DEBUG_FP8_MOE + // Function to dump Buffer B data for debugging FP8 quantization results + inline void dump_buffer_b(const std::string& quantization_type, int expert_idx, const std::string& matrix_type, + typename T::BufferB* buffer) { + auto& quant_config = config_.quant_config; + int& group_size = quant_config.group_size; + + printf("[DUMP_BUFFER_B] TP%d %s Expert%d %s:\n", tp_part_idx, quantization_type.c_str(), expert_idx, + matrix_type.c_str()); + + // Calculate dimensions based on matrix type + int rows, cols; + size_t scale_elem_count; + if (matrix_type == "gate" || matrix_type == "up") { + rows = config_.intermediate_size; + cols = config_.hidden_size; + } else { // down + rows = config_.hidden_size; + cols = config_.intermediate_size; + } + int n_blocks_n = (rows + group_size - 1) / group_size; + int n_blocks_k = (cols + group_size - 1) / group_size; + scale_elem_count = n_blocks_n * n_blocks_k; + + // Dump scales (as BF16 converted to float) + printf(" Scales[first 16]: "); + for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) { + printf("%.6f ", buffer->d[i]); + } + printf("\n"); + + if (scale_elem_count > 16) { + printf(" Scales[last 16]: "); + int start_idx = std::max(0, (int)scale_elem_count - 16); + for (int i = start_idx; i < (int)scale_elem_count; i++) { + printf("%.6f ", buffer->d[i]); + } + printf("\n"); + } + + // Dump FP8 weights (as hex uint8) + size_t weight_size = (size_t)rows * cols; // FP8 is 1 byte per element + uint8_t* weight_ptr = (uint8_t*)buffer->b; + + printf(" FP8 Weights[first 32 bytes]: "); + for (int i = 0; i < std::min(32, (int)weight_size); i++) { + printf("%02x ", weight_ptr[i]); + } + printf("\n"); + + if (weight_size > 32) { + printf(" FP8 Weights[last 32 bytes]: "); + int start_idx = std::max(32, (int)weight_size - 32); + for (int i = start_idx; i < (int)weight_size; i++) { + printf("%02x ", weight_ptr[i]); + } + printf("\n"); + } + + printf(" Matrix dimensions: %dx%d (n x k), Scale blocks: %dx%d, Group size: %d, Scale elements: %zu\n", rows, cols, + n_blocks_n, n_blocks_k, group_size, scale_elem_count); + } +#endif + + /** + * @brief Load FP8 weights from contiguous memory layout + * + * Loads weights from config_.gate_proj, up_proj, down_proj with scales + * from config_.gate_scale, up_scale, down_scale. + */ + void load_weights() { + auto& quant_config = config_.quant_config; + int& group_size = quant_config.group_size; + const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map; + auto pool = config_.pool->get_subpool(tp_part_idx); + + if (config_.gate_scale == nullptr) { + throw std::runtime_error("FP8 AVX MOE only support native weight."); + } + + // load weight + int nth = T::recommended_nth(config_.intermediate_size); + pool->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [this, nth, physical_to_logical_map, group_size](int task_id) { + uint64_t expert_idx = task_id / nth; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + int ith = task_id % nth; + // gate part + gate_bb_[expert_idx]->from_mat( + (uint8_t*)config_.gate_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size), + (float*)config_.gate_scale + + (logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)), + ith, nth); + // up part + up_bb_[expert_idx]->from_mat( + (uint8_t*)config_.up_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size), + (float*)config_.up_scale + + (logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)), + ith, nth); + }, + nullptr); + + nth = T::recommended_nth(config_.hidden_size); + pool->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [this, nth, physical_to_logical_map, group_size](int task_id) { + uint64_t expert_idx = task_id / nth; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + int ith = task_id % nth; + // down part + down_bb_[expert_idx]->from_mat( + (uint8_t*)config_.down_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size), + (float*)config_.down_scale + + (logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)), + ith, nth); + }, + nullptr); +#ifdef DEBUG_FP8_MOE + dump_buffer_b("Native FP8", 0, "gate", gate_bb_[0].get()); + dump_buffer_b("Native FP8", 0, "down", down_bb_[0].get()); +#endif + } + + // Fast 64-byte (512-bit) memcpy using AVX512 + static inline void fast_memcpy_64(void* __restrict dst, const void* __restrict src) { + __m512i data = _mm512_loadu_si512(src); + _mm512_storeu_si512(dst, data); + } + + // Fast memcpy for arbitrary sizes using AVX512 + static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) { + uint8_t* d = (uint8_t*)dst; + const uint8_t* s = (const uint8_t*)src; + size_t chunks = bytes / 64; + for (size_t i = 0; i < chunks; i++) { + fast_memcpy_64(d, s); + d += 64; + s += 64; + } + bytes -= chunks * 64; + if (bytes > 0) { + std::memcpy(d, s, bytes); + } + } + + /** + * @brief Unpack a single N_STEP x K_STEP block from packed BufferB format to n-major format + * + * This is the inverse of the packing done in BufferBFP8Impl::from_mat. + * Optimized with AVX512 gather for efficient non-contiguous reads. + * + * @param src Pointer to packed data (N_STEP * K_STEP bytes in packed layout) + * @param dst Pointer to destination in n-major layout + * @param dst_row_stride Row stride in destination buffer (number of columns in full matrix) + */ + static inline void unpack_nk_block(const uint8_t* src, uint8_t* dst, size_t dst_row_stride) { + // row_map[packed_i] gives the base row for packed index packed_i + static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28}; + const uint64_t* src64 = reinterpret_cast(src); + + // Gather indices: src64[8*j + packed_i] for j = 0..7 + // Offsets in uint64 units: 0, 8, 16, 24, 32, 40, 48, 56 (+ packed_i for each group) + const __m512i gather_offsets = _mm512_set_epi64(56, 48, 40, 32, 24, 16, 8, 0); + + // Process each packed group (8 groups of 4 rows each = 32 rows total) + for (int packed_i = 0; packed_i < 8; packed_i++) { + const int base_row = row_map[packed_i]; + const uint64_t* base_src = src64 + packed_i; + + // Gather 8 values for j=0..7 and j=8..15 + __m512i vals_0_7 = _mm512_i64gather_epi64(gather_offsets, base_src, 8); + __m512i vals_8_15 = _mm512_i64gather_epi64(gather_offsets, base_src + 64, 8); + + // Extract 4 rows from each set of 8 values + // Row 0: bits 0-15 + __m128i row0_lo = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_0_7, _mm512_set1_epi64(0xFFFF))); + __m128i row0_hi = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_8_15, _mm512_set1_epi64(0xFFFF))); + // Row 1: bits 16-31 + __m128i row1_lo = + _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 16), _mm512_set1_epi64(0xFFFF))); + __m128i row1_hi = + _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 16), _mm512_set1_epi64(0xFFFF))); + // Row 2: bits 32-47 + __m128i row2_lo = + _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 32), _mm512_set1_epi64(0xFFFF))); + __m128i row2_hi = + _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 32), _mm512_set1_epi64(0xFFFF))); + // Row 3: bits 48-63 + __m128i row3_lo = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_0_7, 48)); + __m128i row3_hi = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_8_15, 48)); + + // Store 32 bytes (16 x uint16) to each row + // Combine two 128-bit values into 256-bit for more efficient stores + uint8_t* row0_dst = dst + (size_t)base_row * dst_row_stride; + uint8_t* row1_dst = dst + (size_t)(base_row + 1) * dst_row_stride; + uint8_t* row2_dst = dst + (size_t)(base_row + 2) * dst_row_stride; + uint8_t* row3_dst = dst + (size_t)(base_row + 3) * dst_row_stride; + + // Combine lo and hi into 256-bit and store + __m256i row0_256 = _mm256_set_m128i(row0_hi, row0_lo); + __m256i row1_256 = _mm256_set_m128i(row1_hi, row1_lo); + __m256i row2_256 = _mm256_set_m128i(row2_hi, row2_lo); + __m256i row3_256 = _mm256_set_m128i(row3_hi, row3_lo); + + _mm256_storeu_si256((__m256i*)row0_dst, row0_256); + _mm256_storeu_si256((__m256i*)row1_dst, row1_256); + _mm256_storeu_si256((__m256i*)row2_dst, row2_256); + _mm256_storeu_si256((__m256i*)row3_dst, row3_256); + } + } + + /** + * @brief Unpack 4 consecutive N_STEP x K_STEP blocks to maximize cache line utilization + * + * Processing 4 blocks together means each row write is 128 bytes = 2 cache lines, + * which greatly improves write efficiency compared to 32 bytes per row. + * + * @param src Array of 4 source pointers (each pointing to a 32x32 packed block) + * @param dst Destination pointer in n-major layout + * @param dst_row_stride Row stride in destination buffer + */ + static inline void unpack_4nk_blocks(const uint8_t* src[4], uint8_t* dst, size_t dst_row_stride) { + static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28}; + constexpr int K_STEP = T::K_STEP; // 32 + + // Reinterpret as uint64 arrays for efficient access + const uint64_t* src0 = reinterpret_cast(src[0]); + const uint64_t* src1 = reinterpret_cast(src[1]); + const uint64_t* src2 = reinterpret_cast(src[2]); + const uint64_t* src3 = reinterpret_cast(src[3]); + + // Process all 32 rows, writing 128 bytes (4 x 32) per row + for (int packed_i = 0; packed_i < 8; packed_i++) { + const int base_row = row_map[packed_i]; + + // Process 4 rows at a time + for (int r = 0; r < 4; r++) { + uint16_t* row_dst = reinterpret_cast(dst + (size_t)(base_row + r) * dst_row_stride); + const int shift = r * 16; + + // Unroll: process all 4 blocks x 16 columns = 64 uint16 values + // Block 0: columns 0-15 + for (int j = 0; j < 16; j++) { + row_dst[j] = static_cast(src0[8 * j + packed_i] >> shift); + } + // Block 1: columns 16-31 + for (int j = 0; j < 16; j++) { + row_dst[16 + j] = static_cast(src1[8 * j + packed_i] >> shift); + } + // Block 2: columns 32-47 + for (int j = 0; j < 16; j++) { + row_dst[32 + j] = static_cast(src2[8 * j + packed_i] >> shift); + } + // Block 3: columns 48-63 + for (int j = 0; j < 16; j++) { + row_dst[48 + j] = static_cast(src3[8 * j + packed_i] >> shift); + } + } + } + } + + /** + * @brief Reconstruct weights for a single expert to the output buffers (no temp buffer version) + * + * Directly unpacks from packed BufferB format to n-major GPU buffers without intermediate storage. + * Optimized version with coarse-grained task splitting for better cache utilization. + * + * Key optimizations: + * - Reduced task count (~40 vs ~350) to minimize scheduling overhead + * - Larger chunks per task for better cache line utilization + * - Process multiple N_STEPs per task for better write locality + * + * @param gpu_tp_count Number of GPU TP parts (1, 2, 4, or 8) + * @param cpu_tp_count Number of CPU TP parts + * @param expert_id Expert index to process + * @param full_config Full configuration (before CPU TP split) + * @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP) + * @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP) + * @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP) + * @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP) + */ + void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id, + const GeneralMOEConfig& full_config, const std::vector& w13_weight_ptrs, + const std::vector& w13_scale_ptrs, + const std::vector& w2_weight_ptrs, + const std::vector& w2_scale_ptrs) const { + auto& config = config_; + const int group_size = config.quant_config.group_size; + auto pool = config.pool->get_subpool(tp_part_idx); + + constexpr int N_STEP = T::N_STEP; + constexpr int K_STEP = T::K_STEP; + constexpr int N_BLOCK = T::N_BLOCK; + constexpr int K_BLOCK = T::K_BLOCK; + + // ========= W13 (gate+up): Shape [intermediate, hidden], split by N only ========= + const int cpu_n_w13 = config.intermediate_size; + const int cpu_k_w13 = config.hidden_size; + const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count; + const int gpu_k_w13 = full_config.hidden_size; + const int global_n_offset_w13 = tp_part_idx * cpu_n_w13; + + const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13; + const size_t gpu_w13_scale_per_mat = (size_t)div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size); + const int cpu_scale_k_blocks_w13 = div_up(cpu_k_w13, group_size); + const int gpu_scale_k_blocks_w13 = div_up(gpu_k_w13, group_size); + + // ========= W2 (down): Shape [hidden, intermediate], split by K ========= + const int cpu_n_w2 = config.hidden_size; + const int cpu_k_w2 = config.intermediate_size; + const int gpu_n_w2 = full_config.hidden_size; + const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count; + const int global_k_offset_w2 = tp_part_idx * cpu_k_w2; + + const size_t gpu_w2_weight_per_mat = (size_t)gpu_n_w2 * gpu_k_w2; + const size_t gpu_w2_scale_per_mat = (size_t)div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size); + const int cpu_scale_k_blocks_w2 = div_up(cpu_k_w2, group_size); + const int gpu_scale_k_blocks_w2 = div_up(gpu_k_w2, group_size); + + // ========= Scale dimensions ========= + const int cpu_scale_n_blocks_w13 = div_up(cpu_n_w13, group_size); + const int gpu_scale_n_blocks_w13 = div_up(gpu_n_w13, group_size); + const int cpu_scale_n_blocks_w2 = div_up(cpu_n_w2, group_size); + + // ========= Optimized job layout ========= + // Use task count slightly above CPU core count for good work stealing + // For 80-core system, ~100 tasks provides good balance + constexpr int NUM_W13_TASKS = 32; // Per matrix (gate or up), total 64 for w13 + constexpr int NUM_W2_TASKS = 32; // For down matrix + constexpr int SCALE_TASKS = 3; // gate_scale, up_scale, down_scale + + const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS + SCALE_TASKS; + + // Calculate N_STEP blocks per task (must be N_STEP aligned for correct BufferB addressing) + const int w13_n_steps = div_up(cpu_n_w13, N_STEP); + const int w13_steps_per_task = div_up(w13_n_steps, NUM_W13_TASKS); + const int w2_n_steps = div_up(cpu_n_w2, N_STEP); + const int w2_steps_per_task = div_up(w2_n_steps, NUM_W2_TASKS); + + pool->do_work_stealing_job( + total_tasks, nullptr, + [=, &w13_weight_ptrs, &w13_scale_ptrs, &w2_weight_ptrs, &w2_scale_ptrs, this](int task_id) { + if (task_id < NUM_W13_TASKS * 2) { + // ========= W13 weight task: process chunk of rows x full K ========= + const bool is_up = task_id >= NUM_W13_TASKS; + const int chunk_idx = task_id % NUM_W13_TASKS; + const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id]; + + // Calculate row range for this task (N_STEP aligned) + const int step_start = chunk_idx * w13_steps_per_task; + const int step_end = std::min(step_start + w13_steps_per_task, w13_n_steps); + if (step_start >= w13_n_steps) return; + const int chunk_n_start = step_start * N_STEP; + const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w13); + + // Process each N_STEP within this chunk + for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) { + // Calculate GPU target and offset for each N_STEP (may cross GPU TP boundaries) + const int global_n = global_n_offset_w13 + local_n_start; + const int target_gpu = global_n / gpu_n_w13; + const int n_in_gpu = global_n % gpu_n_w13; + + uint8_t* weight_base = (uint8_t*)w13_weight_ptrs[target_gpu]; + // Pointer already points to current expert's location, only add offset for up matrix + const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0; + + // Calculate N_BLOCK info for source addressing + const int n_block_idx = local_n_start / N_BLOCK; + const int n_block_begin = n_block_idx * N_BLOCK; + const int n_block_size = std::min(N_BLOCK, cpu_n_w13 - n_block_begin); + const int n_in_block = local_n_start - n_block_begin; + + // Process all K in groups of 4 K_STEPs when possible for cache efficiency + for (int k_block_begin = 0; k_block_begin < cpu_k_w13; k_block_begin += K_BLOCK) { + const int k_block_size = std::min(K_BLOCK, cpu_k_w13 - k_block_begin); + + // Try to process 4 K_STEPs at once (128 columns = 2 cache lines per row) + int k_begin = 0; + for (; k_begin + 4 * K_STEP <= k_block_size; k_begin += 4 * K_STEP) { + const uint8_t* src_ptrs[4]; + for (int i = 0; i < 4; i++) { + src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w13 + (size_t)k_block_begin * n_block_size + + (size_t)n_in_block * k_block_size + (size_t)(k_begin + i * K_STEP) * N_STEP; + } + uint8_t* dst = + weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin; + unpack_4nk_blocks(src_ptrs, dst, gpu_k_w13); + } + + // Handle remaining K_STEPs one by one + for (; k_begin < k_block_size; k_begin += K_STEP) { + const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w13 + + (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size + + (size_t)k_begin * N_STEP; + uint8_t* dst = + weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin; + unpack_nk_block(src, dst, gpu_k_w13); + } + } + } + + } else if (task_id < NUM_W13_TASKS * 2 + NUM_W2_TASKS) { + // ========= W2 weight task: process chunk of rows x all K slices ========= + const int chunk_idx = task_id - NUM_W13_TASKS * 2; + const auto& bb = down_bb_[expert_id]; + + // Calculate row range for this task (N_STEP aligned) + const int step_start = chunk_idx * w2_steps_per_task; + const int step_end = std::min(step_start + w2_steps_per_task, w2_n_steps); + if (step_start >= w2_n_steps) return; + const int chunk_n_start = step_start * N_STEP; + const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w2); + + // Process each N_STEP within this chunk + for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) { + // Calculate N_BLOCK info for source addressing + const int n_block_idx = local_n_start / N_BLOCK; + const int n_block_begin = n_block_idx * N_BLOCK; + const int n_block_size = std::min(N_BLOCK, cpu_n_w2 - n_block_begin); + const int n_in_block = local_n_start - n_block_begin; + + // Process all K slices (each slice goes to a different GPU TP) + for (int k_slice_start = 0; k_slice_start < cpu_k_w2; k_slice_start += gpu_k_w2) { + const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2); + + const int global_k_start = global_k_offset_w2 + k_slice_start; + const int target_gpu = global_k_start / gpu_k_w2; + const int k_in_gpu_base = global_k_start % gpu_k_w2; + + uint8_t* weight_base = (uint8_t*)w2_weight_ptrs[target_gpu]; + // Pointer already points to current expert's location + const size_t expert_weight_off = 0; + + // Process K within this slice, trying 4 K_STEPs at once when aligned + for (int k_abs = k_slice_start; k_abs < k_slice_end;) { + const int k_block_idx = k_abs / K_BLOCK; + const int k_block_begin = k_block_idx * K_BLOCK; + const int k_block_size = std::min(K_BLOCK, cpu_k_w2 - k_block_begin); + const int k_in_block = k_abs - k_block_begin; + const int k_in_gpu = k_in_gpu_base + (k_abs - k_slice_start); + + // Check if we can process 4 K_STEPs at once + const int remaining_in_block = k_block_size - k_in_block; + const int remaining_in_slice = k_slice_end - k_abs; + + if (remaining_in_block >= 4 * K_STEP && remaining_in_slice >= 4 * K_STEP) { + const uint8_t* src_ptrs[4]; + for (int i = 0; i < 4; i++) { + src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w2 + (size_t)k_block_begin * n_block_size + + (size_t)n_in_block * k_block_size + (size_t)(k_in_block + i * K_STEP) * N_STEP; + } + uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu; + unpack_4nk_blocks(src_ptrs, dst, gpu_k_w2); + k_abs += 4 * K_STEP; + } else { + const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w2 + + (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size + + (size_t)k_in_block * N_STEP; + uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu; + unpack_nk_block(src, dst, gpu_k_w2); + k_abs += K_STEP; + } + } + } + } + + } else { + // ========= Scale copy task: simple linear copy with fast_memcpy ========= + const int scale_task_id = task_id - NUM_W13_TASKS * 2 - NUM_W2_TASKS; + + if (scale_task_id < 2) { + // Gate (0) or Up (1) scale copy + const bool is_up = scale_task_id == 1; + const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id]; + + // W13 scales: copy N blocks corresponding to this CPU TP + // Note: when gpu_tp > cpu_tp, scale blocks may span multiple GPU TPs + const int bn_start_global = global_n_offset_w13 / group_size; + + for (int bn = 0; bn < cpu_scale_n_blocks_w13; bn++) { + const int global_bn = bn_start_global + bn; + const int target_gpu = global_bn / gpu_scale_n_blocks_w13; + const int gpu_bn = global_bn % gpu_scale_n_blocks_w13; + + float* scale_dst = (float*)w13_scale_ptrs[target_gpu]; + // Pointer already points to current expert's location, only add offset for up matrix + const size_t expert_scale_off = is_up ? gpu_w13_scale_per_mat : 0; + + fast_memcpy(scale_dst + expert_scale_off + (size_t)gpu_bn * gpu_scale_k_blocks_w13, + bb->d + (size_t)bn * cpu_scale_k_blocks_w13, cpu_scale_k_blocks_w13 * sizeof(float)); + } + } else { + // Down scale copy (scale_task_id == 2) + const auto& bb = down_bb_[expert_id]; + + // W2 scales: K dimension is split, copy to each GPU TP + for (int k_slice_idx = 0; k_slice_idx < div_up(cpu_k_w2, gpu_k_w2); k_slice_idx++) { + const int k_slice_start = k_slice_idx * gpu_k_w2; + const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2); + + const int global_k_start = global_k_offset_w2 + k_slice_start; + const int target_gpu = global_k_start / gpu_k_w2; + const int bk_gpu_base = (global_k_start % gpu_k_w2) / group_size; + + float* scale_dst = (float*)w2_scale_ptrs[target_gpu]; + // Pointer already points to current expert's location + const size_t expert_scale_off = 0; + + const int bk_start = k_slice_start / group_size; + const int bk_end = div_up(k_slice_end, group_size); + const int bk_count = bk_end - bk_start; + + for (int bn = 0; bn < cpu_scale_n_blocks_w2; bn++) { + fast_memcpy(scale_dst + expert_scale_off + (size_t)bn * gpu_scale_k_blocks_w2 + bk_gpu_base, + bb->d + (size_t)bn * cpu_scale_k_blocks_w2 + bk_start, bk_count * sizeof(float)); + } + } + } + } + }, + nullptr); + } +}; + +template +class TP_MOE> : public TP_MOE>> { + public: + using Base = TP_MOE>>; + using Base::Base; + + void load_weights() override { + auto& config = this->config; + auto& tps = this->tps; + auto& tp_count = this->tp_count; + auto pool = config.pool; + const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map; + + const int group_size = config.quant_config.group_size; + if (group_size == 0 || config.quant_config.zero_point) { + throw std::runtime_error("FP8 MoE only supports have group_size, zero_point=false"); + } + + if (config.gate_projs.empty() && config.gate_proj == nullptr) { + throw std::runtime_error("no weight source"); + } + const bool use_per_expert_ptrs = !config.gate_projs.empty(); + + const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size; + const size_t full_scale_elems = + (size_t)div_up(config.hidden_size, group_size) * div_up(config.intermediate_size, group_size); + + pool->dispense_backend()->do_numa_job([&, this](int i) { + auto& tpc = tps[i]->config_; + const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size; + const size_t tp_scale_elems = + (size_t)div_up(tpc.intermediate_size, group_size) * div_up(tpc.hidden_size, group_size); + + tpc.gate_proj = new uint8_t[tpc.expert_num * tp_weight_elems]; + tpc.up_proj = new uint8_t[tpc.expert_num * tp_weight_elems]; + tpc.down_proj = new uint8_t[tpc.expert_num * tp_weight_elems]; + + tpc.gate_scale = new float[tpc.expert_num * tp_scale_elems]; + tpc.up_scale = new float[tpc.expert_num * tp_scale_elems]; + tpc.down_scale = new float[tpc.expert_num * tp_scale_elems]; + + const size_t tp_idx = (size_t)i; + const size_t gate_up_weight_src_offset = i * tp_weight_elems; + const size_t gate_up_scale_src_offset = i * tp_scale_elems; + + const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size; + const size_t down_scale_src_block_k_offset = down_weight_src_col_offset / (size_t)group_size; + + pool->get_subpool(i)->do_work_stealing_job( + tpc.expert_num, nullptr, + [&, &tpc](int expert_id_) { + const size_t expert_id = expert_map(physical_to_logical_map, expert_id_); + + uint8_t* gate_dst = (uint8_t*)tpc.gate_proj + expert_id * tp_weight_elems; + uint8_t* up_dst = (uint8_t*)tpc.up_proj + expert_id * tp_weight_elems; + uint8_t* down_dst = (uint8_t*)tpc.down_proj + expert_id * tp_weight_elems; + + float* gate_scale_dst = (float*)tpc.gate_scale + expert_id * tp_scale_elems; + float* up_scale_dst = (float*)tpc.up_scale + expert_id * tp_scale_elems; + float* down_scale_dst = (float*)tpc.down_scale + expert_id * tp_scale_elems; + + const uint8_t* gate_src; + const uint8_t* up_src; + const uint8_t* down_src; + const float* gate_scale_src; + const float* up_scale_src; + const float* down_scale_src; + + if (use_per_expert_ptrs) { + gate_src = (const uint8_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset; + up_src = (const uint8_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset; + down_src = (const uint8_t*)config.down_projs[0][expert_id]; + + gate_scale_src = (const float*)config.gate_scales[0][expert_id] + gate_up_scale_src_offset; + up_scale_src = (const float*)config.up_scales[0][expert_id] + gate_up_scale_src_offset; + down_scale_src = (const float*)config.down_scales[0][expert_id]; + } else { + gate_src = (const uint8_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset; + up_src = (const uint8_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset; + down_src = (const uint8_t*)config.down_proj + expert_id * full_weight_elems; + + gate_scale_src = + (const float*)config.gate_scale + expert_id * full_scale_elems + gate_up_scale_src_offset; + up_scale_src = (const float*)config.up_scale + expert_id * full_scale_elems + gate_up_scale_src_offset; + down_scale_src = (const float*)config.down_scale + expert_id * full_scale_elems; + } + + std::memcpy(gate_dst, gate_src, tp_weight_elems); + std::memcpy(up_dst, up_src, tp_weight_elems); + std::memcpy(gate_scale_dst, gate_scale_src, sizeof(float) * tp_scale_elems); + std::memcpy(up_scale_dst, up_scale_src, sizeof(float) * tp_scale_elems); + + for (int row = 0; row < config.hidden_size; row++) { + const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset; + const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size; + std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset, (size_t)tpc.intermediate_size); + } + + const int n_blocks_n = div_up(config.hidden_size, group_size); + const int full_n_blocks_k = div_up(config.intermediate_size, group_size); + const int tp_n_blocks_k = div_up(tpc.intermediate_size, group_size); + for (int bn = 0; bn < n_blocks_n; bn++) { + const float* src = down_scale_src + (size_t)bn * (size_t)full_n_blocks_k + down_scale_src_block_k_offset; + float* dst = down_scale_dst + (size_t)bn * (size_t)tp_n_blocks_k; + std::memcpy(dst, src, sizeof(float) * (size_t)tp_n_blocks_k); + } + }, + nullptr); + }); + + DO_TPS_LOAD_WEIGHTS(pool); + + pool->dispense_backend()->do_numa_job([&, this](int i) { + auto& tpc = tps[i]->config_; + delete[] (uint8_t*)tpc.gate_proj; + delete[] (uint8_t*)tpc.up_proj; + delete[] (uint8_t*)tpc.down_proj; + delete[] (float*)tpc.gate_scale; + delete[] (float*)tpc.up_scale; + delete[] (float*)tpc.down_scale; + }); + + this->weights_loaded = true; + } + + void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector& w13_weight_ptrs, + const std::vector& w13_scale_ptrs, + const std::vector& w2_weight_ptrs, + const std::vector& w2_scale_ptrs) { + if (this->weights_loaded == false) { + throw std::runtime_error("Not Loaded"); + } + if (this->tps.empty()) { + throw std::runtime_error("No TP parts initialized"); + } + if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w13_scale_ptrs.size() != gpu_tp_count || + (int)w2_weight_ptrs.size() != gpu_tp_count || (int)w2_scale_ptrs.size() != gpu_tp_count) { + throw std::runtime_error("Pointer arrays size must match gpu_tp_count"); + } + + this->config.pool->dispense_backend()->do_numa_job([&, this](int i) { + this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs, + w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs); + }); + } +}; + +#endif // CPUINFER_OPERATOR_AMX_FP8_MOE_H diff --git a/kt-kernel/operators/amx/k2-moe.hpp b/kt-kernel/operators/amx/k2-moe.hpp index 3e3c207..67809a9 100644 --- a/kt-kernel/operators/amx/k2-moe.hpp +++ b/kt-kernel/operators/amx/k2-moe.hpp @@ -1,256 +1,120 @@ /** - * @Description : Skeleton for K2 AMX MoE operator. - * @Author : Codex - * @Date : 2024-07-22 - * @Version : 0.1.0 - * @LastEditors : Codex - * @LastEditTime : 2024-07-22 + * @Description : K2 AMX MoE operator for Kimi-K2 native inference + * @Author : oql, Codex and Claude + * @Date : 2025-12-09 + * @Version : 1.0.0 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + * + * This file implements K2 Int4 MoE using CRTP pattern, inheriting from moe_base.hpp. + * K2 weights are stored with group-wise scales (KGroup Int4). **/ #ifndef CPUINFER_OPERATOR_AMX_K2_MOE_H #define CPUINFER_OPERATOR_AMX_K2_MOE_H -// #define DEBUG_K2_MOE +// #define LOAD_TIME_PROFILE -#include -#include -#include -// #define FORWARD_TIME_PROFILE -#define LOAD_TIME_PROFILE +#include "moe_base.hpp" -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../../cpu_backend/shared_mem_buffer.h" -#include "../../cpu_backend/worker_pool.h" -#include "../common.hpp" -#include "../moe-tp.hpp" -#include "la/amx.hpp" -#include "llama.cpp/ggml.h" - -template -class AMX_K2_MOE_TP { - private: - int tp_part_idx = 0; - - void* gate_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] - void* up_proj_ = nullptr; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)] - void* down_proj_ = nullptr; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)] - - ggml_bf16_t* m_local_input_ = nullptr; // [num_experts_per_tok * max_len * hidden_size] - ggml_bf16_t* m_local_gate_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size] - ggml_bf16_t* m_local_up_output_ = nullptr; // [num_experts_per_tok * max_len * intermediate_size] - ggml_bf16_t* m_local_down_output_ = nullptr; // [num_experts_per_tok * max_len * hidden_size] - - std::vector> m_local_pos_; // [max_len, num_experts_per_tok] - std::vector m_local_num_; // [expert_num] - std::vector m_expert_id_map_; // [expert_num] - std::vector m_local_input_ptr_; // [expert_num] - std::vector m_local_gate_output_ptr_; // [expert_num] - std::vector m_local_up_output_ptr_; // [expert_num] - std::vector m_local_down_output_ptr_; // [expert_num] - - std::vector> gate_up_ba_; - std::vector> gate_bb_; - std::vector> gate_bc_; - std::vector> up_bb_; - std::vector> up_bc_; - std::vector> down_ba_; - std::vector> down_bb_; - std::vector> down_bc_; - - size_t pool_count_ = 0; // rows reserved in each scratch pool - size_t gate_up_ba_pool_bytes_ = 0; - size_t gate_bc_pool_bytes_ = 0; - size_t up_bc_pool_bytes_ = 0; - size_t down_ba_pool_bytes_ = 0; - size_t down_bc_pool_bytes_ = 0; - void* gate_up_ba_pool_ = nullptr; - void* gate_bc_pool_ = nullptr; - void* up_bc_pool_ = nullptr; - void* down_ba_pool_ = nullptr; - void* down_bc_pool_ = nullptr; -#ifdef CHECK - char verify_bb[100000000]; - char check_bb[100000000]; - uint8_t compare_expers = 3; -#endif - -#ifdef CHECK - inline void load_check() { - // TODO: implement load_check for verification. - } - - void verify_load_right() { - // TODO: implement verification helpers. - } -#endif - - inline void dump_buffer_b(const std::string& quantization_type, int expert_idx, const std::string& matrix_type, - typename T::BufferB* buffer) { - auto& quant_config = config_.quant_config; - int& group_size = quant_config.group_size; - - printf("[DUMP_BUFFER_B] TP%d %s Expert%d %s:\n", tp_part_idx, quantization_type.c_str(), expert_idx, - matrix_type.c_str()); - - // Calculate dimensions based on matrix type - int rows, cols, num_groups; - size_t scale_elem_count; - if (matrix_type == "gate" || matrix_type == "up") { - rows = config_.intermediate_size; - cols = config_.hidden_size; - num_groups = cols / group_size; - scale_elem_count = num_groups * rows; - } else { // down - rows = config_.hidden_size; - cols = config_.intermediate_size; - num_groups = cols / group_size; - scale_elem_count = num_groups * rows; - } - - // Dump scales (as float) - printf(" Scales[first 16]: "); - for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) { - printf("%.6f ", buffer->d[i]); - } - printf("\n"); - - if (scale_elem_count > 16) { - printf(" Scales[last 16]: "); - int start_idx = std::max(0, (int)scale_elem_count - 16); - for (int i = start_idx; i < (int)scale_elem_count; i++) { - printf("%.6f ", buffer->d[i]); - } - printf("\n"); - } - // Dump quantized weights (as hex uint8) - size_t weight_size = (rows * cols) / 2; // INT4 packed - uint8_t* weight_ptr = (uint8_t*)buffer->b; - - printf(" Weights[first 32 bytes]: "); - for (int i = 0; i < std::min(32, (int)weight_size); i++) { - printf("%02x ", weight_ptr[i]); - } - printf("\n"); - - if (weight_size > 32) { - printf(" Weights[last 32 bytes]: "); - int start_idx = std::max(32, (int)weight_size - 32); - for (int i = start_idx; i < (int)weight_size; i++) { - printf("%02x ", weight_ptr[i]); - } - printf("\n"); - } - - printf(" Matrix dimensions: %dx%d, Groups: %d, Group size: %d, Scale elements: %zu\n", rows, cols, num_groups, - group_size, scale_elem_count); - printf("\n"); - fflush(stdout); - } +/** + * @brief K2 Int4 MoE operator using CRTP pattern + * @tparam T Kernel type, defaults to amx::GemmKernel224Int4SmallKGroup + * + * This class provides K2-specific GEMM implementations: + * - do_gate_up_gemm: Int4 weight with KGroup scale + AMX GEMM + * - do_down_gemm: Same Int4 KGroup GEMM + * - load_weights: Load Int4 weights with group-wise scales + */ +template +class AMX_K2_MOE_TP : public AMX_MOE_BASE> { + using Base = AMX_MOE_BASE>; + using Base::config_; + using Base::down_ba_; + using Base::down_bb_; + using Base::down_bc_; + using Base::gate_bb_; + using Base::gate_bc_; + using Base::gate_up_ba_; + using Base::m_local_num_; + using Base::tp_part_idx; + using Base::up_bb_; + using Base::up_bc_; public: - using input_t = ggml_bf16_t; - using output_t = float; - GeneralMOEConfig config_; - static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE; + using typename Base::input_t; + using typename Base::output_t; - AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_) { - auto& quant_config = config.quant_config; - int& group_size = quant_config.group_size; + AMX_K2_MOE_TP() = default; + + AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) { + auto& quant_config = config_.quant_config; if (quant_config.group_size == 0 || quant_config.zero_point) { throw std::runtime_error("Kimi-K2 MoE only support KGroup Int4"); } printf("Creating AMX_K2_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu())); - auto& load = config.load; - auto& save = config.save; - if (load && config.path == "") { - load = false; - } - - this->tp_part_idx = tp_part_idx_; - config_ = config; - gate_proj_ = config_.gate_proj; - up_proj_ = config_.up_proj; - down_proj_ = config_.down_proj; - - MemoryRequest mem_requests; - mem_requests.append_pointer( - &m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size); - mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * - config_.max_len * config_.intermediate_size); - mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * - config_.max_len * config_.intermediate_size); - mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * - config_.max_len * config_.hidden_size); - - m_local_pos_.resize(config_.max_len); - for (int i = 0; i < config_.max_len; i++) { - m_local_pos_[i].resize(config_.num_experts_per_tok); - } - m_expert_id_map_.resize(config_.expert_num); - m_local_num_.resize(config_.expert_num); - m_local_input_ptr_.resize(config_.expert_num); - m_local_gate_output_ptr_.resize(config_.expert_num); - m_local_up_output_ptr_.resize(config_.expert_num); - m_local_down_output_ptr_.resize(config_.expert_num); - - for (size_t i = 0; i < config_.expert_num; i++) { - gate_up_ba_.push_back( - std::make_shared(config_.max_len, config_.hidden_size, group_size, nullptr)); - gate_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); - up_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); - down_ba_.push_back( - std::make_shared(config_.max_len, config_.intermediate_size, group_size, nullptr)); - down_bc_.push_back(std::make_shared(config_.max_len, config_.hidden_size, nullptr)); - - void* gate_bb_ptr = - std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size)); - gate_bb_.push_back(std::make_shared(config_.intermediate_size, config_.hidden_size, - group_size, gate_bb_ptr)); - - void* up_bb_ptr = - std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size)); - up_bb_.push_back( - std::make_shared(config_.intermediate_size, config_.hidden_size, group_size, up_bb_ptr)); - - void* down_bb_ptr = - std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size)); - down_bb_.push_back(std::make_shared(config_.hidden_size, config_.intermediate_size, - group_size, down_bb_ptr)); - } - assert(T::BufferA::M_STEP == T::BufferC::M_STEP); - // TODO: need update to all *.hpp - // (config_.expert_num * T::BufferA::M_STEP) in pool_count_ is to ensure padding for each experts. - pool_count_ = config_.max_len * config_.num_experts_per_tok + config_.expert_num * T::BufferA::M_STEP; - - gate_up_ba_pool_bytes_ = - (T::BufferA::required_size(pool_count_, config_.hidden_size, group_size)) + pool_count_ * 64; - gate_bc_pool_bytes_ = (T::BufferC::required_size(pool_count_, config_.intermediate_size)) + pool_count_ * 64; - up_bc_pool_bytes_ = (T::BufferC::required_size(pool_count_, config_.intermediate_size)) + pool_count_ * 64; - down_ba_pool_bytes_ = - (T::BufferA::required_size(pool_count_, config_.intermediate_size, group_size)) + pool_count_ * 64; - down_bc_pool_bytes_ = (T::BufferC::required_size(pool_count_, config_.hidden_size)) + pool_count_ * 64; - - mem_requests.append_pointer(&gate_up_ba_pool_, gate_up_ba_pool_bytes_); - mem_requests.append_pointer(&gate_bc_pool_, gate_bc_pool_bytes_); - mem_requests.append_pointer(&up_bc_pool_, up_bc_pool_bytes_); - mem_requests.append_pointer(&down_ba_pool_, down_ba_pool_bytes_); - mem_requests.append_pointer(&down_bc_pool_, down_bc_pool_bytes_); - - shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests); } ~AMX_K2_MOE_TP() = default; + // ============================================================================ + // CRTP buffer creation - with group_size + // ============================================================================ + + size_t buffer_a_required_size_impl(size_t m, size_t k) const { + return T::BufferA::required_size(m, k, config_.quant_config.group_size); + } + size_t buffer_b_required_size_impl(size_t n, size_t k) const { + return T::BufferB::required_size(n, k, config_.quant_config.group_size); + } + size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); } + + std::shared_ptr make_buffer_a_impl(size_t m, size_t k, void* data) const { + return std::make_shared(m, k, config_.quant_config.group_size, data); + } + std::shared_ptr make_buffer_b_impl(size_t n, size_t k, void* data) const { + return std::make_shared(n, k, config_.quant_config.group_size, data); + } + std::shared_ptr make_buffer_c_impl(size_t m, size_t n, void* data) const { + return std::make_shared(m, n, data); + } + + // ============================================================================ + // CRTP virtual points - GEMM dispatch + // ============================================================================ + + void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) { + auto& group_size = config_.quant_config.group_size; + int m = m_local_num_[expert_idx]; + auto& ba = gate_up_ba_[expert_idx]; + auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx]; + auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx]; + + // Dispatch based on qlen threshold + if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { + amx::mat_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth); + } else { + amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth); + } + } + + void do_down_gemm(int expert_idx, int ith, int nth, int qlen) { + auto& group_size = config_.quant_config.group_size; + int m = m_local_num_[expert_idx]; + + if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { + amx::mat_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], + down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); + } else { + amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], + down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); + } + } + + /** + * @brief Load Int4 weights from contiguous memory layout + * + * Loads weights from config_.gate_proj, up_proj, down_proj with scales + * from config_.gate_scale, up_scale, down_scale. + */ void load_weights() { auto& quant_config = config_.quant_config; int& group_size = quant_config.group_size; @@ -263,6 +127,7 @@ class AMX_K2_MOE_TP { if (config_.gate_scale == nullptr) { throw std::runtime_error("Kimi AVX MOE only support load native weight."); } + // load weight int nth = T::recommended_nth(config_.intermediate_size); pool->do_work_stealing_job( @@ -314,12 +179,67 @@ class AMX_K2_MOE_TP { (ggml_bf16_t*)config_.down_scale + (logical_expert_id * scale_elem_count), scale_elem_count); }, nullptr); - // dump_buffer_b("native", 0, "down", down_bb_[0].get()); +#ifdef DEBUG_K2_MOE + dump_buffer_b("native", 0, "down", down_bb_[0].get()); +#endif } - // Reconstruct weights for all experts to the output buffers - // This function handles the TP-specific portion of the reconstruction for all experts - void write_weights_to_buffer(int gpu_tp_count, int cpu_tp_count, int num_experts, const GeneralMOEConfig& full_config, + static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) { + uint8_t* d = (uint8_t*)dst; + const uint8_t* s = (const uint8_t*)src; + + // Main loop: 512-bit (64-byte) SIMD copies + size_t chunks = bytes / 64; + for (size_t i = 0; i < chunks; i++) { + __m512i data = _mm512_loadu_si512((__m512i*)s); + _mm512_storeu_si512((__m512i*)d, data); + d += 64; + s += 64; + } + bytes -= chunks * 64; + + // Handle remaining bytes + if (bytes > 0) { + std::memcpy(d, s, bytes); + } + } + + // Optimized SIMD float32 to bf16 conversion + static inline void fast_fp32_to_bf16(ggml_bf16_t* __restrict dst, const float* __restrict src, size_t count) { + size_t i = 0; + + // Process 32 elements at a time (2x __m512, output 1x __m512i = 32 bf16) + for (; i + 32 <= count; i += 32) { + __m512 v0 = _mm512_loadu_ps(src + i); + __m512 v1 = _mm512_loadu_ps(src + i + 16); + + // Convert to bf16 using truncation (shift right 16 bits) + __m512i i0 = _mm512_srli_epi32(_mm512_castps_si512(v0), 16); + __m512i i1 = _mm512_srli_epi32(_mm512_castps_si512(v1), 16); + + // Pack 32-bit values to 16-bit + __m512i packed = _mm512_packus_epi32(i0, i1); + + // Reorder due to packus lane behavior: + // packus outputs interleaved: [i0[0-3], i1[0-3], i0[4-7], i1[4-7], i0[8-11], i1[8-11], i0[12-15], i1[12-15]] + // We need sequential: [i0[0-15], i1[0-15]] = [i0[0-3], i0[4-7], i0[8-11], i0[12-15], i1[0-3], i1[4-7], i1[8-11], + // i1[12-15]] Permutation: [0, 2, 4, 6, 1, 3, 5, 7] (qword indices) + __m512i permuted = _mm512_permutexvar_epi64(_mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0), packed); + + _mm512_storeu_si512((__m512i*)(dst + i), permuted); + } + + // Handle remaining elements with scalar conversion + for (; i < count; i++) { + dst[i] = ggml_fp32_to_bf16(src[i]); + } + } + + // Write a single expert's weights to the output buffers + // The caller provides pointers that already point to the target expert's location (no offset needed) + // expert_id: the index of the expert to write + // Optimized for maximum memory bandwidth using streaming stores + void write_weights_to_buffer(int gpu_tp_count, int cpu_tp_count, int expert_id, const GeneralMOEConfig& full_config, const std::vector& w13_weight_ptrs, const std::vector& w13_scale_ptrs, const std::vector& w2_weight_ptrs, @@ -346,95 +266,117 @@ class AMX_K2_MOE_TP { int target_gpu_tp = tp_part_idx / (cpu_tp_count / gpu_tp_count); int local_idx = tp_part_idx % (cpu_tp_count / gpu_tp_count); - // Get pointers for this GPU TP part + // Get pointers for this GPU TP part (already pointing to target expert's location) uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[target_gpu_tp]; ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[target_gpu_tp]; uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[target_gpu_tp]; ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[target_gpu_tp]; - // Calculate offset within the GPU TP buffer + // Calculate offset within the GPU TP buffer (for CPU TP slice within GPU TP) size_t offset_in_gpu_weight = local_idx * cpu_tp_weight_bytes; size_t offset_in_gpu_scale = local_idx * cpu_tp_scale_elem_count; - // Process only the first num_experts experts (GPU experts) - int nth = T::recommended_nth(config_.intermediate_size); - nth = 1; + // Optimized task layout for maximum bandwidth: + // - Larger chunks to reduce task overhead + // - Separate large contiguous copies (gate_w, up_w) from strided copies (down) + // - Scale conversions are relatively small, merge with weight tasks + + // Use fewer, larger tasks for better efficiency + constexpr int NUM_WEIGHT_TASKS = 8; // Fewer tasks, larger chunks + constexpr int MIN_COLS_PER_TASK = 128; + int num_down_tasks = std::max(1, (int)config_.hidden_size / MIN_COLS_PER_TASK); + num_down_tasks = std::min(num_down_tasks, 32); + + // Total tasks: gate_weight + up_weight + down_weight_scale + gate_scale + up_scale + int total_tasks = NUM_WEIGHT_TASKS * 2 + num_down_tasks + 2; + + size_t weight_chunk_size = (cpu_tp_weight_bytes + NUM_WEIGHT_TASKS - 1) / NUM_WEIGHT_TASKS; + // Align chunk size to 64 bytes for optimal streaming stores + weight_chunk_size = (weight_chunk_size + 63) & ~63ULL; + pool->do_work_stealing_job( - nth * num_experts, nullptr, - [&, this](int task_id) { - int expert_id = task_id / nth; - // int ith = task_id % nth; - // auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); + total_tasks, nullptr, + [&, this, num_down_tasks, expert_id, weight_chunk_size](int task_id) { + if (task_id < NUM_WEIGHT_TASKS) { + // Gate weight copy - chunked + int chunk_idx = task_id; + size_t start = chunk_idx * weight_chunk_size; + size_t end = std::min(start + weight_chunk_size, cpu_tp_weight_bytes); + if (start < end) { + uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b; + fast_memcpy(w13_weight_dst + offset_in_gpu_weight + start, gate_weight_src + start, end - start); + } + } else if (task_id < NUM_WEIGHT_TASKS * 2) { + // Up weight copy - chunked + int chunk_idx = task_id - NUM_WEIGHT_TASKS; + size_t start = chunk_idx * weight_chunk_size; + size_t end = std::min(start + weight_chunk_size, cpu_tp_weight_bytes); + if (start < end) { + uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b; + fast_memcpy(w13_weight_dst + offset_in_gpu_weight + gpu_tp_weight_bytes + start, up_weight_src + start, + end - start); + } + } else if (task_id < NUM_WEIGHT_TASKS * 2 + num_down_tasks) { + // Down columns - split by column chunks + // Each task handles multiple consecutive columns for better cache locality + int chunk_idx = task_id - NUM_WEIGHT_TASKS * 2; + size_t cols_per_chunk = (config_.hidden_size + num_down_tasks - 1) / num_down_tasks; + size_t col_start = chunk_idx * cols_per_chunk; + size_t col_end = std::min(col_start + cols_per_chunk, (size_t)config_.hidden_size); - // Calculate base offsets for this expert in the GPU buffers - // For w13: each expert has gate+up, so the offset needs to account for 2x size - size_t w13_expert_base_weight = expert_id * 2 * gpu_tp_weight_bytes; - size_t w13_expert_base_scale = expert_id * 2 * gpu_tp_scale_elem_count; - size_t w2_expert_base_weight = expert_id * gpu_tp_weight_bytes; - size_t w2_expert_base_scale = expert_id * gpu_tp_scale_elem_count; + size_t weight_per_col = config_.intermediate_size >> 1; + size_t scale_per_col = config_.intermediate_size / group_size; + size_t gpu_weight_stride = (full_config.intermediate_size / gpu_tp_count) >> 1; + size_t gpu_scale_stride = (full_config.intermediate_size / gpu_tp_count) / group_size; + size_t gpu_weight_slice_offset = local_idx * weight_per_col; + size_t gpu_scale_slice_offset = local_idx * scale_per_col; - // Gate (first part of w13 for this expert) - uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b; - float* gate_scale_src = gate_bb_[expert_id]->d; - std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight, gate_weight_src, - cpu_tp_weight_bytes); - convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale), gate_scale_src, - cpu_tp_scale_elem_count); + for (size_t col = col_start; col < col_end; col++) { + fast_memcpy(w2_weight_dst + col * gpu_weight_stride + gpu_weight_slice_offset, + (uint8_t*)down_bb_[expert_id]->b + col * weight_per_col, weight_per_col); - // Up (second part of w13 for this expert, immediately after gate) - uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b; - float* up_scale_src = up_bb_[expert_id]->d; - std::memcpy(w13_weight_dst + w13_expert_base_weight + offset_in_gpu_weight + gpu_tp_weight_bytes, - up_weight_src, cpu_tp_weight_bytes); - convert_or_copy( - (ggml_bf16_t*)(w13_scale_dst + w13_expert_base_scale + offset_in_gpu_scale + gpu_tp_scale_elem_count), - up_scale_src, cpu_tp_scale_elem_count); - - // Down (w2) - need to handle column-wise slicing - // The down matrix is transposed compared to gate/up, so we need to extract by columns - // When multiple CPU TPs map to one GPU TP, each CPU TP has a slice of intermediate dimension - // CPU TP internal layout: each column has config_.intermediate_size elements - // GPU expects: each column has full_config.intermediate_size elements - size_t cpu_tps_per_gpu = cpu_tp_count / gpu_tp_count; - - for (size_t col = 0; col < config_.hidden_size; col++) { - // GPU buffer column width is full_config.intermediate_size / gpu_tp_count - size_t gpu_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) >> 1); - size_t cpu_col_offset = col * (config_.intermediate_size >> 1); - size_t gpu_col_slice_offset = local_idx * (config_.intermediate_size >> 1); - - std::memcpy(w2_weight_dst + w2_expert_base_weight + gpu_col_offset + gpu_col_slice_offset, - (uint8_t*)down_bb_[expert_id]->b + cpu_col_offset, config_.intermediate_size / 2); - - // Same for scales - size_t gpu_scale_col_offset = col * ((full_config.intermediate_size / gpu_tp_count) / group_size); - size_t cpu_scale_col_offset = col * (config_.intermediate_size / group_size); - size_t gpu_scale_slice_offset = local_idx * (config_.intermediate_size / group_size); - - convert_or_copy( - (ggml_bf16_t*)(w2_scale_dst + w2_expert_base_scale + gpu_scale_col_offset + gpu_scale_slice_offset), - down_bb_[expert_id]->d + cpu_scale_col_offset, config_.intermediate_size / group_size); + fast_fp32_to_bf16(w2_scale_dst + col * gpu_scale_stride + gpu_scale_slice_offset, + down_bb_[expert_id]->d + col * scale_per_col, scale_per_col); + } + } else if (task_id == NUM_WEIGHT_TASKS * 2 + num_down_tasks) { + // Gate scale convert + float* gate_scale_src = gate_bb_[expert_id]->d; + fast_fp32_to_bf16(w13_scale_dst + offset_in_gpu_scale, gate_scale_src, cpu_tp_scale_elem_count); + } else { + // Up scale convert + float* up_scale_src = up_bb_[expert_id]->d; + fast_fp32_to_bf16(w13_scale_dst + offset_in_gpu_scale + gpu_tp_scale_elem_count, up_scale_src, + cpu_tp_scale_elem_count); } }, nullptr); } else { // cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs - // Each CPU TP part contains data for multiple GPU TP parts int gpu_tps_per_cpu_tp = gpu_tp_count / cpu_tp_count; - - // This CPU TP part writes to GPU TP indices: [start_gpu_tp, start_gpu_tp + gpu_tps_per_cpu_tp) int start_gpu_tp = tp_part_idx * gpu_tps_per_cpu_tp; // Size of data per GPU TP within this CPU TP size_t data_per_gpu_tp_weight = cpu_tp_weight_bytes / gpu_tps_per_cpu_tp; size_t data_per_gpu_tp_scale = cpu_tp_scale_elem_count / gpu_tps_per_cpu_tp; - // Process all experts for this GPU TP + // Optimized task layout + constexpr int NUM_WEIGHT_TASKS = 8; + constexpr int MIN_COLS_PER_TASK = 128; + int num_down_tasks = std::max(1, (int)config_.hidden_size / MIN_COLS_PER_TASK); + num_down_tasks = std::min(num_down_tasks, 32); + + int tasks_per_gpu_tp = NUM_WEIGHT_TASKS * 2 + num_down_tasks + 2; + int total_tasks = tasks_per_gpu_tp * gpu_tps_per_cpu_tp; + + size_t weight_chunk_size = (data_per_gpu_tp_weight + NUM_WEIGHT_TASKS - 1) / NUM_WEIGHT_TASKS; + weight_chunk_size = (weight_chunk_size + 63) & ~63ULL; + pool->do_work_stealing_job( - gpu_tps_per_cpu_tp * num_experts, nullptr, - [&, this](int task_id) { - int expert_id = task_id % num_experts; - int local_gpu_idx = task_id / num_experts; + total_tasks, nullptr, + [&, this, gpu_tps_per_cpu_tp, start_gpu_tp, data_per_gpu_tp_weight, data_per_gpu_tp_scale, num_down_tasks, + tasks_per_gpu_tp, expert_id, weight_chunk_size](int task_id) { + int local_gpu_idx = task_id / tasks_per_gpu_tp; + int task_type = task_id % tasks_per_gpu_tp; int gpu_tp_idx = start_gpu_tp + local_gpu_idx; // Get pointers for this GPU TP part @@ -447,649 +389,73 @@ class AMX_K2_MOE_TP { size_t cpu_offset_weight = local_gpu_idx * data_per_gpu_tp_weight; size_t cpu_offset_scale = local_gpu_idx * data_per_gpu_tp_scale; - // Calculate offsets for this expert in GPU buffers - // For w13: each expert has gate+up, so the offset needs to account for 2x size - size_t w13_gpu_expert_offset_weight = expert_id * 2 * gpu_tp_weight_bytes; - size_t w13_gpu_expert_offset_scale = expert_id * 2 * gpu_tp_scale_elem_count; - size_t w2_gpu_expert_offset_weight = expert_id * gpu_tp_weight_bytes; - size_t w2_gpu_expert_offset_scale = expert_id * gpu_tp_scale_elem_count; + if (task_type < NUM_WEIGHT_TASKS) { + // Gate weight copy - chunked + int chunk_idx = task_type; + size_t start = chunk_idx * weight_chunk_size; + size_t end = std::min(start + weight_chunk_size, data_per_gpu_tp_weight); + if (start < end) { + uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b + cpu_offset_weight; + fast_memcpy(w13_weight_dst + start, gate_weight_src + start, end - start); + } + } else if (task_type < NUM_WEIGHT_TASKS * 2) { + // Up weight copy - chunked + int chunk_idx = task_type - NUM_WEIGHT_TASKS; + size_t start = chunk_idx * weight_chunk_size; + size_t end = std::min(start + weight_chunk_size, data_per_gpu_tp_weight); + if (start < end) { + uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b + cpu_offset_weight; + fast_memcpy(w13_weight_dst + gpu_tp_weight_bytes + start, up_weight_src + start, end - start); + } + } else if (task_type < NUM_WEIGHT_TASKS * 2 + num_down_tasks) { + // Down columns - split by column chunks + int chunk_idx = task_type - NUM_WEIGHT_TASKS * 2; + size_t cols_per_chunk = (config_.hidden_size + num_down_tasks - 1) / num_down_tasks; + size_t col_start = chunk_idx * cols_per_chunk; + size_t col_end = std::min(col_start + cols_per_chunk, (size_t)config_.hidden_size); - // Gate (first part of w13 for this expert) - uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b + cpu_offset_weight; - float* gate_scale_src = gate_bb_[expert_id]->d + cpu_offset_scale; - std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight, gate_weight_src, data_per_gpu_tp_weight); - convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale), gate_scale_src, - data_per_gpu_tp_scale); + size_t weight_per_gpu_col = (config_.intermediate_size / gpu_tps_per_cpu_tp) >> 1; + size_t scale_per_gpu_col = (config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size; - // Up (second part of w13 for this expert, immediately after gate) - uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b + cpu_offset_weight; - float* up_scale_src = up_bb_[expert_id]->d + cpu_offset_scale; - std::memcpy(w13_weight_dst + w13_gpu_expert_offset_weight + gpu_tp_weight_bytes, up_weight_src, - data_per_gpu_tp_weight); - convert_or_copy((ggml_bf16_t*)(w13_scale_dst + w13_gpu_expert_offset_scale + gpu_tp_scale_elem_count), - up_scale_src, data_per_gpu_tp_scale); + for (size_t col = col_start; col < col_end; col++) { + size_t col_offset_weight = (col * config_.intermediate_size / 2) + + (local_gpu_idx * data_per_gpu_tp_weight / config_.hidden_size); + size_t col_offset_scale = (col * (config_.intermediate_size / group_size)) + + (local_gpu_idx * data_per_gpu_tp_scale / config_.hidden_size); - // Down (w2) - need to handle column-wise slicing - // The down matrix is transposed compared to gate/up, so we need to extract by columns - for (size_t col = 0; col < config_.hidden_size; col++) { - // Calculate the offset within the column for this GPU TP part - size_t col_offset_weight = (col * config_.intermediate_size / 2) + - (local_gpu_idx * data_per_gpu_tp_weight / config_.hidden_size); - size_t col_offset_scale = (col * (config_.intermediate_size / group_size)) + - (local_gpu_idx * data_per_gpu_tp_scale / config_.hidden_size); + fast_memcpy(w2_weight_dst + col * weight_per_gpu_col, + (uint8_t*)down_bb_[expert_id]->b + col_offset_weight, weight_per_gpu_col); - // Copy weights column by column - std::memcpy(w2_weight_dst + w2_gpu_expert_offset_weight + - (col * (config_.intermediate_size / gpu_tps_per_cpu_tp) / 2), - (uint8_t*)down_bb_[expert_id]->b + col_offset_weight, - (config_.intermediate_size / gpu_tps_per_cpu_tp) / 2); - - // Copy scales column by column - convert_or_copy((ggml_bf16_t*)(w2_scale_dst + w2_gpu_expert_offset_scale + - col * ((config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size)), - down_bb_[expert_id]->d + col_offset_scale, - (config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size); + fast_fp32_to_bf16(w2_scale_dst + col * scale_per_gpu_col, down_bb_[expert_id]->d + col_offset_scale, + scale_per_gpu_col); + } + } else if (task_type == NUM_WEIGHT_TASKS * 2 + num_down_tasks) { + // Gate scale convert + float* gate_scale_src = gate_bb_[expert_id]->d + cpu_offset_scale; + fast_fp32_to_bf16(w13_scale_dst, gate_scale_src, data_per_gpu_tp_scale); + } else { + // Up scale convert + float* up_scale_src = up_bb_[expert_id]->d + cpu_offset_scale; + fast_fp32_to_bf16(w13_scale_dst + gpu_tp_scale_elem_count, up_scale_src, data_per_gpu_tp_scale); } }, nullptr); } } - - void warm_up() { - int qlen = config_.max_len; - std::vector input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); - std::vector output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); - std::vector expert_ids(qlen * config_.num_experts_per_tok); - std::vector weights(qlen * config_.num_experts_per_tok); - for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) { - expert_ids[i] = i % config_.expert_num; - weights[i] = 0.01; - } - forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data()); - } - - void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { - if (qlen > 1) { - forward_prefill(qlen, k, expert_ids, weights, input, output); - } else { - forward_decode(k, expert_ids, weights, input, output); - } - } - -#ifndef DIRECT_OR_POOL_BY_QLEN -#define DIRECT_OR_POOL_BY_QLEN(var, fn) \ - do { \ - if (qlen < 10) { \ - for (int i = 0; i < (var); i++) { \ - (fn)(i); \ - } \ - } else { \ - pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \ - } \ - } while (0) -#endif - - void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, - void* output) { - auto pool = config_.pool->get_subpool(tp_part_idx); - auto& quant_config = config_.quant_config; - int& group_size = quant_config.group_size; -#ifdef FORWARD_TIME_PROFILE - auto start_time = std::chrono::high_resolution_clock::now(); - auto last = start_time; - // 用于保存各阶段耗时(单位:微秒) - long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0; - long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; - int max_local_num = 0; // 记录最大的 local num -#endif - - int activated_expert = 0; - for (int i = 0; i < config_.expert_num; i++) { - m_local_num_[i] = 0; - } - for (int i = 0; i < qlen; i++) { - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; - } - } - - for (int i = 0; i < config_.expert_num; i++) { - if (m_local_num_[i] > 0) { -#ifdef FORWARD_TIME_PROFILE - max_local_num = std::max(max_local_num, m_local_num_[i]); -#endif - m_expert_id_map_[activated_expert] = i; - activated_expert++; - } - } - - // activated_expert 已经统计完成 - - size_t offset = 0; - void* gate_up_ba_pool_ptr = gate_up_ba_pool_; - void* gate_bc_pool_ptr = gate_bc_pool_; - void* up_bc_pool_ptr = up_bc_pool_; - void* down_ba_pool_ptr = down_ba_pool_; - void* down_bc_pool_ptr = down_bc_pool_; - constexpr size_t M_STEP = T::BufferA::M_STEP; - auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); }; - size_t used_pool_m = 0; - size_t used_pool_bytes_a = 0, used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0, - used_pool_bytes_bc_down = 0; - - for (int i = 0; i < config_.expert_num; i++) { - m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size; - m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; - m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; - m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; - offset += m_local_num_[i]; - - if (m_local_num_[i] == 0) continue; - size_t max_m = (m_local_num_[i] + M_STEP - 1) / M_STEP * M_STEP; - gate_up_ba_[i]->max_m = max_m; - gate_up_ba_[i]->set_data(gate_up_ba_pool_ptr); - size_t ba_size = align64(T::BufferA::required_size(max_m, config_.hidden_size, group_size)); - gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size); - gate_bc_[i]->max_m = max_m; - gate_bc_[i]->set_data(gate_bc_pool_ptr); - size_t bc_gate_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size)); - gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size); - up_bc_[i]->max_m = max_m; - up_bc_[i]->set_data(up_bc_pool_ptr); - size_t bc_up_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size)); - up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size); - down_ba_[i]->max_m = max_m; - down_ba_[i]->set_data(down_ba_pool_ptr); - size_t ba_down_size = align64(T::BufferA::required_size(max_m, config_.intermediate_size, group_size)); - down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size); - down_bc_[i]->max_m = max_m; - down_bc_[i]->set_data(down_bc_pool_ptr); - size_t bc_down_size = align64(T::BufferC::required_size(max_m, config_.hidden_size)); - down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size); - used_pool_m += max_m; - used_pool_bytes_a += ba_size; - used_pool_bytes_bc_gate += bc_gate_size; - used_pool_bytes_bc_up += bc_up_size; - used_pool_bytes_ba_down += ba_down_size; - used_pool_bytes_bc_down += bc_down_size; - } - assert(used_pool_m <= pool_count_); - assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_); - assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_); - assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_); - assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_); - assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - prepare_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) { - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size, - (ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); - } - }); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - cpy_input_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) { - int expert_idx = m_expert_id_map_[task_id]; - gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); - }); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_input_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - int nth = T::recommended_nth(config_.intermediate_size); - pool->do_work_stealing_job( - nth * activated_expert * 2, [](int _) { T::config(); }, - [this, nth, qlen](int task_id2) { - int& group_size = config_.quant_config.group_size; - int task_id = task_id2 / 2; - bool do_up = task_id2 % 2; - int expert_idx = m_expert_id_map_[task_id / nth]; - - int ith = task_id % nth; - if (do_up) { - MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, - group_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], - ith, nth); - up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); - } else { - MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, - group_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx], - gate_bc_[expert_idx], ith, nth); - gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); - } - }, - nullptr); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - up_gate_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - auto up_gate_fn = [this, nth](int task_id) { - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); - for (int i = 0; i < m_local_num_[expert_idx]; i++) { - ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; - ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; - for (int j = n_start; j < n_end; j += 32) { - __m512 gate_val0, gate_val1, up_val0, up_val1; - avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); - avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); - __m512 result0 = amx::act_fn(gate_val0, up_val0); - __m512 result1 = amx::act_fn(gate_val1, up_val1); - avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); - } - } - }; - DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - act_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - pool->do_work_stealing_job( - activated_expert, nullptr, - [this](int task_id) { - int expert_idx = m_expert_id_map_[task_id]; - down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1); - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - nth = T::recommended_nth(config_.hidden_size); - pool->do_work_stealing_job( - nth * activated_expert, [](int _) { T::config(); }, - [this, nth, qlen](int task_id) { - int& group_size = config_.quant_config.group_size; - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, - group_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx], - ith, nth); - down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth); - }, - nullptr); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - pool->do_work_stealing_job( - qlen, nullptr, - [this, nth, output, k, expert_ids, weights](int i) { - for (int e = 0; e < config_.hidden_size; e += 32) { - __m512 x0 = _mm512_setzero_ps(); - __m512 x1 = _mm512_setzero_ps(); - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - __m512 weight = _mm512_set1_ps(weights[i * k + j]); - __m512 down_output0, down_output1; - avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] + - m_local_pos_[i][j] * config_.hidden_size + e), - &down_output0, &down_output1); - x0 = _mm512_fmadd_ps(down_output0, weight, x0); - x1 = _mm512_fmadd_ps(down_output1, weight, x1); - } - auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e); - f32out[0] = x0; - f32out[1] = x1; - } - }, - nullptr); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - weight_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } - auto end_time = std::chrono::high_resolution_clock::now(); - auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); - // 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen - printf( - "Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, " - "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: " - "%d, qlen: %d\n", - tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time, - down_time, weight_time, forward_total_time, max_local_num, qlen); -#endif - // for (int i = 0; i < qlen; i ++) - // forward_decode(k, expert_ids + i * k, weights + i * k, (ggml_bf16_t*)input + i * config_.hidden_size, - // (float*)output + i * config_.hidden_size); - } - - void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { - int qlen = 1; - auto pool = config_.pool->get_subpool(tp_part_idx); - auto& quant_config = config_.quant_config; - int& group_size = quant_config.group_size; -#ifdef FORWARD_TIME_PROFILE - auto start_time = std::chrono::high_resolution_clock::now(); - auto last = start_time; - // 用于保存各阶段耗时(单位:微秒) - long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0; - long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; - int max_local_num = 0; // 记录最大的 local num -#endif - - int activated_expert = 0; - for (int i = 0; i < k; i++) { - if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) { - continue; - } - m_expert_id_map_[activated_expert] = expert_ids[i]; - activated_expert++; - } - - size_t offset = 0; - for (int i = 0; i < activated_expert; i++) { - auto expert_idx = m_expert_id_map_[i]; - m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size; - m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size; - m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size; - offset += qlen; - } - - void* gate_bc_pool_ptr = gate_bc_pool_; - void* up_bc_pool_ptr = up_bc_pool_; - void* down_ba_pool_ptr = down_ba_pool_; - void* down_bc_pool_ptr = down_bc_pool_; - constexpr size_t M_STEP = T::BufferA::M_STEP; - auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); }; - size_t used_pool_m = 0; - size_t used_pool_bytes_a = 0, used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0, - used_pool_bytes_bc_down = 0; - for (int i = 0; i < activated_expert; i++) { - auto expert_idx = m_expert_id_map_[i]; - size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP; - - gate_bc_[expert_idx]->max_m = max_m; - gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr); - size_t bc_gate_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size)); - gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size); - - up_bc_[expert_idx]->max_m = max_m; - up_bc_[expert_idx]->set_data(up_bc_pool_ptr); - size_t bc_up_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size)); - up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size); - - down_ba_[expert_idx]->max_m = max_m; - down_ba_[expert_idx]->set_data(down_ba_pool_ptr); - size_t ba_down_size = align64(T::BufferA::required_size(max_m, config_.intermediate_size, group_size)); - down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size); - - down_bc_[expert_idx]->max_m = max_m; - down_bc_[expert_idx]->set_data(down_bc_pool_ptr); - size_t bc_down_size = align64(T::BufferC::required_size(max_m, config_.hidden_size)); - down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size); - - used_pool_m += max_m; - used_pool_bytes_bc_gate += bc_gate_size; - used_pool_bytes_bc_up += bc_up_size; - used_pool_bytes_ba_down += ba_down_size; - used_pool_bytes_bc_down += bc_down_size; - } - assert(used_pool_m <= pool_count_); - assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_); - assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_); - assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_); - assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_); - - gate_up_ba_[0]->max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP; - gate_up_ba_[0]->set_data(gate_up_ba_pool_); - gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_input_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - // calc gate & up - int nth = T::recommended_nth(config_.intermediate_size); - pool->do_work_stealing_job( - nth * activated_expert * 2, [](int _) { T::config(); }, - [this, nth, qlen](int task_id2) { - int& group_size = config_.quant_config.group_size; - int task_id = task_id2 / 2; - bool do_up = task_id2 % 2; - int expert_idx = m_expert_id_map_[task_id / nth]; - - int ith = task_id % nth; - if (do_up) { - amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0], - up_bb_[expert_idx], up_bc_[expert_idx], ith, nth); - up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth); - } else { - amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0], - gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth); - gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth); - } - }, - nullptr); - -#ifdef DEBUG_K2_MOE - if (activated_expert > 0) { - int print_elems = std::min(config_.intermediate_size, 16); - for (int dbg = 0; dbg < activated_expert; ++dbg) { - int sample_expert = m_expert_id_map_[dbg]; - ggml_bf16_t* gate_ptr = m_local_gate_output_ptr_[sample_expert]; - if (gate_ptr == nullptr) { - continue; - } - - printf("[K2][TP %d] gate_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems); - for (int idx = 0; idx < print_elems; idx++) { - float val = ggml_bf16_to_fp32(gate_ptr[idx]); - printf("%.6f ", val); - } - printf("\n"); - - int tail_start = config_.intermediate_size > print_elems ? config_.intermediate_size - print_elems : 0; - printf("[K2][TP %d] gate_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems); - for (int idx = 0; idx < print_elems; idx++) { - float val = ggml_bf16_to_fp32(gate_ptr[tail_start + idx]); - printf("%.6f ", val); - } - printf("\n"); - } - } -#endif - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - up_gate_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - // act - for (int task_id = 0; task_id < nth * activated_expert; task_id++) { - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); - for (int i = 0; i < qlen; i++) { - ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; - ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; - for (int j = n_start; j < n_end; j += 32) { - __m512 gate_val0, gate_val1, up_val0, up_val1; - avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); - avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); - __m512 result0 = amx::act_fn(gate_val0, up_val0); - __m512 result1 = amx::act_fn(gate_val1, up_val1); - avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); - } - } - } - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - act_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - // quant, get down a - pool->do_work_stealing_job( - activated_expert, nullptr, - [this, qlen](int task_id) { - int expert_idx = m_expert_id_map_[task_id]; - down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1); - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - // * down - nth = T::recommended_nth(config_.hidden_size); - pool->do_work_stealing_job( - nth * activated_expert, [](int _) { T::config(); }, - [this, nth, qlen](int task_id) { - int& group_size = config_.quant_config.group_size; - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - amx::vec_mul_kgroup(qlen, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], - down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); - down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth); - }, - nullptr); - -#ifdef DEBUG_K2_MOE - if (activated_expert > 0) { - int print_elems = std::min(config_.hidden_size, 16); - for (int dbg = 0; dbg < activated_expert; ++dbg) { - int sample_expert = m_expert_id_map_[dbg]; - ggml_bf16_t* down_ptr = m_local_down_output_ptr_[sample_expert]; - if (down_ptr == nullptr) { - continue; - } - - printf("[K2][TP %d] down_out (expert %d, first %d elems): ", tp_part_idx, sample_expert, print_elems); - for (int idx = 0; idx < print_elems; idx++) { - float val = ggml_bf16_to_fp32(down_ptr[idx]); - printf("%.6f ", val); - } - printf("\n"); - - int tail_start = config_.hidden_size > print_elems ? config_.hidden_size - print_elems : 0; - printf("[K2][TP %d] down_out (expert %d, last %d elems): ", tp_part_idx, sample_expert, print_elems); - for (int idx = 0; idx < print_elems; idx++) { - float val = ggml_bf16_to_fp32(down_ptr[tail_start + idx]); - printf("%.6f ", val); - } - printf("\n"); - } - } -#endif - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - // get output - for (int e = 0; e < config_.hidden_size; e += 32) { - __m512 x0 = _mm512_setzero_ps(); - __m512 x1 = _mm512_setzero_ps(); - for (int j = 0; j < k; j++) { - if (expert_ids[j] < config_.num_gpu_experts || expert_ids[j] >= config_.expert_num) { - continue; - } - __m512 weight = _mm512_set1_ps(weights[j]); - __m512 down_output0, down_output1; - avx512_32xbf16_to_32xfp32( - (__m512i*)(m_local_down_output_ptr_[expert_ids[j]] + m_local_pos_[0][j] * config_.hidden_size + e), - &down_output0, &down_output1); - x0 = _mm512_fmadd_ps(down_output0, weight, x0); - x1 = _mm512_fmadd_ps(down_output1, weight, x1); - } - auto f32out = (__m512*)((float*)output + e); - f32out[0] = x0; - f32out[1] = x1; - } - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - weight_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } - auto end_time = std::chrono::high_resolution_clock::now(); - auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); - // 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen - printf( - "Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, " - "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n", - tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time, - forward_total_time); -#endif - } }; -template -class TP_MOE> : public TP_MOE_Common> { - public: - using TP_MOE_Common>::TP_MOE_Common; +// ============================================================================ +// TP_MOE specialization for AMX_K2_MOE_TP +// Inherits from TP_MOE> to reuse merge_results implementation +// ============================================================================ - void load_weights() { +template +class TP_MOE> : public TP_MOE>> { + public: + using Base = TP_MOE>>; + using Base::Base; + + void load_weights() override { auto& config = this->config; auto& tps = this->tps; auto& tp_count = this->tp_count; @@ -1102,10 +468,9 @@ class TP_MOE> : public TP_MOE_Common> { long alloc_and_tp_slice_time = 0, tps_load_time = 0, cleanup_time = 0; #endif - // Check if using per-expert pointers (gate_projs) or contiguous memory (gate_proj + gate_scale) bool use_per_expert_ptrs = !config.gate_projs.empty(); - if (!use_per_expert_ptrs && config.gate_scale == nullptr) { + if (config.gate_projs.empty() && config.gate_scale == nullptr) { throw std::runtime_error("K2 MoE only supports Packed Int4 with KGroup Scale"); } @@ -1118,15 +483,11 @@ class TP_MOE> : public TP_MOE_Common> { int& group_size = config.quant_config.group_size; if (use_per_expert_ptrs) { - // Load from per-expert pointers - no need to allocate intermediate buffers - // gate_projs[numa_id][expert_id] -> pointer to expert weight - // For RAWINT4, numa dimension is 1 (index 0) for (auto i = 0; i < tp_count; i++) { auto& tpc = tps[i]->config_; size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size; size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size; - // Allocate per-TP buffers tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; @@ -1139,8 +500,6 @@ class TP_MOE> : public TP_MOE_Common> { [&, i](int expert_id_) { size_t expert_id = expert_map(physical_to_logical_map, expert_id_); - // Source pointers from per-expert pointer arrays - // gate_projs[0][expert_id] since numa dimension is 1 uint8_t* src_gate = (uint8_t*)config.gate_projs[0][expert_id]; uint8_t* src_up = (uint8_t*)config.up_projs[0][expert_id]; uint8_t* src_down = (uint8_t*)config.down_projs[0][expert_id]; @@ -1148,7 +507,6 @@ class TP_MOE> : public TP_MOE_Common> { ggml_bf16_t* src_up_scale = (ggml_bf16_t*)config.up_scales[0][expert_id]; ggml_bf16_t* src_down_scale = (ggml_bf16_t*)config.down_scales[0][expert_id]; - // TP-slicing for gate and up (row-major slicing) memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1), src_gate + ((i * weight_elem_count) >> 1), (weight_elem_count >> 1)); @@ -1161,7 +519,6 @@ class TP_MOE> : public TP_MOE_Common> { memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count), src_up_scale + (i * scales_elem_count), sizeof(ggml_bf16_t) * scales_elem_count); - // TP-slicing for down (by column) for (size_t col = 0; col < config.hidden_size; col++) { memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1), src_down + ((col * config.intermediate_size + i * tpc.intermediate_size) >> 1), @@ -1177,7 +534,6 @@ class TP_MOE> : public TP_MOE_Common> { printf("TP %d load weight done.\n", i); } } else { - // Original path: load from contiguous memory with gate_proj/gate_scale for (auto i = 0; i < tp_count; i++) { auto& tpc = tps[i]->config_; size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size; @@ -1191,13 +547,12 @@ class TP_MOE> : public TP_MOE_Common> { tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)]; - if (tps[i]->config_.load == false) { + if (tpc.load == false) { pool->get_subpool(i)->do_work_stealing_job( tpc.expert_num, nullptr, - [&](int expert_id_) { // weight and scale are all in col majored. + [&](int expert_id_) { size_t expert_id = expert_map(physical_to_logical_map, expert_id_); - // weight and scale TP-slicing for gate and up memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1), (uint8_t*)config.gate_proj + ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), @@ -1220,7 +575,6 @@ class TP_MOE> : public TP_MOE_Common> { i * scales_elem_count), sizeof(ggml_bf16_t) * scales_elem_count); - // weight and scale TP-slicing for down (by column) for (size_t col = 0; col < config.hidden_size; col++) { memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1), (uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size + @@ -1288,8 +642,7 @@ class TP_MOE> : public TP_MOE_Common> { this->weights_loaded = true; } - void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num, - const std::vector& w13_weight_ptrs, + void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector& w13_weight_ptrs, const std::vector& w13_scale_ptrs, const std::vector& w2_weight_ptrs, const std::vector& w2_scale_ptrs) { @@ -1300,59 +653,18 @@ class TP_MOE> : public TP_MOE_Common> { throw std::runtime_error("No TP parts initialized"); } - // Validate input vector sizes if (w13_weight_ptrs.size() != gpu_tp_count || w13_scale_ptrs.size() != gpu_tp_count || w2_weight_ptrs.size() != gpu_tp_count || w2_scale_ptrs.size() != gpu_tp_count) { throw std::runtime_error("Pointer arrays size must match gpu_tp_count"); } - auto& config = this->config; - auto pool = config.pool; - // Each TP part writes to its corresponding buffer - pool->dispense_backend()->do_numa_job([this, pool, gpu_tp_count, gpu_experts_num, w13_weight_ptrs, w13_scale_ptrs, - w2_weight_ptrs, w2_scale_ptrs](int numa_id) { - // Note: w13 combines gate and up projections - // Split w13 pointers for gate and up - this->tps[numa_id]->write_weights_to_buffer(gpu_tp_count, this->tp_count, gpu_experts_num, this->config, - w13_weight_ptrs, w13_scale_ptrs, // gate + up use w13 - w2_weight_ptrs, w2_scale_ptrs); // down uses w2 + this->config.pool->dispense_backend()->do_numa_job([&, this](int i) { + this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs, + w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs); }); } - void merge_results(int qlen, void* output, bool incremental) { - auto pool = this->config.pool; - auto merge_fn = [this, output, incremental](int token_nth) { - auto& local_output_numa = this->local_output_numa; - auto& tp_configs = this->tp_configs; - auto& tp_count = this->tp_count; - auto& config = this->config; - float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size; - if (incremental) { - for (int e = 0; e < config.hidden_size; e += 32) { - __m512 x0, x1; - avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1); - *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0); - *((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1); - } - } - for (int i = 1; i < tp_count; i++) { - float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size; - for (int e = 0; e < tp_configs[i].hidden_size; e += 16) { - *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e))); - } - } - for (int e = 0; e < config.hidden_size; e += 32) { - __m512 x0 = *(__m512*)(merge_to + e); - __m512 x1 = *(__m512*)(merge_to + e + 16); - avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e)); - } - }; - for (int i = 0; i < qlen; i++) { - merge_fn(i); - } - } - - void merge_results(int qlen, void* output) { merge_results(qlen, output, false); } + // merge_results is inherited from TP_MOE>> }; #endif // CPUINFER_OPERATOR_AMX_K2_MOE_H diff --git a/kt-kernel/operators/amx/la/amx.hpp b/kt-kernel/operators/amx/la/amx.hpp index 7130392..c8ce391 100644 --- a/kt-kernel/operators/amx/la/amx.hpp +++ b/kt-kernel/operators/amx/la/amx.hpp @@ -46,6 +46,9 @@ static inline __m512 exp_avx512(__m512 x) { static inline __m512 act_fn(__m512 gate_val, __m512 up_val) { __m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val); + // Clamp neg_gate_val to avoid exp overflow (exp(88) overflows for float32) + const __m512 max_exp_input = _mm512_set1_ps(88.0f); + neg_gate_val = _mm512_min_ps(neg_gate_val, max_exp_input); __m512 exp_neg_gate = exp_avx512(neg_gate_val); __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate); __m512 act_val = _mm512_div_ps(gate_val, denom); diff --git a/kt-kernel/operators/amx/la/amx_kernels.hpp b/kt-kernel/operators/amx/la/amx_kernels.hpp index b7333a2..65f0643 100644 --- a/kt-kernel/operators/amx/la/amx_kernels.hpp +++ b/kt-kernel/operators/amx/la/amx_kernels.hpp @@ -762,6 +762,16 @@ struct GemmKernel224BF { struct BufferC { float* c; int max_m, n; + // 物理布局(按 float 元素数): + // 逻辑矩阵 C 为 (max_m, n) 行主序,max_m 为 M_STEP 的倍数, + // n 按 N_BLOCK 分块。 + // 存储顺序: + // n_block(N_BLOCK 列) → m_block(M_STEP 行) → n_step(N_STEP 列) → (M_STEP×N_STEP) 行主序 tile。 + // 因此可视为 5D: + // c[n_blocks][m_blocks][n_steps][M_STEP][N_STEP], + // n_blocks = ceil(n / N_BLOCK),m_blocks = max_m / M_STEP, + // n_steps = N_BLOCK / N_STEP(尾块可能更小)。 + // get_submat(m_begin, n_begin) 返回连续的 (M_STEP×N_STEP) tile 起始地址。 static size_t required_size(int max_m, int n) { return sizeof(float) * max_m * n; } diff --git a/kt-kernel/operators/amx/la/amx_raw_buffers.hpp b/kt-kernel/operators/amx/la/amx_raw_buffers.hpp new file mode 100644 index 0000000..b485966 --- /dev/null +++ b/kt-kernel/operators/amx/la/amx_raw_buffers.hpp @@ -0,0 +1,488 @@ +#ifndef AMX_RAW_BUFFERS_HPP +#define AMX_RAW_BUFFERS_HPP + +/** + * @file amx_raw_buffers.hpp + * @brief Raw data format buffer management (FP8, BF16, etc.) + * + * 本文件实现原精度格式的缓冲区管理,用于 DeepSeek V3.2 等原精度推理。 + * + * 缓冲区类型: + * - BufferAFP8Impl: 输入激活缓冲区,支持动态 FP8 量化 + * - BufferBFP8Impl: 权重缓冲区,FP8 格式 + 128x128 块缩放 + * - BufferBFP8BlockImpl: 优化的块量化权重缓冲区 + * + * 内存布局: + * - FP8 数据:1 字节/元素 + * - Scale:4 字节/块(BufferB 每 128x128 块一个,BufferA 每 128 行一个) + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "amx_config.hpp" +#include "amx_utils.hpp" +#include "llama.cpp/ggml-impl.h" +#include "pack.hpp" +#include "utils.hpp" + +namespace amx { + +// ============================================================================ +// BufferAFP8Impl: FP8 激活缓冲区(支持动态量化) +// ============================================================================ +/* 物理布局(按 bf16 元素数): + * 逻辑矩阵 A 为 (m, k) 行主序,m pad 到 max_m(=m_block_size,M_STEP 的倍数)。 + * 存储顺序: + * k_block(K_BLOCK 列) → m_block(M_STEP 行) → k_step(K_STEP 列) → (M_STEP×K_STEP) 行主序 tile。 + * 因此可视为 5D: + * a[k_blocks][m_blocks][k_steps][M_STEP][K_STEP], + * k_blocks = ceil(k / K_BLOCK),m_blocks = max_m / M_STEP, + * k_steps = K_BLOCK / K_STEP(最后一个 k_block 可能更小)。 + * get_submat(m_begin, k_begin) 返回连续的 (M_STEP×K_STEP) tile。 + */ +template +struct BufferABF16Impl { + ggml_bf16_t* a; + int max_m, k; + static constexpr int M_STEP = K::M_STEP; + static constexpr int K_STEP = K::K_STEP; + static constexpr int K_BLOCK = K::K_BLOCK; + + static size_t required_size(int max_m, int k) { return sizeof(ggml_bf16_t) * max_m * k; } + + BufferABF16Impl(int max_m, int k, void* ptr) : max_m(max_m), k(k) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(max_m % M_STEP == 0); + assert(k % K_STEP == 0); + a = reinterpret_cast(ptr); + } + + void set_data(void* new_ptr) { a = reinterpret_cast(new_ptr); } + + void from_mat(int m, ggml_bf16_t* src, int ith, int nth) { + assert(m <= max_m); + assert(ith == 0 && nth == 1); + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { + for (int i = 0; i < M_STEP && m_begin + i < m; i++) { + __m512i* s = (__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin); + __m512i* d = + (__m512i*)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP); + avx512_copy_32xbf16(s, d); + } + } + } + } + } + + ggml_bf16_t* get_submat(int m, int k, int m_begin, int k_begin) { + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int k_block_begin = k_begin / K_BLOCK * K_BLOCK; + k_begin -= k_block_begin; + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP; + } +}; + +// ============================================================================ +// BufferB +// ============================================================================ + +/** + * @brief BF16 BufferB + * 物理布局(按 bf16 元素数): + * 逻辑矩阵 B 为 (n, k) 行主序(用于 NT GEMM),n 按 N_BLOCK 分块。 + * 存储顺序: + * n_block(N_BLOCK 行) → k_block(K_BLOCK 列) → n_step(N_STEP 行) → k_step(K_STEP 列) + * → (N_STEP×K_STEP) tile;每个 tile 内部再对两个 16×16 子块做 transpose, + * 以匹配 AMX BTile 的 VNNI 布局(TILE_K/VNNI_BLK × TILE_N*VNNI_BLK)。 + * 因此可视为 6D: + * b[n_blocks][k_blocks][n_steps][k_steps][N_STEP][K_STEP], + * n_blocks = ceil(n / N_BLOCK),k_blocks = ceil(k / K_BLOCK), + * n_steps = N_BLOCK / N_STEP,k_steps = K_BLOCK / K_STEP(尾块可能更小)。 + * get_submat(n_begin, k_begin) 返回连续的 (N_STEP×K_STEP) tile 起始地址。 + * @tparam K Kernel 类型 + */ + +template +struct BufferBBF16Impl { + ggml_bf16_t* b; + int n, k; + static constexpr bool SCALE = false; + static constexpr int N_STEP = K::N_STEP; + static constexpr int K_STEP = K::K_STEP; + static constexpr int N_BLOCK = K::N_BLOCK; + static constexpr int K_BLOCK = K::K_BLOCK; + static constexpr int TILE_N = K::TILE_N; + static size_t required_size(int n, int k) { return sizeof(ggml_bf16_t) * n * k; } + + BufferBBF16Impl(int n, int k, void* ptr) : n(n), k(k) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(n % N_STEP == 0); + assert(k % K_STEP == 0); + b = reinterpret_cast(ptr); + } + void set_data(void* new_ptr) { b = reinterpret_cast(new_ptr); } + + void from_mat(ggml_bf16_t* src, int ith, int nth) { + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + int n_block_begin = n_start; + int n_block_size = n_end - n_block_begin; + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { + for (int i = 0; i < N_STEP; i++) { + __m512i* s = (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin); + __m512i* d = (__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + + k_begin * N_STEP + i * K_STEP); + avx512_copy_32xbf16(s, d); + } + transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size + + n_begin * k_block_size + k_begin * N_STEP)); + transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size + + n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP)); + } + } + } + } + ggml_bf16_t* get_submat(int n, int k, int n_begin, int k_begin) { + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + n_begin -= n_block_begin; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + int k_block_begin = k_begin / K_BLOCK * K_BLOCK; + k_begin -= k_block_begin; + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP; + } +}; + +/** + * @brief FP8 权重缓冲区 + * + * 存储 FP8 格式的权重矩阵,每个 128x128 块有一个缩放因子。 + * 这与 DeepSeek V3.2 的原精度格式匹配。 + * + * @tparam K Kernel 类型 + */ +template +struct BufferBFP8Impl { + uint8_t* b; // FP8 weight + float* d; // scale_inv [n / k_group_size, k / k_group_size] + int n, k, k_group_size; // k_group_size = 128 in DeepSeek + + static constexpr int N_STEP = K::N_STEP; + static constexpr int K_STEP = K::K_STEP; + static constexpr int N_BLOCK = K::N_BLOCK; + static constexpr int K_BLOCK = K::K_BLOCK; + static constexpr bool SCALE = true; + + /** + * @brief 计算所需内存大小 + */ + static size_t required_size(int n, int k, int k_group_size) { + int n_blocks_n = (n + k_group_size - 1) / k_group_size; + int n_blocks_k = (k + k_group_size - 1) / k_group_size; + return sizeof(uint8_t) * n * k + sizeof(float) * n_blocks_n * n_blocks_k; + } + + /** + * @brief 构造函数 + */ + BufferBFP8Impl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) { set_data(ptr); } + + void set_data(void* ptr) { + assert(reinterpret_cast(ptr) % 64 == 0); + b = reinterpret_cast(ptr); + d = reinterpret_cast(b + (size_t)n * k); + } + + static constexpr int mat_offset[8] = {0, 2, 4, 6, 1, 3, 5, 7}; // fp8 matrix offset for reordering + /** + * @brief 从原始 FP8 权重加载(已经是量化格式) + * + * @param b_src FP8 权重源数据 (n-major, n×k) + * @param d_src FP32 scale_inv 源数据 (n-major, ceil(n/128)×ceil(k/128)) + */ + void from_mat(const uint8_t* b_src, const float* d_src, int ith, int nth) { + assert(b != nullptr && d != nullptr); + assert(N_STEP == 32 && K_STEP == 32); // from mat block copy assumes this + + // Copy scales (per 128x128 block). Each thread copies its own n-block range. + const int n_blocks_k = (k + k_group_size - 1) / k_group_size; + if (d_src != nullptr) { + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + int bn_start = n_start / k_group_size; + int bn_end = (n_end + k_group_size - 1) / k_group_size; + memcpy(d + bn_start * n_blocks_k, d_src + bn_start * n_blocks_k, + sizeof(float) * (bn_end - bn_start) * n_blocks_k); + } + + // Reorder FP8 weights into KT block-major layout (same panel->tile order as BF16 BufferB). + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + int n_block_begin = n_start; + int n_block_size = n_end - n_block_begin; + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + int n_step_size = std::min(N_STEP, n_block_size - n_begin); + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { + int k_step_size = std::min(K_STEP, k_block_size - k_begin); + // [k_step_size, n_step_size] block copy + const uint8_t* block_b_src = b_src + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin; + uint64_t* block_b_dst = + reinterpret_cast(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + + (size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP); + for (int i = 0; i < 8; i++) { + const uint16_t* s = reinterpret_cast(block_b_src + (size_t)i * k * 4); + for (int j = 0; j < 16; j++) { + uint64_t val = (((uint64_t)s[j])) | (((uint64_t)s[j + (k / 2) * 1]) << 16) | + (((uint64_t)s[j + (k / 2) * 2]) << 32) | (((uint64_t)s[j + (k / 2) * 3]) << 48); + block_b_dst[8 * j + mat_offset[i]] = val; + } + } + } + } + } + } + + /** + * @brief get scale_inv + */ + float* get_scale(int n, int n_begin, int k, int k_begin) { + int n_blocks_k = (k + k_group_size - 1) / k_group_size; + int bn = n_begin / k_group_size; + int bk = k_begin / k_group_size; + return d + bn * n_blocks_k + bk; + } + + /** + * @brief 获取子矩阵指针 + */ + uint8_t* get_submat(int n, int k, int n_begin, int k_begin) { + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + n_begin -= n_block_begin; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + int k_block_begin = k_begin / K_BLOCK * K_BLOCK; + k_begin -= k_block_begin; + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + return b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + (size_t)n_begin * k_block_size + + (size_t)k_begin * N_STEP; + } + + /** + * @brief Inverse mapping for mat_offset used in to_mat + * mat_offset = {0, 2, 4, 6, 1, 3, 5, 7} + * inv_mat_offset[mat_offset[i]] = i + */ + static constexpr int inv_mat_offset[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + + /** + * @brief Unpack FP8 weights from KT block-major layout back to n-major layout + * + * This is the inverse operation of from_mat. + * + * @param b_dst FP8 输出缓冲区 (n-major, n×k) + * @param d_dst FP32 scale_inv 输出缓冲区 (n-major, ceil(n/128)×ceil(k/128)) + * @param ith Thread index + * @param nth Total number of threads + */ + void to_mat(uint8_t* b_dst, float* d_dst, int ith, int nth) const { + assert(b != nullptr && d != nullptr); + assert(N_STEP == 32 && K_STEP == 32); + + // Calculate N_BLOCK range for this thread + // Unlike split_range_n which gives one N_BLOCK per thread, we need to handle + // the case where nth < n/N_BLOCK (fewer threads than blocks) + int total_n_blocks = (n + N_BLOCK - 1) / N_BLOCK; + int blocks_per_thread = (total_n_blocks + nth - 1) / nth; + int start_n_block_idx = ith * blocks_per_thread; + int end_n_block_idx = std::min((ith + 1) * blocks_per_thread, total_n_blocks); + + // Copy scales (per 128x128 block). Each thread copies its own n-block range. + const int n_blocks_k = (k + k_group_size - 1) / k_group_size; + if (d_dst != nullptr) { + int bn_start = start_n_block_idx; + int bn_end = end_n_block_idx; + memcpy(d_dst + bn_start * n_blocks_k, d + bn_start * n_blocks_k, + sizeof(float) * (bn_end - bn_start) * n_blocks_k); + } + + // Reorder FP8 weights back to n-major layout (inverse of from_mat) + // Process each N_BLOCK assigned to this thread + for (int n_block_idx = start_n_block_idx; n_block_idx < end_n_block_idx; n_block_idx++) { + int n_block_begin = n_block_idx * N_BLOCK; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { + // Source: packed layout (KT block-major) + const uint64_t* block_b_src = + reinterpret_cast(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + + (size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP); + + // Destination: n-major layout + uint8_t* block_b_dst = b_dst + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin; + + // Inverse of from_mat transformation + for (int packed_i = 0; packed_i < 8; packed_i++) { + int i = inv_mat_offset[packed_i]; + uint16_t* d_row = reinterpret_cast(block_b_dst + (size_t)i * k * 4); + for (int j = 0; j < 16; j++) { + uint64_t val = block_b_src[8 * j + packed_i]; + d_row[j] = (uint16_t)(val & 0xFFFF); + d_row[j + (k / 2) * 1] = (uint16_t)((val >> 16) & 0xFFFF); + d_row[j + (k / 2) * 2] = (uint16_t)((val >> 32) & 0xFFFF); + d_row[j + (k / 2) * 3] = (uint16_t)((val >> 48) & 0xFFFF); + } + } + } + } + } + } + } +}; + +// ============================================================================ +// BufferCFP8Impl: FP32 输出缓冲区 +// ============================================================================ + +/** + * @brief FP32 输出缓冲区 + * + * 存储 FP32 格式的累加器,支持转换为 BF16 输出 + * + * @tparam K Kernel 类型 + */ +template +struct BufferCFP32Impl { + float* c; + int max_m, n; + static constexpr int M_STEP = K::M_STEP; + static constexpr int N_STEP = K::N_STEP; + static constexpr int N_BLOCK = K::N_BLOCK; + // 物理布局(按 float 元素数): + // 逻辑矩阵 C 为 (max_m, n) 行主序,max_m 为 M_STEP 的倍数, + // n 按 N_BLOCK 分块。 + // 存储顺序: + // n_block(N_BLOCK 列) → m_block(M_STEP 行) → n_step(N_STEP 列) → (M_STEP×N_STEP) 行主序 tile。 + // 因此可视为 5D: + // c[n_blocks][m_blocks][n_steps][M_STEP][N_STEP], + // n_blocks = ceil(n / N_BLOCK),m_blocks = max_m / M_STEP, + // n_steps = N_BLOCK / N_STEP(尾块可能更小)。 + // get_submat(m_begin, n_begin) 返回连续的 (M_STEP×N_STEP) tile 起始地址。 + + static size_t required_size(int max_m, int n) { return sizeof(float) * max_m * n; } + + BufferCFP32Impl(int max_m, int n, void* ptr) : max_m(max_m), n(n) { + assert(reinterpret_cast(ptr) % 64 == 0); + assert(max_m % M_STEP == 0); + assert(n % N_STEP == 0); + c = reinterpret_cast(ptr); + } + + void set_data(void* new_ptr) { c = reinterpret_cast(new_ptr); } + + void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) { + assert(m <= max_m); + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int n_block_begin = n_start; + int n_block_size = n_end - n_block_begin; + for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + for (int i = 0; i < M_STEP && m_begin + i < m; i++) { + __m512* x0 = + (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP); + __m512* x1 = + (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16); + avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin)); + } + } + } + } + + float* get_submat(int m, int n, int m_begin, int n_begin) { + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + n_begin -= n_block_begin; + return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP; + } +}; + +template +struct BufferCFP32ReduceImpl { + float* c; + float* reduce_buf; + int max_m, n; + + static constexpr int M_STEP = K::M_STEP; + static constexpr int N_STEP = K::N_STEP; + static constexpr int N_BLOCK = K::N_BLOCK; + + static size_t required_size(int max_m, int n) { return sizeof(float) * (size_t)max_m * n * 2; } + + BufferCFP32ReduceImpl(int max_m, int n, void* ptr) : max_m(max_m), n(n) { + assert(max_m % M_STEP == 0); + assert(n % N_STEP == 0); + set_data(ptr); + } + + void set_data(void* ptr) { + assert(reinterpret_cast(ptr) % 64 == 0); + c = reinterpret_cast(ptr); + reduce_buf = c + (size_t)max_m * n; + } + + void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) { + assert(m <= max_m); + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int n_block_begin = n_start; + int n_block_size = n_end - n_block_begin; + for (int m_begin = 0; m_begin < m; m_begin += M_STEP) { + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + for (int i = 0; i < M_STEP && m_begin + i < m; i++) { + __m512* x0 = + (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP); + __m512* x1 = + (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16); + avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin)); + } + } + } + } + + float* get_submat(int m, int n, int m_begin, int n_begin) { + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + n_begin -= n_block_begin; + return c + (size_t)m_block_size * n_block_begin + (size_t)m_begin * n_block_size + (size_t)n_begin * M_STEP; + } + + float* get_reduce_submat(int m, int n, int m_begin, int n_begin) { + int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP; + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + n_begin -= n_block_begin; + return reduce_buf + (size_t)m_block_size * n_block_begin + (size_t)m_begin * n_block_size + + (size_t)n_begin * M_STEP; + } +}; + +} // namespace amx + +#endif // AMX_RAW_BUFFERS_HPP diff --git a/kt-kernel/operators/amx/la/amx_raw_kernels.hpp b/kt-kernel/operators/amx/la/amx_raw_kernels.hpp new file mode 100644 index 0000000..9a38394 --- /dev/null +++ b/kt-kernel/operators/amx/la/amx_raw_kernels.hpp @@ -0,0 +1,464 @@ +#ifndef AMX_RAW_KERNELS_HPP +#define AMX_RAW_KERNELS_HPP + +#include +#include +#include +#include +#include + +#include "amx_config.hpp" +#include "amx_raw_buffers.hpp" +#include "amx_utils.hpp" +#include "llama.cpp/ggml-impl.h" + +namespace amx { + +struct GemmKernel224BF16 { + using dt = ggml_bf16_t; + using output_t = float; + static constexpr double ELEMENT_SIZE = 2; + static const int TILE_M = 16; + static const int TILE_K = 32; + static const int TILE_N = 16; + static const int VNNI_BLK = 2; + + static const int M_STEP = TILE_M * 2; + static const int N_STEP = TILE_N * 2; + static const int K_STEP = TILE_K; + + static inline const int N_BLOCK = 256; + static inline const int K_BLOCK = 1792; + static std::string name() { return "BF16"; } + + static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } + + static std::pair split_range_n(int n, int ith, int nth) { + int n_start = N_BLOCK * ith; + int n_end = std::min(n, N_BLOCK * (ith + 1)); + return {n_start, n_end}; + } + + static void config() { +#ifdef HAVE_AMX + enable_amx(); + TileConfig tile_config; + + // size is 16 x 32 + for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt)); + + // size is 16 x 32 + for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt)); + + // size is 16 x 16 + for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t)); + + tile_config.set_config(); +#endif + } + + static void load_a(dt* a, size_t lda) { +#ifdef HAVE_AMX + _tile_loadd(0, a, lda); + _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda); +#else + (void)a; + (void)lda; +#endif + } + + static void load_b(dt* b, size_t ldb) { +#ifdef HAVE_AMX + _tile_loadd(2, b, ldb); + _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb); +#else + (void)b; + (void)ldb; +#endif + } + + static void clean_c() { +#ifdef HAVE_AMX + _tile_zero(4); + _tile_zero(5); + _tile_zero(6); + _tile_zero(7); +#endif + } + + static void load_c(output_t* c, size_t ldc) { +#ifdef HAVE_AMX + _tile_loadd(4, c, ldc); + _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); + _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc); + _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); +#else + (void)c; + (void)ldc; +#endif + } + + static void store_c(output_t* c, size_t ldc) { +#ifdef HAVE_AMX + _tile_stored(4, c, ldc); + _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc); + _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc); + _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc); +#else + (void)c; + (void)ldc; +#endif + } + + static void run_tile() { +#ifdef HAVE_AMX + _tile_dpbf16ps(4, 0, 2); + _tile_dpbf16ps(5, 0, 3); + _tile_dpbf16ps(6, 1, 2); + _tile_dpbf16ps(7, 1, 3); +#endif + } + using BufferA = BufferABF16Impl; + using BufferB = BufferBBF16Impl; + using BufferC = BufferCFP32Impl; +}; + +// FP8 (e4m3) AMX kernel that mirrors the GemmKernel224BF16 interface. +struct GemmKernel224FP8 { + using fp8_t = uint8_t; + using output_t = float; + + static constexpr double ELEMENT_SIZE = 1.0; + static const int TILE_M = 16; + static const int TILE_K = 32; + static const int TILE_N = 16; + static const int VNNI_BLK = 2; + + static const int M_STEP = TILE_M * 2; + static const int N_STEP = TILE_N * 2; + static const int K_STEP = TILE_K; + + static inline const int BLOCK_SIZE = 128; // 128 x 128 block quantization + static inline const int N_BLOCK = 128; + static inline const int K_BLOCK = 7168; + + static std::string name() { return "FP8"; } + + static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } + + static std::pair split_range_n(int n, int ith, int nth) { + int n_start = N_BLOCK * ith; + int n_end = std::min(n, N_BLOCK * (ith + 1)); + return {n_start, n_end}; + } + + static void config() {} + + private: + alignas(64) static constexpr uint8_t bf16_hi_0_val[64] = { + 0x00, 0x3b, 0x3b, 0x3b, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, + 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, + 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, + 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, + }; + alignas(64) static constexpr uint8_t bf16_hi_1_val[64] = { + 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, + 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, + 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, + }; + alignas(64) static constexpr uint8_t bf16_lo_0_val[64] = { + 0x00, 0x00, 0x80, 0xc0, 0x00, 0x20, 0x40, 0x60, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, + 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, + 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, + 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, + }; + alignas(64) static constexpr uint8_t bf16_lo_1_val[64] = { + 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, + 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, + 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, + 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0, + }; + // _mm512_set1_epi8 is not constexpr; keep it as a static cached value + alignas(64) static const __m512i sign_mask_val; + static inline __m512i bf16_hi_0_mask() { return _mm512_load_si512((__m512i const*)bf16_hi_0_val); } + static inline __m512i bf16_hi_1_mask() { return _mm512_load_si512((__m512i const*)bf16_hi_1_val); } + static inline __m512i bf16_lo_0_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_0_val); } + static inline __m512i bf16_lo_1_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_1_val); } + static inline __m512i sign_mask() { return _mm512_set1_epi8(0x80); } + + public: + using BufferA = BufferABF16Impl; + using BufferB = BufferBFP8Impl; + using BufferC = BufferCFP32ReduceImpl; + + static inline std::pair<__m512i, __m512i> fp8x64_to_bf16x64(__m512i bfp8_512) { + // fp8->bf16 + __m512i b_hi = _mm512_permutex2var_epi8(bf16_hi_0_mask(), bfp8_512, bf16_hi_1_mask()); + __m512i b_lo = _mm512_permutex2var_epi8(bf16_lo_0_mask(), bfp8_512, bf16_lo_1_mask()); + b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask(), bfp8_512), b_hi); + __m512i bbf16_0 = _mm512_unpacklo_epi8(b_lo, b_hi); + __m512i bbf16_1 = _mm512_unpackhi_epi8(b_lo, b_hi); + return {bbf16_0, bbf16_1}; + } + // Optimized AVX kernel: process entire k_group_size + // Load all data first, then convert all, then compute all + // This gives compiler more freedom to schedule instructions + static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_group_begin, float* c, BufferA* ba, + BufferB* bb, int k_group_size) { + const __m512i bf16_hi_0_val = bf16_hi_0_mask(); + const __m512i bf16_hi_1_val = bf16_hi_1_mask(); + const __m512i bf16_lo_0_val = bf16_lo_0_mask(); + const __m512i bf16_lo_1_val = bf16_lo_1_mask(); + const __m512i sign_mask_val = sign_mask(); + + __m512* c512 = (__m512*)c; + int m_block_end = std::min(m - m_begin, M_STEP); + + // Zero out accumulator at the start + for (int m_i = 0; m_i < m_block_end; m_i++) { + c512[m_i * 2] = _mm512_setzero_ps(); + c512[m_i * 2 + 1] = _mm512_setzero_ps(); + } + + // Process entire k_group_size + for (int k_begin = 0; k_begin < k_group_size && k_group_begin + k_begin < k; k_begin += K_STEP) { + ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_group_begin + k_begin); + __m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_group_begin + k_begin); + + for (int m_i = 0; m_i < m_block_end; m_i++) { + // Process 2 k_i per iteration + for (int k_i = 0; k_i < 16; k_i += 2) { + // Load A vectors + __m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]); + __m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]); + + // Load B matrices + __m512i bfp8_0 = bfp8_512[k_i]; + __m512i bfp8_1 = bfp8_512[k_i + 1]; + + // Convert FP8 -> BF16 for all + __m512i b_hi_0 = _mm512_permutex2var_epi8(bf16_hi_0_val, bfp8_0, bf16_hi_1_val); + __m512i b_lo_0 = _mm512_permutex2var_epi8(bf16_lo_0_val, bfp8_0, bf16_lo_1_val); + b_hi_0 = _mm512_or_si512(_mm512_and_si512(sign_mask_val, bfp8_0), b_hi_0); + + __m512i b_hi_1 = _mm512_permutex2var_epi8(bf16_hi_0_val, bfp8_1, bf16_hi_1_val); + __m512i b_lo_1 = _mm512_permutex2var_epi8(bf16_lo_0_val, bfp8_1, bf16_lo_1_val); + b_hi_1 = _mm512_or_si512(_mm512_and_si512(sign_mask_val, bfp8_1), b_hi_1); + + // Compute dpbf16 for all + __m512bh bbf16_0_0 = (__m512bh)_mm512_unpacklo_epi8(b_lo_0, b_hi_0); + __m512bh bbf16_1_0 = (__m512bh)_mm512_unpackhi_epi8(b_lo_0, b_hi_0); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_0); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_1_0); + + __m512bh bbf16_0_1 = (__m512bh)_mm512_unpacklo_epi8(b_lo_1, b_hi_1); + __m512bh bbf16_1_1 = (__m512bh)_mm512_unpackhi_epi8(b_lo_1, b_hi_1); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_0_1); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_1); + } + } + } + } + + // Optimized AVX kernel: process 4 k_i at once, convert B once and reuse for all m rows + // This version achieved ~493 GB/s - restoring as baseline for further optimization + static void avx_kernel_4(int m, int n, int k, int m_begin, int n_begin, int k_group_begin, float* c, BufferA* ba, + BufferB* bb, int k_group_size) { + const __m512i bf16_hi_0 = bf16_hi_0_mask(); + const __m512i bf16_hi_1 = bf16_hi_1_mask(); + const __m512i bf16_lo_0 = bf16_lo_0_mask(); + const __m512i bf16_lo_1 = bf16_lo_1_mask(); + const __m512i sign_mask_v = sign_mask(); + + __m512* c512 = (__m512*)c; + int m_block_end = std::min(m - m_begin, M_STEP); + + // Zero out accumulator + for (int m_i = 0; m_i < m_block_end; m_i++) { + c512[m_i * 2] = _mm512_setzero_ps(); + c512[m_i * 2 + 1] = _mm512_setzero_ps(); + } + + // Process entire k_group_size + for (int k_begin = 0; k_begin < k_group_size && k_group_begin + k_begin < k; k_begin += K_STEP) { + ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_group_begin + k_begin); + __m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_group_begin + k_begin); + + // Process 4 k_i at once - convert B and reuse across all m rows + for (int k_i = 0; k_i < 16; k_i += 4) { + // Load 4 B vectors + __m512i bfp8_0 = bfp8_512[k_i]; + __m512i bfp8_1 = bfp8_512[k_i + 1]; + __m512i bfp8_2 = bfp8_512[k_i + 2]; + __m512i bfp8_3 = bfp8_512[k_i + 3]; + + // Convert all 4 FP8 -> BF16 + __m512i b_hi, b_lo; + + b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_0), + _mm512_permutex2var_epi8(bf16_hi_0, bfp8_0, bf16_hi_1)); + b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_0, bf16_lo_1); + __m512bh bbf16_0_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi); + __m512bh bbf16_0_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi); + + b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_1), + _mm512_permutex2var_epi8(bf16_hi_0, bfp8_1, bf16_hi_1)); + b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_1, bf16_lo_1); + __m512bh bbf16_1_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi); + __m512bh bbf16_1_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi); + + b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_2), + _mm512_permutex2var_epi8(bf16_hi_0, bfp8_2, bf16_hi_1)); + b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_2, bf16_lo_1); + __m512bh bbf16_2_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi); + __m512bh bbf16_2_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi); + + b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_3), + _mm512_permutex2var_epi8(bf16_hi_0, bfp8_3, bf16_hi_1)); + b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_3, bf16_lo_1); + __m512bh bbf16_3_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi); + __m512bh bbf16_3_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi); + + // Process m rows - unroll by 2 for better ILP + int m_i = 0; + for (; m_i + 1 < m_block_end; m_i += 2) { + // Load A values for 2 rows + __m512bh ma0_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]); + __m512bh ma1_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]); + __m512bh ma2_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]); + __m512bh ma3_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]); + __m512bh ma0_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + k_i * 2]); + __m512bh ma1_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 1) * 2]); + __m512bh ma2_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 2) * 2]); + __m512bh ma3_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 3) * 2]); + + // Process row 0, then row 1 - sequential to avoid dependencies + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0_0, bbf16_0_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0_0, bbf16_0_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1_0, bbf16_1_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1_0, bbf16_1_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2_0, bbf16_2_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2_0, bbf16_2_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3_0, bbf16_3_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3_0, bbf16_3_hi); + + c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma0_1, bbf16_0_lo); + c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma0_1, bbf16_0_hi); + c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma1_1, bbf16_1_lo); + c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma1_1, bbf16_1_hi); + c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma2_1, bbf16_2_lo); + c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma2_1, bbf16_2_hi); + c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma3_1, bbf16_3_lo); + c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma3_1, bbf16_3_hi); + } + // Handle remaining row + for (; m_i < m_block_end; m_i++) { + __m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]); + __m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]); + __m512bh ma2 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]); + __m512bh ma3 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]); + + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_0_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_1_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2, bbf16_2_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2, bbf16_2_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3, bbf16_3_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3, bbf16_3_hi); + } + } + } + } + + static void apply_scale_kgroup(int m, int n, int m_begin, int n_begin, int k_block_begin, float* c, float* reduce_c, + BufferA* ba, BufferB* bb, int k, int k_group_size) { + using K = GemmKernel224FP8; + int to = std::min(m - m_begin, K::M_STEP); + + for (int i = 0; i < to; i++) { + // Get scale for this k_group + __m512 bs = _mm512_set1_ps(*bb->get_scale(n, n_begin, k, k_block_begin)); + __m512 now = _mm512_load_ps(reduce_c + i * K::N_STEP); + __m512 result = _mm512_mul_ps(now, bs); + __m512 existing = _mm512_load_ps(c + i * K::N_STEP); + result = _mm512_add_ps(result, existing); + _mm512_store_ps(c + i * K::N_STEP, result); + + now = _mm512_load_ps(reduce_c + i * K::N_STEP + K::TILE_N); + result = _mm512_mul_ps(now, bs); + existing = _mm512_load_ps(c + i * K::N_STEP + K::TILE_N); + result = _mm512_add_ps(result, existing); + _mm512_store_ps(c + i * K::N_STEP + K::TILE_N, result); + } + } +}; + +// all step = 32 +template +void float_mat_vec_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb, + typename K::BufferC* bc, int ith, int nth) { + assert(n % K::N_STEP == 0); + assert(k % k_group_size == 0); + assert(k_group_size % K::K_STEP == 0); + + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + + // Process by k_groups + for (int k_group_begin = 0; k_group_begin < k; k_group_begin += k_group_size) { + for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) { + for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) { + float* c = bc->get_submat(m, n, m_begin, n_begin); + float* reduce_c = bc->get_reduce_submat(m, n, m_begin, n_begin); + + if (k_group_begin == 0) { + for (int i = 0; i < K::M_STEP && m_begin + i < m; i++) { + for (int j = 0; j < K::N_STEP; j++) { + c[i * K::N_STEP + j] = 0.0f; + } + } + } + + // avx_kernel_4 now processes entire k_group_size internally (like INT8's avx_kernel) + if constexpr (amx_or_avx && AMX_AVAILABLE) { + for (int k_begin = k_group_begin; k_begin < std::min(k, k_group_begin + k_group_size); k_begin += K::K_STEP) { + K::amx_kernel(m, n, k, m_begin, n_begin, k_begin, reduce_c, ba, bb, k_group_size); + } + } else { + // Single call processes entire k_group + K::avx_kernel_4(m, n, k, m_begin, n_begin, k_group_begin, reduce_c, ba, bb, k_group_size); + } + K::apply_scale_kgroup(m, n, m_begin, n_begin, k_group_begin, c, reduce_c, ba, bb, k, k_group_size); + } + } + } +} + +// inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr ba, +// std::shared_ptr bb, +// std::shared_ptr bc, int ith, int nth) { +// float_mat_mul_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth); +// } + +// inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr ba, +// std::shared_ptr bb, +// std::shared_ptr bc, int ith, int nth) { +// float_mat_mul_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth); +// } + +inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr ba, + std::shared_ptr bb, std::shared_ptr bc, + int ith, int nth) { + float_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth); +} + +inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr ba, + std::shared_ptr bb, std::shared_ptr bc, + int ith, int nth) { + float_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth); +} + +} // namespace amx + +#endif // AMX_RAW_KERNELS_HPP diff --git a/kt-kernel/operators/amx/moe.hpp b/kt-kernel/operators/amx/moe.hpp index cbb168e..168b04b 100644 --- a/kt-kernel/operators/amx/moe.hpp +++ b/kt-kernel/operators/amx/moe.hpp @@ -11,30 +11,27 @@ #define CPUINFER_OPERATOR_AMX_MOE_H // #define CHECK - -#include -#include -#include // #define FORWARD_TIME_PROFILE // #define FORWARD_TIME_REPORT -#include -#include -#include -#include -#include -#include - -#include "../../cpu_backend/shared_mem_buffer.h" -#include "../../cpu_backend/worker_pool.h" -#include "../moe-tp.hpp" -#include "la/amx.hpp" -#include "llama.cpp/ggml.h" +#include "moe_base.hpp" template -class AMX_MOE_TP { +class AMX_MOE_TP : public AMX_MOE_BASE> { private: - int tp_part_idx; + using Base = AMX_MOE_BASE>; + using Base::config_; + using Base::tp_part_idx; + using Base::gate_bb_; + using Base::up_bb_; + using Base::down_bb_; + using Base::gate_up_ba_; + using Base::gate_bc_; + using Base::up_bc_; + using Base::down_ba_; + using Base::down_bc_; + using Base::m_local_num_; + std::filesystem::path prefix; void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if @@ -44,27 +41,6 @@ class AMX_MOE_TP { void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if // quantized)] - ggml_bf16_t* m_local_input_; // [num_experts_per_tok * max_len * hidden_size] - ggml_bf16_t* m_local_gate_output_; // [num_experts_per_tok * max_len * intermediate_size] - ggml_bf16_t* m_local_up_output_; // [num_experts_per_tok * max_len * intermediate_size] - ggml_bf16_t* m_local_down_output_; // [num_experts_per_tok * max_len * hidden_size] - - std::vector> m_local_pos_; // [max_len, num_experts_per_tok] - std::vector m_local_num_; // [expert_num] - std::vector m_expert_id_map_; // [expert_num] - std::vector m_local_input_ptr_; // [expert_num] - std::vector m_local_gate_output_ptr_; // [expert_num] - std::vector m_local_up_output_ptr_; // [expert_num] - std::vector m_local_down_output_ptr_; // [expert_num] - - std::vector> gate_up_ba_; - std::vector> gate_bb_; - std::vector> gate_bc_; - std::vector> up_bb_; - std::vector> up_bc_; - std::vector> down_ba_; - std::vector> down_bb_; - std::vector> down_bc_; #ifdef CHECK char verify_bb[100000000]; char check_bb[100000000]; @@ -161,21 +137,15 @@ class AMX_MOE_TP { #endif public: - using input_t = ggml_bf16_t; - using output_t = float; - GeneralMOEConfig config_; - static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE; + AMX_MOE_TP() = default; - AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx) { + AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx = 0) : Base(config, tp_part_idx) { printf("Creating AMX_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu())); - auto& load = config.load; - auto& save = config.save; - if (load && config.path == "") { - load = false; - } + auto& load = config_.load; + auto& save = config_.save; - prefix = config.path; - prefix = prefix / ("_layer_" + std::to_string(config.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx)); + prefix = config_.path; + prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx)); if (save) { std::cout << "Creating " << prefix << std::endl; std::filesystem::create_directories(prefix); @@ -188,78 +158,65 @@ class AMX_MOE_TP { } } - this->tp_part_idx = tp_part_idx; - config_ = config; gate_proj_ = config_.gate_proj; up_proj_ = config_.up_proj; down_proj_ = config_.down_proj; - - MemoryRequest mem_requests; - mem_requests.append_pointer( - &m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size); - mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * - config_.max_len * config_.intermediate_size); - mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * - config_.max_len * config_.intermediate_size); - mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * - config_.max_len * config_.hidden_size); - - m_local_pos_.resize(config_.max_len); - for (int i = 0; i < config_.max_len; i++) { - m_local_pos_[i].resize(config_.num_experts_per_tok); - } - m_expert_id_map_.resize(config_.expert_num); - m_local_num_.resize(config_.expert_num); - m_local_input_ptr_.resize(config_.expert_num); - m_local_gate_output_ptr_.resize(config_.expert_num); - m_local_up_output_ptr_.resize(config_.expert_num); - m_local_down_output_ptr_.resize(config_.expert_num); - - // printf("tp part %d alloc layer %d, %f GB, on numa %d\n", tp_part_idx, config_.layer_idx, - // 1e-9 * config_.expert_num * - // (T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) * 2 + - // T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)), - // numa_node_of_cpu(sched_getcpu())); - - for (size_t i = 0; i < config_.expert_num; i++) { - gate_up_ba_.push_back(std::make_shared(config_.max_len, config_.hidden_size, nullptr)); - gate_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); - up_bc_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); - down_ba_.push_back(std::make_shared(config_.max_len, config_.intermediate_size, nullptr)); - down_bc_.push_back(std::make_shared(config_.max_len, config_.hidden_size, nullptr)); - - void* gate_bb_ptr = - std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size)); - gate_bb_.push_back( - std::make_shared(config_.intermediate_size, config_.hidden_size, gate_bb_ptr)); - - void* up_bb_ptr = - std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size)); - up_bb_.push_back( - std::make_shared(config_.intermediate_size, config_.hidden_size, up_bb_ptr)); - - void* down_bb_ptr = - std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)); - down_bb_.push_back( - std::make_shared(config_.hidden_size, config_.intermediate_size, down_bb_ptr)); - } - for (int i = 0; i < config_.expert_num; i++) { - mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); }, - T::BufferA::required_size(config_.max_len, config_.hidden_size)); - mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); }, - T::BufferC::required_size(config_.max_len, config_.intermediate_size)); - mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); }, - T::BufferC::required_size(config_.max_len, config_.intermediate_size)); - mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); }, - T::BufferA::required_size(config_.max_len, config_.intermediate_size)); - mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); }, - T::BufferC::required_size(config_.max_len, config_.hidden_size)); - } - shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests); } - ~AMX_MOE_TP() { - // shared_mem_buffer_numa.dealloc(this); + ~AMX_MOE_TP() = default; + + // ============================================================================ + // CRTP buffer creation - no group_size + // ============================================================================ + + size_t buffer_a_required_size_impl(size_t m, size_t k) const { + return T::BufferA::required_size(m, k); + } + size_t buffer_b_required_size_impl(size_t n, size_t k) const { + return T::BufferB::required_size(n, k); + } + size_t buffer_c_required_size_impl(size_t m, size_t n) const { + return T::BufferC::required_size(m, n); + } + + std::shared_ptr make_buffer_a_impl(size_t m, size_t k, void* data) const { + return std::make_shared(m, k, data); + } + std::shared_ptr make_buffer_b_impl(size_t n, size_t k, void* data) const { + return std::make_shared(n, k, data); + } + std::shared_ptr make_buffer_c_impl(size_t m, size_t n, void* data) const { + return std::make_shared(m, n, data); + } + + // ============================================================================ + // CRTP virtual points - GEMM dispatch + // ============================================================================ + + void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) { + int m = m_local_num_[expert_idx]; + auto& ba = gate_up_ba_[expert_idx]; + auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx]; + auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx]; + + if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { + amx::mat_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth); + } else { + amx::vec_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth); + } + } + + void do_down_gemm(int expert_idx, int ith, int nth, int qlen) { + int m = m_local_num_[expert_idx]; + auto& ba = down_ba_[expert_idx]; + auto& bb = down_bb_[expert_idx]; + auto& bc = down_bc_[expert_idx]; + + if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { + amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth); + } else { + amx::vec_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth); + } } void load_weights() { auto pool = config_.pool->get_subpool(tp_part_idx); @@ -401,434 +358,21 @@ class AMX_MOE_TP { } } - void warm_up() { - int qlen = config_.max_len; - std::vector input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); - std::vector output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); - std::vector expert_ids(qlen * config_.num_experts_per_tok); - std::vector weights(qlen * config_.num_experts_per_tok); - for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) { - expert_ids[i] = i % config_.expert_num; - weights[i] = 0.01; - } - forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data()); - } - - void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { - if (qlen > 1) { - forward_prefill(qlen, k, expert_ids, weights, input, output); - } else { - forward_decode(k, expert_ids, weights, input, output); - } - } - -#define DIRECT_OR_POOL_BY_QLEN(var, fn) \ - do { \ - if (qlen < 10) { \ - for (int i = 0; i < (var); i++) { \ - (fn)(i); \ - } \ - } else { \ - pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \ - } \ - } while (0) - -#define MATMUL_OR_VECMUL_BY_QLEN(...) \ - do { \ - if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { \ - amx::mat_mul(__VA_ARGS__); \ - } else { \ - amx::vec_mul(__VA_ARGS__); \ - } \ - } while (0) - - void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, - void* output) { - auto pool = config_.pool->get_subpool(tp_part_idx); -#ifdef FORWARD_TIME_PROFILE - auto start_time = std::chrono::high_resolution_clock::now(); - auto last = start_time; - // 用于保存各阶段耗时(单位:微秒) - long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0; - long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; - int max_local_num = 0; // 记录最大的 local num -#endif - - int activated_expert = 0; - for (int i = 0; i < config_.expert_num; i++) { - m_local_num_[i] = 0; - } - for (int i = 0; i < qlen; i++) { - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; - } - } - - for (int i = 0; i < config_.expert_num; i++) { - if (m_local_num_[i] > 0) { -#ifdef FORWARD_TIME_PROFILE - max_local_num = std::max(max_local_num, m_local_num_[i]); -#endif - m_expert_id_map_[activated_expert] = i; - activated_expert++; - } - } - - // activated_expert 已经统计完成 - - size_t offset = 0; - for (int i = 0; i < config_.expert_num; i++) { - m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size; - m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; - m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; - m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; - offset += m_local_num_[i]; - } -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - prepare_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) { - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size, - (ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); - } - }); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - cpy_input_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) { - int expert_idx = m_expert_id_map_[task_id]; - gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); - }); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_input_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - int nth = T::recommended_nth(config_.intermediate_size); - pool->do_work_stealing_job( - nth * activated_expert * 2, [](int _) { T::config(); }, - [this, nth, qlen](int task_id2) { - int task_id = task_id2 / 2; - bool do_up = task_id2 % 2; - int expert_idx = m_expert_id_map_[task_id / nth]; - - int ith = task_id % nth; - if (do_up) { - MATMUL_OR_VECMUL_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, - gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth); - up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); - } else { - MATMUL_OR_VECMUL_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size, - gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth); - gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); - } - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - up_gate_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - auto up_gate_fn = [this, nth](int task_id) { - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); - for (int i = 0; i < m_local_num_[expert_idx]; i++) { - ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; - ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; - for (int j = n_start; j < n_end; j += 32) { - __m512 gate_val0, gate_val1, up_val0, up_val1; - avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); - avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); - __m512 result0 = amx::act_fn(gate_val0, up_val0); - __m512 result1 = amx::act_fn(gate_val1, up_val1); - avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); - } - } - }; - DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - act_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - pool->do_work_stealing_job( - activated_expert, nullptr, - [this](int task_id) { - int expert_idx = m_expert_id_map_[task_id]; - down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1); - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - nth = T::recommended_nth(config_.hidden_size); - pool->do_work_stealing_job( - nth * activated_expert, [](int _) { T::config(); }, - [this, nth, qlen](int task_id) { - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - MATMUL_OR_VECMUL_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, - down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); - down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth); - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - pool->do_work_stealing_job( - qlen, nullptr, - [this, nth, output, k, expert_ids, weights](int i) { - for (int e = 0; e < config_.hidden_size; e += 32) { - __m512 x0 = _mm512_setzero_ps(); - __m512 x1 = _mm512_setzero_ps(); - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - __m512 weight = _mm512_set1_ps(weights[i * k + j]); - __m512 down_output0, down_output1; - avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] + - m_local_pos_[i][j] * config_.hidden_size + e), - &down_output0, &down_output1); - x0 = _mm512_fmadd_ps(down_output0, weight, x0); - x1 = _mm512_fmadd_ps(down_output1, weight, x1); - } - auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e); - f32out[0] = x0; - f32out[1] = x1; - } - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - weight_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } - auto end_time = std::chrono::high_resolution_clock::now(); - auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); - // 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen - printf( - "Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, " - "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: " - "%d, qlen: %d\n", - tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time, - down_time, weight_time, forward_total_time, max_local_num, qlen); -#endif - } - - void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { - int qlen = 1; - auto pool = config_.pool->get_subpool(tp_part_idx); -#ifdef FORWARD_TIME_PROFILE - auto start_time = std::chrono::high_resolution_clock::now(); - auto last = start_time; - // 用于保存各阶段耗时(单位:微秒) - long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0; - long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; - int max_local_num = 0; // 记录最大的 local num -#endif - - int activated_expert = 0; - for (int i = 0; i < k; i++) { - if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) { - continue; - } - m_expert_id_map_[activated_expert] = expert_ids[i]; - activated_expert++; - } - - size_t offset = 0; - for (int i = 0; i < activated_expert; i++) { - auto expert_idx = m_expert_id_map_[i]; - m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size; - m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size; - m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size; - offset += qlen; - } - - gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1); - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_input_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - int nth = T::recommended_nth(config_.intermediate_size); - pool->do_work_stealing_job( - nth * activated_expert * 2, [](int _) { T::config(); }, - [this, nth, qlen](int task_id2) { - int task_id = task_id2 / 2; - bool do_up = task_id2 % 2; - int expert_idx = m_expert_id_map_[task_id / nth]; - - int ith = task_id % nth; - if (do_up) { - amx::vec_mul(qlen, config_.intermediate_size, config_.hidden_size, gate_up_ba_[0], up_bb_[expert_idx], - up_bc_[expert_idx], ith, nth); - up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth); - } else { - amx::vec_mul(qlen, config_.intermediate_size, config_.hidden_size, gate_up_ba_[0], gate_bb_[expert_idx], - gate_bc_[expert_idx], ith, nth); - gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth); - } - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - up_gate_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - for (int task_id = 0; task_id < nth * activated_expert; task_id++) { - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); - for (int i = 0; i < qlen; i++) { - ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; - ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; - for (int j = n_start; j < n_end; j += 32) { - __m512 gate_val0, gate_val1, up_val0, up_val1; - avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); - avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); - __m512 result0 = amx::act_fn(gate_val0, up_val0); - __m512 result1 = amx::act_fn(gate_val1, up_val1); - avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); - } - } - } - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - act_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - pool->do_work_stealing_job( - activated_expert, nullptr, - [this, qlen](int task_id) { - int expert_idx = m_expert_id_map_[task_id]; - down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1); - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - q_down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - - nth = T::recommended_nth(config_.hidden_size); - pool->do_work_stealing_job( - nth * activated_expert, [](int _) { T::config(); }, - [this, nth, qlen](int task_id) { - int expert_idx = m_expert_id_map_[task_id / nth]; - int ith = task_id % nth; - amx::vec_mul(qlen, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_[expert_idx], - down_bc_[expert_idx], ith, nth); - down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth); - }, - nullptr); -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - down_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } -#endif - for (int i = 0; i < qlen; i++) { - for (int e = 0; e < config_.hidden_size; e += 32) { - __m512 x0 = _mm512_setzero_ps(); - __m512 x1 = _mm512_setzero_ps(); - for (int j = 0; j < k; j++) { - if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { - continue; - } - __m512 weight = _mm512_set1_ps(weights[i * k + j]); - __m512 down_output0, down_output1; - avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] + - m_local_pos_[i][j] * config_.hidden_size + e), - &down_output0, &down_output1); - x0 = _mm512_fmadd_ps(down_output0, weight, x0); - x1 = _mm512_fmadd_ps(down_output1, weight, x1); - } - auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e); - f32out[0] = x0; - f32out[1] = x1; - } - } - -#ifdef FORWARD_TIME_PROFILE - { - auto now_time = std::chrono::high_resolution_clock::now(); - weight_time = std::chrono::duration_cast(now_time - last).count(); - last = now_time; - } - auto end_time = std::chrono::high_resolution_clock::now(); - auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); - // 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen - printf( - "Profiling Results (numa[%d]) decode: activated_expert: %d, q_input: %ld us, " - "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n", - tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time, - forward_total_time); -#endif - } + // forward, forward_prefill, forward_decode, warm_up are inherited from Base }; +// ============================================================================ +// TP_MOE specialization for AMX_MOE_TP +// Inherits from TP_MOE> to reuse merge_results implementation +// ============================================================================ + template -class TP_MOE> : public TP_MOE_Common> { +class TP_MOE> : public TP_MOE>> { public: - using TP_MOE_Common>::TP_MOE_Common; - void load_weights() { + using Base = TP_MOE>>; + using Base::Base; + + void load_weights() override { auto& config = this->config; auto& tps = this->tps; auto& tp_count = this->tp_count; @@ -836,7 +380,6 @@ class TP_MOE> : public TP_MOE_Common> { const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map; if (config.gate_projs.empty() == false) { printf("TP Load from loader\n"); - // pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); }); DO_TPS_LOAD_WEIGHTS(pool); this->weights_loaded = true; } else if (config.gate_proj != nullptr) { @@ -872,7 +415,6 @@ class TP_MOE> : public TP_MOE_Common> { } } - // pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); }); DO_TPS_LOAD_WEIGHTS(pool); for (auto i = 0; i < tp_count; i++) { @@ -885,7 +427,6 @@ class TP_MOE> : public TP_MOE_Common> { this->weights_loaded = true; } else if (config.path != "") { printf("TP Load from file\n"); - // pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); }); DO_TPS_LOAD_WEIGHTS(pool); this->weights_loaded = true; } else { @@ -893,37 +434,7 @@ class TP_MOE> : public TP_MOE_Common> { } } - void merge_results(int qlen, void* output, bool incremental) { - auto pool = this->config.pool; - auto merge_fn = [this, output, incremental](int token_nth) { - auto& local_output_numa = this->local_output_numa; - auto& tp_configs = this->tp_configs; - auto& tp_count = this->tp_count; - auto& config = this->config; - float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size; - if (incremental) { - for (int e = 0; e < config.hidden_size; e += 32) { - __m512 x0, x1; - avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1); - *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0); - *((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1); - } - } - for (int i = 1; i < tp_count; i++) { - float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size; - for (int e = 0; e < tp_configs[i].hidden_size; e += 16) { - *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e))); - } - } - for (int e = 0; e < config.hidden_size; e += 32) { - __m512 x0 = *(__m512*)(merge_to + e); - __m512 x1 = *(__m512*)(merge_to + e + 16); - avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e)); - } - }; - DIRECT_OR_POOL_BY_QLEN(qlen, merge_fn); - } - void merge_results(int qlen, void* output) { merge_results(qlen, output, false); } + // merge_results is inherited from TP_MOE>> }; #endif diff --git a/kt-kernel/operators/amx/moe_base.hpp b/kt-kernel/operators/amx/moe_base.hpp new file mode 100644 index 0000000..e1bb093 --- /dev/null +++ b/kt-kernel/operators/amx/moe_base.hpp @@ -0,0 +1,763 @@ +/** + * @Description : Common AMX MoE base class extracted from K2 implementation. + * @Author : oql, Codex and Claude + * @Date : 2025-12-09 + * @Version : 0.1.0 + * @LastEditors : oql, Codex and Claude + * @LastEditTime : 2025-12-09 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + **/ +#ifndef CPUINFER_OPERATOR_AMX_MOE_BASE_H +#define CPUINFER_OPERATOR_AMX_MOE_BASE_H + +// #define FORWARD_TIME_PROFILE + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../cpu_backend/shared_mem_buffer.h" +#include "../../cpu_backend/worker_pool.h" +#include "../common.hpp" +#include "../moe-tp.hpp" +#include "la/amx.hpp" +#include "llama.cpp/ggml.h" + +template +class AMX_MOE_BASE { + public: + int tp_part_idx = 0; + + ggml_bf16_t* m_local_input_ = nullptr; + ggml_bf16_t* m_local_gate_output_ = nullptr; + ggml_bf16_t* m_local_up_output_ = nullptr; + ggml_bf16_t* m_local_down_output_ = nullptr; + + std::vector> m_local_pos_; + std::vector m_local_num_; + std::vector m_expert_id_map_; + std::vector m_local_input_ptr_; + std::vector m_local_gate_output_ptr_; + std::vector m_local_up_output_ptr_; + std::vector m_local_down_output_ptr_; + + std::vector> gate_up_ba_; + std::vector> gate_bb_; + std::vector> gate_bc_; + std::vector> up_bb_; + std::vector> up_bc_; + std::vector> down_ba_; + std::vector> down_bb_; + std::vector> down_bc_; + + size_t pool_count_ = 0; + size_t gate_up_ba_pool_bytes_ = 0; + size_t gate_bc_pool_bytes_ = 0; + size_t up_bc_pool_bytes_ = 0; + size_t down_ba_pool_bytes_ = 0; + size_t down_bc_pool_bytes_ = 0; + void* gate_up_ba_pool_ = nullptr; + void* gate_bc_pool_ = nullptr; + void* up_bc_pool_ = nullptr; + void* down_ba_pool_ = nullptr; + void* down_bc_pool_ = nullptr; + + GeneralMOEConfig config_; + using input_t = ggml_bf16_t; + using output_t = float; + static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE; + + AMX_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) { init(); } + + void init() { + if (config_.load && config_.path == "") { + config_.load = false; + } + + MemoryRequest mem_requests; + mem_requests.append_pointer( + &m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size); + mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * + config_.max_len * config_.intermediate_size); + mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * + config_.max_len * config_.intermediate_size); + mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * + config_.max_len * config_.hidden_size); + + m_local_pos_.resize(config_.max_len); + for (int i = 0; i < config_.max_len; i++) { + m_local_pos_[i].resize(config_.num_experts_per_tok); + } + m_expert_id_map_.resize(config_.expert_num); + m_local_num_.resize(config_.expert_num); + m_local_input_ptr_.resize(config_.expert_num); + m_local_gate_output_ptr_.resize(config_.expert_num); + m_local_up_output_ptr_.resize(config_.expert_num); + m_local_down_output_ptr_.resize(config_.expert_num); + + for (size_t i = 0; i < config_.expert_num; i++) { + gate_up_ba_.push_back(make_buffer_a(config_.max_len, config_.hidden_size, nullptr)); + gate_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr)); + up_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr)); + down_ba_.push_back(make_buffer_a(config_.max_len, config_.intermediate_size, nullptr)); + down_bc_.push_back(make_buffer_c(config_.max_len, config_.hidden_size, nullptr)); + + void* gate_bb_ptr = + std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size)); + gate_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, gate_bb_ptr)); + + void* up_bb_ptr = std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size)); + up_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, up_bb_ptr)); + + void* down_bb_ptr = + std::aligned_alloc(64, buffer_b_required_size(config_.hidden_size, config_.intermediate_size)); + down_bb_.push_back(make_buffer_b(config_.hidden_size, config_.intermediate_size, down_bb_ptr)); + } + // TODO: need update to all *.hpp + // (config_.expert_num * T::M_STEP) in pool_count_ is to ensure padding for each experts. + pool_count_ = config_.max_len * config_.num_experts_per_tok + config_.expert_num * T::M_STEP; + + gate_up_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64; + gate_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64; + up_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64; + down_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64; + down_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64; + + mem_requests.append_pointer(&gate_up_ba_pool_, gate_up_ba_pool_bytes_); + mem_requests.append_pointer(&gate_bc_pool_, gate_bc_pool_bytes_); + mem_requests.append_pointer(&up_bc_pool_, up_bc_pool_bytes_); + mem_requests.append_pointer(&down_ba_pool_, down_ba_pool_bytes_); + mem_requests.append_pointer(&down_bc_pool_, down_bc_pool_bytes_); + + shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests); + } + + ~AMX_MOE_BASE() = default; + + void warm_up() { + int qlen = config_.max_len; + std::vector input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); + std::vector output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size); + std::vector expert_ids(qlen * config_.num_experts_per_tok); + std::vector weights(qlen * config_.num_experts_per_tok); + for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) { + expert_ids[i] = i % config_.expert_num; + weights[i] = 0.01; + } + forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data()); + } + + void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { + if (qlen > 1) { + forward_prefill(qlen, k, expert_ids, weights, input, output); + } else { + forward_decode(k, expert_ids, weights, input, output); + } + } + + template + void load_weights(Args&&... args) { + derived()->load_weights(std::forward(args)...); + } + + template + void write_weights_to_buffer(Args&&... args) const { + derived_const()->write_weights_to_buffer(std::forward(args)...); + } + + void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, + void* output) { + auto pool = config_.pool->get_subpool(tp_part_idx); +#ifdef FORWARD_TIME_PROFILE + auto start_time = std::chrono::high_resolution_clock::now(); + auto last = start_time; + long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0; + long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; + int max_local_num = 0; +#endif + + int activated_expert = 0; + std::fill(m_local_num_.begin(), m_local_num_.end(), 0); + for (int i = 0; i < qlen; i++) { + for (int j = 0; j < k; j++) { + if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { + continue; + } + m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++; + } + } + + for (int i = 0; i < config_.expert_num; i++) { + if (m_local_num_[i] > 0) { +#ifdef FORWARD_TIME_PROFILE + max_local_num = std::max(max_local_num, m_local_num_[i]); +#endif + m_expert_id_map_[activated_expert] = i; + activated_expert++; + } + } + + size_t offset = 0; + void* gate_up_ba_pool_ptr = gate_up_ba_pool_; + void* gate_bc_pool_ptr = gate_bc_pool_; + void* up_bc_pool_ptr = up_bc_pool_; + void* down_ba_pool_ptr = down_ba_pool_; + void* down_bc_pool_ptr = down_bc_pool_; + constexpr size_t M_STEP = T::M_STEP; + auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); }; + size_t used_pool_m = 0; + size_t used_pool_bytes_a = 0, used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0, + used_pool_bytes_bc_down = 0; + + for (int i = 0; i < config_.expert_num; i++) { + m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size; + m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size; + m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size; + m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size; + offset += m_local_num_[i]; + + if (m_local_num_[i] == 0) { + continue; + } + + size_t max_m = (m_local_num_[i] + M_STEP - 1) / M_STEP * M_STEP; + gate_up_ba_[i]->max_m = max_m; + gate_up_ba_[i]->set_data(gate_up_ba_pool_ptr); + size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size)); + gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size); + + gate_bc_[i]->max_m = max_m; + gate_bc_[i]->set_data(gate_bc_pool_ptr); + size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size)); + gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size); + + up_bc_[i]->max_m = max_m; + up_bc_[i]->set_data(up_bc_pool_ptr); + size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size)); + up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size); + + down_ba_[i]->max_m = max_m; + down_ba_[i]->set_data(down_ba_pool_ptr); + size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size)); + down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size); + + down_bc_[i]->max_m = max_m; + down_bc_[i]->set_data(down_bc_pool_ptr); + size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size)); + down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size); + + used_pool_m += max_m; + used_pool_bytes_a += ba_size; + used_pool_bytes_bc_gate += bc_gate_size; + used_pool_bytes_bc_up += bc_up_size; + used_pool_bytes_ba_down += ba_down_size; + used_pool_bytes_bc_down += bc_down_size; + } + + assert(used_pool_m <= pool_count_); + assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_); + assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_); + assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_); + assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_); + assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + prepare_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + auto direct_or_pool = [&](int count, auto&& fn) { + if (qlen < 10) { + for (int i = 0; i < count; i++) { + fn(i); + } + } else { + pool->do_work_stealing_job(count, nullptr, fn, nullptr); + } + }; + + direct_or_pool(qlen, [&](int i) { + for (int j = 0; j < k; j++) { + if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { + continue; + } + memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size, + (ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); + } + }); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + cpy_input_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + direct_or_pool(activated_expert, [this](int task_id) { + int expert_idx = m_expert_id_map_[task_id]; + gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); + }); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + q_input_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + int nth = T::recommended_nth(config_.intermediate_size); + pool->do_work_stealing_job( + nth * activated_expert * 2, [](int _) { T::config(); }, + [this, nth, qlen](int task_id2) { + int task_id = task_id2 / 2; + bool do_up = task_id2 % 2; + int expert_idx = m_expert_id_map_[task_id / nth]; + + int ith = task_id % nth; + derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen); + if (do_up) { + up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth); + } else { + gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth); + } + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + up_gate_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + apply_activation(activated_expert, nth, qlen); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + act_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + pool->do_work_stealing_job( + activated_expert, nullptr, + [this](int task_id) { + int expert_idx = m_expert_id_map_[task_id]; + down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1); + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + q_down_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + nth = T::recommended_nth(config_.hidden_size); + pool->do_work_stealing_job( + nth * activated_expert, [](int _) { T::config(); }, + [this, nth, qlen](int task_id) { + int expert_idx = m_expert_id_map_[task_id / nth]; + int ith = task_id % nth; + derived()->do_down_gemm(expert_idx, ith, nth, qlen); + down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth); + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + down_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + pool->do_work_stealing_job( + qlen, nullptr, + [this, output, k, expert_ids, weights](int i) { + for (int e = 0; e < config_.hidden_size; e += 32) { + __m512 x0 = _mm512_setzero_ps(); + __m512 x1 = _mm512_setzero_ps(); + for (int j = 0; j < k; j++) { + if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) { + continue; + } + __m512 weight = _mm512_set1_ps(weights[i * k + j]); + __m512 down_output0, down_output1; + avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] + + m_local_pos_[i][j] * config_.hidden_size + e), + &down_output0, &down_output1); + x0 = _mm512_fmadd_ps(down_output0, weight, x0); + x1 = _mm512_fmadd_ps(down_output1, weight, x1); + } + auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e); + f32out[0] = x0; + f32out[1] = x1; + } + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + weight_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } + auto end_time = std::chrono::high_resolution_clock::now(); + auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); + printf( + "Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, " + "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: " + "%d, qlen: %d\n", + tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time, + down_time, weight_time, forward_total_time, max_local_num, qlen); +#endif + } + + void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) { + int qlen = 1; + auto pool = config_.pool->get_subpool(tp_part_idx); +#ifdef FORWARD_TIME_PROFILE + auto start_time = std::chrono::high_resolution_clock::now(); + auto last = start_time; + long q_input_time = 0, up_gate_time = 0, act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0; +#endif + + int activated_expert = 0; + std::fill(m_local_num_.begin(), m_local_num_.end(), 0); + for (int i = 0; i < k; i++) { + if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) { + continue; + } + m_expert_id_map_[activated_expert] = expert_ids[i]; + m_local_pos_[0][i] = 0; + m_local_num_[expert_ids[i]] = qlen; + activated_expert++; + } + + size_t offset = 0; + for (int i = 0; i < activated_expert; i++) { + auto expert_idx = m_expert_id_map_[i]; + m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size; + m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size; + m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size; + offset += qlen; + } + + void* gate_bc_pool_ptr = gate_bc_pool_; + void* up_bc_pool_ptr = up_bc_pool_; + void* down_ba_pool_ptr = down_ba_pool_; + void* down_bc_pool_ptr = down_bc_pool_; + constexpr size_t M_STEP = T::M_STEP; + auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); }; + size_t used_pool_m = 0; + size_t used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0, + used_pool_bytes_bc_down = 0; + for (int i = 0; i < activated_expert; i++) { + auto expert_idx = m_expert_id_map_[i]; + size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP; + + gate_bc_[expert_idx]->max_m = max_m; + gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr); + size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size)); + gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size); + + up_bc_[expert_idx]->max_m = max_m; + up_bc_[expert_idx]->set_data(up_bc_pool_ptr); + size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size)); + up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size); + + down_ba_[expert_idx]->max_m = max_m; + down_ba_[expert_idx]->set_data(down_ba_pool_ptr); + size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size)); + down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size); + + down_bc_[expert_idx]->max_m = max_m; + down_bc_[expert_idx]->set_data(down_bc_pool_ptr); + size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size)); + down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size); + + used_pool_m += max_m; + used_pool_bytes_bc_gate += bc_gate_size; + used_pool_bytes_bc_up += bc_up_size; + used_pool_bytes_ba_down += ba_down_size; + used_pool_bytes_bc_down += bc_down_size; + } + assert(used_pool_m <= pool_count_); + assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_); + assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_); + assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_); + assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_); + + void* gate_up_ba_pool_ptr = gate_up_ba_pool_; + for (int i = 0; i < activated_expert; i++) { + auto expert_idx = m_expert_id_map_[i]; + size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP; + gate_up_ba_[expert_idx]->max_m = max_m; + gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr); + size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size)); + gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size); + gate_up_ba_[expert_idx]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1); + } + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + q_input_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + int nth = T::recommended_nth(config_.intermediate_size); + pool->do_work_stealing_job( + nth * activated_expert * 2, [](int _) { T::config(); }, + [this, nth, qlen](int task_id2) { + int task_id = task_id2 / 2; + bool do_up = task_id2 % 2; + int expert_idx = m_expert_id_map_[task_id / nth]; + + int ith = task_id % nth; + derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen); + if (do_up) { + up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth); + } else { + gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth); + } + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + up_gate_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + apply_activation(activated_expert, nth, qlen); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + act_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + pool->do_work_stealing_job( + activated_expert, nullptr, + [this, qlen](int task_id) { + int expert_idx = m_expert_id_map_[task_id]; + down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1); + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + q_down_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + nth = T::recommended_nth(config_.hidden_size); + pool->do_work_stealing_job( + nth * activated_expert, [](int _) { T::config(); }, + [this, nth, qlen](int task_id) { + int expert_idx = m_expert_id_map_[task_id / nth]; + int ith = task_id % nth; + derived()->do_down_gemm(expert_idx, ith, nth, qlen); + down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth); + }, + nullptr); + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + down_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } +#endif + + for (int e = 0; e < config_.hidden_size; e += 32) { + __m512 x0 = _mm512_setzero_ps(); + __m512 x1 = _mm512_setzero_ps(); + for (int j = 0; j < k; j++) { + if (expert_ids[j] < config_.num_gpu_experts || expert_ids[j] >= config_.expert_num) { + continue; + } + __m512 weight = _mm512_set1_ps(weights[j]); + __m512 down_output0, down_output1; + avx512_32xbf16_to_32xfp32( + (__m512i*)(m_local_down_output_ptr_[expert_ids[j]] + m_local_pos_[0][j] * config_.hidden_size + e), + &down_output0, &down_output1); + x0 = _mm512_fmadd_ps(down_output0, weight, x0); + x1 = _mm512_fmadd_ps(down_output1, weight, x1); + } + auto f32out = (__m512*)((float*)output + e); + f32out[0] = x0; + f32out[1] = x1; + } + +#ifdef FORWARD_TIME_PROFILE + { + auto now_time = std::chrono::high_resolution_clock::now(); + weight_time = std::chrono::duration_cast(now_time - last).count(); + last = now_time; + } + auto end_time = std::chrono::high_resolution_clock::now(); + auto forward_total_time = std::chrono::duration_cast(end_time - start_time).count(); + printf( + "Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, " + "up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n", + tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time, + forward_total_time); +#endif + } + + protected: + Derived* derived() { return static_cast(this); } + const Derived* derived_const() const { return static_cast(this); } + + // ============================================================================ + // Virtual points for buffer creation and size calculation + // Default implementations use group_size (for KGroup quantization like K2) + // Derived classes (like moe.hpp) can override to not use group_size + // ============================================================================ + + size_t buffer_a_required_size(size_t m, size_t k) const { return derived_const()->buffer_a_required_size_impl(m, k); } + size_t buffer_b_required_size(size_t n, size_t k) const { return derived_const()->buffer_b_required_size_impl(n, k); } + size_t buffer_c_required_size(size_t m, size_t n) const { return derived_const()->buffer_c_required_size_impl(m, n); } + + std::shared_ptr make_buffer_a(size_t m, size_t k, void* data) const { + return derived_const()->make_buffer_a_impl(m, k, data); + } + std::shared_ptr make_buffer_b(size_t n, size_t k, void* data) const { + return derived_const()->make_buffer_b_impl(n, k, data); + } + std::shared_ptr make_buffer_c(size_t m, size_t n, void* data) const { + return derived_const()->make_buffer_c_impl(m, n, data); + } + + void apply_activation(int activated_expert, int nth, int qlen) { + auto pool = config_.pool->get_subpool(tp_part_idx); + auto fn = [this, nth](int task_id) { + int expert_idx = m_expert_id_map_[task_id / nth]; + int ith = task_id % nth; + auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth); + for (int i = 0; i < m_local_num_[expert_idx]; i++) { + ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size]; + ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size]; + for (int j = n_start; j < n_end; j += 32) { + __m512 gate_val0, gate_val1, up_val0, up_val1; + avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); + avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); + __m512 result0 = amx::act_fn(gate_val0, up_val0); + __m512 result1 = amx::act_fn(gate_val1, up_val1); + avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); + } + } + }; + + if (activated_expert == 0) { + return; + } + + if (qlen < 10) { + for (int task_id = 0; task_id < nth * activated_expert; task_id++) { + fn(task_id); + } + } else { + pool->do_work_stealing_job(nth * activated_expert, nullptr, fn, nullptr); + } + } +}; + +// ============================================================================ +// TP_MOE specialization for AMX_MOE_BASE derived classes +// ============================================================================ + +template +class TP_MOE> : public TP_MOE_Common> { + public: + using TP_MOE_Common>::TP_MOE_Common; + + // Default load_weights implementation - can be overridden by derived TP_MOE classes + void load_weights() override { throw std::runtime_error("Not Implemented"); } + + void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num, + const std::vector& w13_weight_ptrs, + const std::vector& w13_scale_ptrs, + const std::vector& w2_weight_ptrs, + const std::vector& w2_scale_ptrs) { + throw std::runtime_error("Not Implemented"); + } + + void merge_results(int qlen, void* output, bool incremental) override { + auto& config = this->config; + auto& tp_count = this->tp_count; + auto& local_output_numa = this->local_output_numa; + auto& tp_configs = this->tp_configs; + + auto merge_fn = [this, output, incremental, &config, &tp_count, &local_output_numa, &tp_configs](int token_nth) { + float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size; + if (incremental) { + for (int e = 0; e < config.hidden_size; e += 32) { + __m512 x0, x1; + avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1); + *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0); + *((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1); + } + } + for (int i = 1; i < tp_count; i++) { + float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size; + for (int e = 0; e < tp_configs[i].hidden_size; e += 16) { + *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e))); + } + } + for (int e = 0; e < config.hidden_size; e += 32) { + __m512 x0 = *(__m512*)(merge_to + e); + __m512 x1 = *(__m512*)(merge_to + e + 16); + avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e)); + } + }; + + auto pool = config.pool; + + auto direct_or_pool = [&](int count, auto&& fn) { + if (qlen < 10) { + for (int i = 0; i < count; i++) { + fn(i); + } + } else { + pool->do_work_stealing_job(count, nullptr, fn, nullptr); + } + }; + + direct_or_pool(qlen, merge_fn); + } + + void merge_results(int qlen, void* output) override { merge_results(qlen, output, false); } +}; + +#endif // CPUINFER_OPERATOR_AMX_MOE_BASE_H diff --git a/kt-kernel/pyproject.toml b/kt-kernel/pyproject.toml index 3c7b537..4c9e55e 100644 --- a/kt-kernel/pyproject.toml +++ b/kt-kernel/pyproject.toml @@ -27,6 +27,12 @@ dependencies = [ "numpy>=1.24.0", "triton>=2.0.0", "gguf>=0.17.0", + # CLI dependencies + "typer[all]>=0.9.0", + "rich>=13.0.0", + "pyyaml>=6.0", + "httpx>=0.25.0", + "packaging>=23.0", # Development dependencies "black>=25.9.0", ] @@ -37,19 +43,35 @@ test = [ "psutil>=5.9.0", ] +[project.scripts] +kt = "kt_kernel.cli.main:main" + [project.urls] Homepage = "https://github.com/kvcache-ai" [tool.setuptools] -packages = ["kt_kernel", "kt_kernel.utils"] +packages = [ + "kt_kernel", + "kt_kernel.utils", + "kt_kernel.cli", + "kt_kernel.cli.commands", + "kt_kernel.cli.config", + "kt_kernel.cli.utils", + "kt_kernel.cli.completions", +] include-package-data = true [tool.setuptools.package-dir] kt_kernel = "python" "kt_kernel.utils" = "python/utils" +"kt_kernel.cli" = "python/cli" +"kt_kernel.cli.commands" = "python/cli/commands" +"kt_kernel.cli.config" = "python/cli/config" +"kt_kernel.cli.utils" = "python/cli/utils" +"kt_kernel.cli.completions" = "python/cli/completions" [tool.setuptools.package-data] -# (empty) placeholder if you later add resources +"kt_kernel.cli.completions" = ["*.bash", "*.fish", "_kt"] [tool.setuptools.exclude-package-data] # (empty) diff --git a/kt-kernel/python/__init__.py b/kt-kernel/python/__init__.py index 97d99d6..8b13399 100644 --- a/kt-kernel/python/__init__.py +++ b/kt-kernel/python/__init__.py @@ -37,11 +37,13 @@ from __future__ import annotations # Detect CPU and load optimal extension variant from ._cpu_detect import initialize as _initialize_cpu + _kt_kernel_ext, __cpu_variant__ = _initialize_cpu() # Make the extension module available to other modules in this package import sys -sys.modules['kt_kernel_ext'] = _kt_kernel_ext + +sys.modules["kt_kernel_ext"] = _kt_kernel_ext # Also expose kt_kernel_ext as an attribute for backward compatibility kt_kernel_ext = _kt_kernel_ext @@ -53,25 +55,28 @@ from .experts import KTMoEWrapper try: # Try to get version from installed package metadata (works in installed environment) from importlib.metadata import version, PackageNotFoundError + try: - __version__ = version('kt-kernel') + __version__ = version("kt-kernel") except PackageNotFoundError: # Package not installed, try to read from source tree version.py import os - _root_version_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'version.py') + + _root_version_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "version.py") if os.path.exists(_root_version_file): _version_ns = {} - with open(_root_version_file, 'r', encoding='utf-8') as f: + with open(_root_version_file, "r", encoding="utf-8") as f: exec(f.read(), _version_ns) - __version__ = _version_ns.get('__version__', '0.4.3') + __version__ = _version_ns.get("__version__", "0.4.3") else: __version__ = "0.4.3" except ImportError: # Python < 3.8, fallback to pkg_resources or hardcoded version try: from pkg_resources import get_distribution, DistributionNotFound + try: - __version__ = get_distribution('kt-kernel').version + __version__ = get_distribution("kt-kernel").version except DistributionNotFound: __version__ = "0.4.3" except ImportError: diff --git a/kt-kernel/python/_cpu_detect.py b/kt-kernel/python/_cpu_detect.py index f0fdab3..c219643 100644 --- a/kt-kernel/python/_cpu_detect.py +++ b/kt-kernel/python/_cpu_detect.py @@ -17,6 +17,7 @@ Example: >>> os.environ['KT_KERNEL_CPU_VARIANT'] = 'avx2' >>> import kt_kernel # Will use AVX2 variant """ + import os import sys from pathlib import Path @@ -35,82 +36,82 @@ def detect_cpu_features(): str: 'amx', 'avx512', or 'avx2' """ # Check environment override - variant = os.environ.get('KT_KERNEL_CPU_VARIANT', '').lower() - if variant in ['amx', 'avx512', 'avx2']: - if os.environ.get('KT_KERNEL_DEBUG') == '1': + variant = os.environ.get("KT_KERNEL_CPU_VARIANT", "").lower() + if variant in ["amx", "avx512", "avx2"]: + if os.environ.get("KT_KERNEL_DEBUG") == "1": print(f"[kt-kernel] Using environment override: {variant}") return variant # Try to read /proc/cpuinfo on Linux try: - with open('/proc/cpuinfo', 'r') as f: + with open("/proc/cpuinfo", "r") as f: cpuinfo = f.read().lower() # Check for AMX support (Intel Sapphire Rapids+) # AMX requires amx_tile, amx_int8, and amx_bf16 - amx_flags = ['amx_tile', 'amx_int8', 'amx_bf16'] + amx_flags = ["amx_tile", "amx_int8", "amx_bf16"] has_amx = all(flag in cpuinfo for flag in amx_flags) if has_amx: - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] Detected AMX support via /proc/cpuinfo") - return 'amx' + return "amx" # Check for AVX512 support # AVX512F is the foundation for all AVX512 variants - if 'avx512f' in cpuinfo: - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if "avx512f" in cpuinfo: + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] Detected AVX512 support via /proc/cpuinfo") - return 'avx512' + return "avx512" # Check for AVX2 support - if 'avx2' in cpuinfo: - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if "avx2" in cpuinfo: + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] Detected AVX2 support via /proc/cpuinfo") - return 'avx2' + return "avx2" # Fallback to AVX2 (should be rare on modern CPUs) - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] No AVX2/AVX512/AMX detected, using AVX2 fallback") - return 'avx2' + return "avx2" except FileNotFoundError: # /proc/cpuinfo doesn't exist (not Linux or in container) # Try cpufeature package as fallback - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] /proc/cpuinfo not found, trying cpufeature package") try: import cpufeature # Check for AMX - if cpufeature.CPUFeature.get('AMX_TILE', False): - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if cpufeature.CPUFeature.get("AMX_TILE", False): + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] Detected AMX support via cpufeature") - return 'amx' + return "amx" # Check for AVX512 - if cpufeature.CPUFeature.get('AVX512F', False): - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if cpufeature.CPUFeature.get("AVX512F", False): + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] Detected AVX512 support via cpufeature") - return 'avx512' + return "avx512" # Fallback to AVX2 - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] Using AVX2 fallback via cpufeature") - return 'avx2' + return "avx2" except ImportError: # cpufeature not available - ultimate fallback - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] cpufeature not available, using AVX2 fallback") - return 'avx2' + return "avx2" except Exception as e: # Any other error - safe fallback - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print(f"[kt-kernel] Error during CPU detection: {e}, using AVX2 fallback") - return 'avx2' + return "avx2" def load_extension(variant): @@ -148,51 +149,53 @@ def load_extension(variant): kt_kernel_dir = os.path.dirname(os.path.abspath(__file__)) # Try multi-variant naming first - pattern = os.path.join(kt_kernel_dir, f'_kt_kernel_ext_{variant}.*.so') + pattern = os.path.join(kt_kernel_dir, f"_kt_kernel_ext_{variant}.*.so") so_files = glob.glob(pattern) if not so_files: # Try single-variant naming (fallback for builds without CPUINFER_BUILD_ALL_VARIANTS) - pattern = os.path.join(kt_kernel_dir, 'kt_kernel_ext.*.so') + pattern = os.path.join(kt_kernel_dir, "kt_kernel_ext.*.so") so_files = glob.glob(pattern) if so_files: - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print(f"[kt-kernel] Multi-variant {variant} not found, using single-variant build") else: - raise ImportError(f"No .so file found for variant {variant} (tried patterns: {kt_kernel_dir}/_kt_kernel_ext_{variant}.*.so and {kt_kernel_dir}/kt_kernel_ext.*.so)") + raise ImportError( + f"No .so file found for variant {variant} (tried patterns: {kt_kernel_dir}/_kt_kernel_ext_{variant}.*.so and {kt_kernel_dir}/kt_kernel_ext.*.so)" + ) so_file = so_files[0] - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print(f"[kt-kernel] Loading {variant} from: {so_file}") # Load the module manually # The module exports PyInit_kt_kernel_ext, so we use that as the module name - spec = importlib.util.spec_from_file_location('kt_kernel_ext', so_file) + spec = importlib.util.spec_from_file_location("kt_kernel_ext", so_file) if spec is None or spec.loader is None: raise ImportError(f"Failed to create spec for {so_file}") ext = importlib.util.module_from_spec(spec) spec.loader.exec_module(ext) - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print(f"[kt-kernel] Successfully loaded {variant.upper()} variant") return ext except (ImportError, ModuleNotFoundError, FileNotFoundError) as e: - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print(f"[kt-kernel] Failed to load {variant} variant: {e}") # Automatic fallback to next best variant - if variant == 'amx': - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if variant == "amx": + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] Falling back from AMX to AVX512") - return load_extension('avx512') - elif variant == 'avx512': - if os.environ.get('KT_KERNEL_DEBUG') == '1': + return load_extension("avx512") + elif variant == "avx512": + if os.environ.get("KT_KERNEL_DEBUG") == "1": print("[kt-kernel] Falling back from AVX512 to AVX2") - return load_extension('avx2') + return load_extension("avx2") else: # AVX2 is the last fallback - if this fails, we can't continue raise ImportError( @@ -221,13 +224,13 @@ def initialize(): # Detect CPU features variant = detect_cpu_features() - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print(f"[kt-kernel] Selected CPU variant: {variant}") # Load the appropriate extension ext = load_extension(variant) - if os.environ.get('KT_KERNEL_DEBUG') == '1': + if os.environ.get("KT_KERNEL_DEBUG") == "1": print(f"[kt-kernel] Extension module loaded: {ext.__name__}") return ext, variant diff --git a/kt-kernel/python/cli/__init__.py b/kt-kernel/python/cli/__init__.py new file mode 100644 index 0000000..c3af5ed --- /dev/null +++ b/kt-kernel/python/cli/__init__.py @@ -0,0 +1,8 @@ +""" +KTransformers CLI - A unified command-line interface for KTransformers. + +This CLI provides a user-friendly interface to all KTransformers functionality, +including model inference, fine-tuning, benchmarking, and more. +""" + +__version__ = "0.1.0" diff --git a/kt-kernel/python/cli/commands/__init__.py b/kt-kernel/python/cli/commands/__init__.py new file mode 100644 index 0000000..0a4deaa --- /dev/null +++ b/kt-kernel/python/cli/commands/__init__.py @@ -0,0 +1,3 @@ +""" +Command modules for kt-cli. +""" diff --git a/kt-kernel/python/cli/commands/bench.py b/kt-kernel/python/cli/commands/bench.py new file mode 100644 index 0000000..4283f45 --- /dev/null +++ b/kt-kernel/python/cli/commands/bench.py @@ -0,0 +1,274 @@ +""" +Bench commands for kt-cli. + +Runs benchmarks for performance testing. +""" + +import subprocess +import sys +from enum import Enum +from pathlib import Path +from typing import Optional + +import typer + +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import ( + console, + print_error, + print_info, + print_step, + print_success, +) + + +class BenchType(str, Enum): + """Benchmark type.""" + + INFERENCE = "inference" + MLA = "mla" + MOE = "moe" + LINEAR = "linear" + ATTENTION = "attention" + ALL = "all" + + +def bench( + type: BenchType = typer.Option( + BenchType.ALL, + "--type", + "-t", + help="Benchmark type", + ), + model: Optional[str] = typer.Option( + None, + "--model", + "-m", + help="Model to benchmark", + ), + output: Optional[Path] = typer.Option( + None, + "--output", + "-o", + help="Output file for results (JSON)", + ), + iterations: int = typer.Option( + 10, + "--iterations", + "-n", + help="Number of iterations", + ), +) -> None: + """Run full benchmark suite.""" + console.print() + print_step(t("bench_starting")) + print_info(t("bench_type", type=type.value)) + console.print() + + if type == BenchType.ALL: + _run_all_benchmarks(model, output, iterations) + elif type == BenchType.INFERENCE: + _run_inference_benchmark(model, output, iterations) + elif type == BenchType.MLA: + _run_component_benchmark("mla", output, iterations) + elif type == BenchType.MOE: + _run_component_benchmark("moe", output, iterations) + elif type == BenchType.LINEAR: + _run_component_benchmark("linear", output, iterations) + elif type == BenchType.ATTENTION: + _run_component_benchmark("attention", output, iterations) + + console.print() + print_success(t("bench_complete")) + if output: + console.print(f" Results saved to: {output}") + console.print() + + +def microbench( + component: str = typer.Argument( + "moe", + help="Component to benchmark (moe, mla, linear, attention)", + ), + batch_size: int = typer.Option( + 1, + "--batch-size", + "-b", + help="Batch size", + ), + seq_len: int = typer.Option( + 1, + "--seq-len", + "-s", + help="Sequence length", + ), + iterations: int = typer.Option( + 100, + "--iterations", + "-n", + help="Number of iterations", + ), + warmup: int = typer.Option( + 10, + "--warmup", + "-w", + help="Warmup iterations", + ), + output: Optional[Path] = typer.Option( + None, + "--output", + "-o", + help="Output file for results (JSON)", + ), +) -> None: + """Run micro-benchmark for specific components.""" + console.print() + console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]") + console.print() + raise typer.Exit(0) + + # Try to find the benchmark script + kt_kernel_path = _find_kt_kernel_path() + + if kt_kernel_path is None: + print_error("kt-kernel not found. Install with: kt install inference") + raise typer.Exit(1) + + bench_dir = kt_kernel_path / "bench" + + # Map component to script + component_scripts = { + "moe": "bench_moe.py", + "mla": "bench_mla.py", + "linear": "bench_linear.py", + "attention": "bench_attention.py", + "mlp": "bench_mlp.py", + } + + script_name = component_scripts.get(component.lower()) + if script_name is None: + print_error(f"Unknown component: {component}") + console.print(f"Available: {', '.join(component_scripts.keys())}") + raise typer.Exit(1) + + script_path = bench_dir / script_name + if not script_path.exists(): + print_error(f"Benchmark script not found: {script_path}") + raise typer.Exit(1) + + # Run benchmark + cmd = [ + sys.executable, + str(script_path), + "--batch-size", + str(batch_size), + "--seq-len", + str(seq_len), + "--iterations", + str(iterations), + "--warmup", + str(warmup), + ] + + if output: + cmd.extend(["--output", str(output)]) + + console.print(f"[dim]$ {' '.join(cmd)}[/dim]") + console.print() + + try: + process = subprocess.run(cmd) + + if process.returncode == 0: + console.print() + print_success(t("bench_complete")) + if output: + console.print(f" Results saved to: {output}") + else: + print_error(f"Benchmark failed with exit code {process.returncode}") + raise typer.Exit(process.returncode) + + except FileNotFoundError as e: + print_error(f"Failed to run benchmark: {e}") + raise typer.Exit(1) + + +def _find_kt_kernel_path() -> Optional[Path]: + """Find the kt-kernel installation path.""" + try: + import kt_kernel + + return Path(kt_kernel.__file__).parent.parent + except ImportError: + pass + + # Check common locations + possible_paths = [ + Path.home() / "Projects" / "ktransformers" / "kt-kernel", + Path("/opt/ktransformers/kt-kernel"), + Path.cwd() / "kt-kernel", + ] + + for path in possible_paths: + if path.exists() and (path / "bench").exists(): + return path + + return None + + +def _run_all_benchmarks(model: Optional[str], output: Optional[Path], iterations: int) -> None: + """Run all benchmarks.""" + components = ["moe", "mla", "linear", "attention"] + + for component in components: + console.print(f"\n[bold]Running {component} benchmark...[/bold]") + _run_component_benchmark(component, None, iterations) + + +def _run_inference_benchmark(model: Optional[str], output: Optional[Path], iterations: int) -> None: + """Run inference benchmark.""" + if model is None: + print_error("Model required for inference benchmark. Use --model flag.") + raise typer.Exit(1) + + print_info(f"Running inference benchmark on {model}...") + console.print() + console.print("[dim]This will start the server and run test requests.[/dim]") + console.print() + + # TODO: Implement actual inference benchmarking + print_error("Inference benchmarking not yet implemented.") + + +def _run_component_benchmark(component: str, output: Optional[Path], iterations: int) -> None: + """Run a component benchmark.""" + kt_kernel_path = _find_kt_kernel_path() + + if kt_kernel_path is None: + print_error("kt-kernel not found.") + return + + bench_dir = kt_kernel_path / "bench" + script_map = { + "moe": "bench_moe.py", + "mla": "bench_mla.py", + "linear": "bench_linear.py", + "attention": "bench_attention.py", + } + + script_name = script_map.get(component) + if script_name is None: + print_error(f"Unknown component: {component}") + return + + script_path = bench_dir / script_name + if not script_path.exists(): + print_error(f"Script not found: {script_path}") + return + + cmd = [sys.executable, str(script_path), "--iterations", str(iterations)] + + try: + subprocess.run(cmd) + except Exception as e: + print_error(f"Benchmark failed: {e}") diff --git a/kt-kernel/python/cli/commands/chat.py b/kt-kernel/python/cli/commands/chat.py new file mode 100644 index 0000000..978f0a6 --- /dev/null +++ b/kt-kernel/python/cli/commands/chat.py @@ -0,0 +1,437 @@ +""" +Chat command for kt-cli. + +Provides interactive chat interface with running model server. +""" + +import json +import os +import sys +from datetime import datetime +from pathlib import Path +from typing import Optional + +import typer +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.prompt import Prompt, Confirm + +from kt_kernel.cli.config.settings import get_settings +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import ( + console, + print_error, + print_info, + print_success, + print_warning, +) + +# Try to import OpenAI SDK +try: + from openai import OpenAI + + HAS_OPENAI = True +except ImportError: + HAS_OPENAI = False + + +def chat( + host: Optional[str] = typer.Option( + None, + "--host", + "-H", + help="Server host address", + ), + port: Optional[int] = typer.Option( + None, + "--port", + "-p", + help="Server port", + ), + model: Optional[str] = typer.Option( + None, + "--model", + "-m", + help="Model name (if server hosts multiple models)", + ), + temperature: float = typer.Option( + 0.7, + "--temperature", + "-t", + help="Sampling temperature (0.0 to 2.0)", + ), + max_tokens: int = typer.Option( + 2048, + "--max-tokens", + help="Maximum tokens to generate", + ), + system_prompt: Optional[str] = typer.Option( + None, + "--system", + "-s", + help="System prompt", + ), + save_history: bool = typer.Option( + True, + "--save-history/--no-save-history", + help="Save conversation history", + ), + history_file: Optional[Path] = typer.Option( + None, + "--history-file", + help="Path to save conversation history", + ), + stream: bool = typer.Option( + True, + "--stream/--no-stream", + help="Enable streaming output", + ), +) -> None: + """Start interactive chat with a running model server. + + Examples: + kt chat # Connect to default server + kt chat --host 127.0.0.1 -p 8080 # Connect to specific server + kt chat -t 0.9 --max-tokens 4096 # Adjust generation parameters + """ + if not HAS_OPENAI: + print_error("OpenAI Python SDK is required for chat functionality.") + console.print() + console.print("Install it with:") + console.print(" pip install openai") + raise typer.Exit(1) + + settings = get_settings() + + # Resolve server connection + final_host = host or settings.get("server.host", "127.0.0.1") + final_port = port or settings.get("server.port", 30000) + + # Construct base URL for OpenAI-compatible API + base_url = f"http://{final_host}:{final_port}/v1" + + console.print() + console.print( + Panel.fit( + f"[bold cyan]KTransformers Chat[/bold cyan]\n\n" + f"Server: [yellow]{final_host}:{final_port}[/yellow]\n" + f"Temperature: [cyan]{temperature}[/cyan] | Max tokens: [cyan]{max_tokens}[/cyan]\n\n" + f"[dim]Type '/help' for commands, '/quit' to exit[/dim]", + border_style="cyan", + ) + ) + console.print() + + # Check for proxy environment variables + proxy_vars = ["HTTP_PROXY", "HTTPS_PROXY", "http_proxy", "https_proxy", "ALL_PROXY", "all_proxy"] + detected_proxies = {var: os.environ.get(var) for var in proxy_vars if os.environ.get(var)} + + if detected_proxies: + proxy_info = ", ".join(f"{k}={v}" for k, v in detected_proxies.items()) + console.print() + print_warning(t("chat_proxy_detected")) + console.print(f" [dim]{proxy_info}[/dim]") + console.print() + + use_proxy = Confirm.ask(t("chat_proxy_confirm"), default=False) + + if not use_proxy: + # Temporarily disable proxy for this connection + for var in proxy_vars: + if var in os.environ: + del os.environ[var] + print_info(t("chat_proxy_disabled")) + console.print() + + # Initialize OpenAI client + try: + client = OpenAI( + base_url=base_url, + api_key="EMPTY", # SGLang doesn't require API key + ) + + # Test connection + print_info("Connecting to server...") + models = client.models.list() + available_models = [m.id for m in models.data] + + if not available_models: + print_error("No models available on server") + raise typer.Exit(1) + + # Select model + if model: + if model not in available_models: + print_warning(f"Model '{model}' not found. Available models: {', '.join(available_models)}") + selected_model = available_models[0] + else: + selected_model = model + else: + selected_model = available_models[0] + + print_success(f"Connected to model: {selected_model}") + console.print() + + except Exception as e: + print_error(f"Failed to connect to server: {e}") + console.print() + console.print("Make sure the model server is running:") + console.print(" kt run ") + raise typer.Exit(1) + + # Initialize conversation history + messages = [] + + # Add system prompt if provided + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Setup history file + if save_history: + if history_file is None: + history_dir = settings.config_dir / "chat_history" + history_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + history_file = history_dir / f"chat_{timestamp}.json" + else: + history_file = Path(history_file) + history_file.parent.mkdir(parents=True, exist_ok=True) + + # Main chat loop + try: + while True: + # Get user input + try: + user_input = Prompt.ask("[bold green]You[/bold green]") + except (EOFError, KeyboardInterrupt): + console.print() + print_info("Goodbye!") + break + + if not user_input.strip(): + continue + + # Handle special commands + if user_input.startswith("/"): + if _handle_command(user_input, messages, temperature, max_tokens): + continue + else: + break # Exit command + + # Add user message to history + messages.append({"role": "user", "content": user_input}) + + # Generate response + console.print() + console.print("[bold cyan]Assistant[/bold cyan]") + + try: + if stream: + # Streaming response + response_content = _stream_response(client, selected_model, messages, temperature, max_tokens) + else: + # Non-streaming response + response_content = _generate_response(client, selected_model, messages, temperature, max_tokens) + + # Add assistant response to history + messages.append({"role": "assistant", "content": response_content}) + + console.print() + + except Exception as e: + print_error(f"Error generating response: {e}") + # Remove the user message that caused the error + messages.pop() + continue + + # Save history if enabled + if save_history: + _save_history(history_file, messages, selected_model) + + except KeyboardInterrupt: + console.print() + console.print() + print_info("Chat interrupted. Goodbye!") + + # Final history save + if save_history and messages: + _save_history(history_file, messages, selected_model) + console.print(f"[dim]History saved to: {history_file}[/dim]") + console.print() + + +def _stream_response( + client: "OpenAI", + model: str, + messages: list, + temperature: float, + max_tokens: int, +) -> str: + """Generate streaming response and display in real-time.""" + response_content = "" + + try: + stream = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + stream=True, + ) + + for chunk in stream: + if chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + response_content += content + console.print(content, end="") + + console.print() # Newline after streaming + + except Exception as e: + raise Exception(f"Streaming error: {e}") + + return response_content + + +def _generate_response( + client: "OpenAI", + model: str, + messages: list, + temperature: float, + max_tokens: int, +) -> str: + """Generate non-streaming response.""" + try: + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + stream=False, + ) + + content = response.choices[0].message.content + + # Display as markdown + md = Markdown(content) + console.print(md) + + return content + + except Exception as e: + raise Exception(f"Generation error: {e}") + + +def _handle_command(command: str, messages: list, temperature: float, max_tokens: int) -> bool: + """Handle special commands. Returns True to continue chat, False to exit.""" + cmd = command.lower().strip() + + if cmd in ["/quit", "/exit", "/q"]: + console.print() + print_info("Goodbye!") + return False + + elif cmd in ["/help", "/h"]: + console.print() + console.print( + Panel( + "[bold]Available Commands:[/bold]\n\n" + "/help, /h - Show this help message\n" + "/quit, /exit, /q - Exit chat\n" + "/clear, /c - Clear conversation history\n" + "/history, /hist - Show conversation history\n" + "/info, /i - Show current settings\n" + "/retry, /r - Regenerate last response", + title="Help", + border_style="cyan", + ) + ) + console.print() + return True + + elif cmd in ["/clear", "/c"]: + messages.clear() + console.print() + print_success("Conversation history cleared") + console.print() + return True + + elif cmd in ["/history", "/hist"]: + console.print() + if not messages: + print_info("No conversation history") + else: + console.print( + Panel( + _format_history(messages), + title=f"History ({len(messages)} messages)", + border_style="cyan", + ) + ) + console.print() + return True + + elif cmd in ["/info", "/i"]: + console.print() + console.print( + Panel( + f"[bold]Current Settings:[/bold]\n\n" + f"Temperature: [cyan]{temperature}[/cyan]\n" + f"Max tokens: [cyan]{max_tokens}[/cyan]\n" + f"Messages: [cyan]{len(messages)}[/cyan]", + title="Info", + border_style="cyan", + ) + ) + console.print() + return True + + elif cmd in ["/retry", "/r"]: + if len(messages) >= 2 and messages[-1]["role"] == "assistant": + # Remove last assistant response + messages.pop() + print_info("Retrying last response...") + console.print() + else: + print_warning("No previous response to retry") + console.print() + return True + + else: + print_warning(f"Unknown command: {command}") + console.print("[dim]Type /help for available commands[/dim]") + console.print() + return True + + +def _format_history(messages: list) -> str: + """Format conversation history for display.""" + lines = [] + for i, msg in enumerate(messages, 1): + role = msg["role"].capitalize() + content = msg["content"] + + # Truncate long messages + if len(content) > 200: + content = content[:200] + "..." + + lines.append(f"[bold]{i}. {role}:[/bold] {content}") + + return "\n\n".join(lines) + + +def _save_history(file_path: Path, messages: list, model: str) -> None: + """Save conversation history to file.""" + try: + history_data = { + "model": model, + "timestamp": datetime.now().isoformat(), + "messages": messages, + } + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(history_data, f, indent=2, ensure_ascii=False) + + except Exception as e: + print_warning(f"Failed to save history: {e}") diff --git a/kt-kernel/python/cli/commands/config.py b/kt-kernel/python/cli/commands/config.py new file mode 100644 index 0000000..84f2475 --- /dev/null +++ b/kt-kernel/python/cli/commands/config.py @@ -0,0 +1,167 @@ +""" +Config command for kt-cli. + +Manages kt-cli configuration. +""" + +from typing import Optional + +import typer +import yaml +from rich.syntax import Syntax + +from kt_kernel.cli.config.settings import get_settings +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import confirm, console, print_error, print_success + +app = typer.Typer(help="Manage kt-cli configuration") + + +@app.command(name="init") +def init() -> None: + """Initialize or re-run the first-time setup wizard.""" + from kt_kernel.cli.main import _show_first_run_setup + from kt_kernel.cli.config.settings import get_settings + + settings = get_settings() + _show_first_run_setup(settings) + + +@app.command(name="show") +def show( + key: Optional[str] = typer.Argument(None, help="Configuration key to show (e.g., server.port)"), +) -> None: + """Show current configuration.""" + settings = get_settings() + + if key: + value = settings.get(key) + if value is not None: + if isinstance(value, (dict, list)): + console.print(yaml.dump({key: value}, default_flow_style=False, allow_unicode=True)) + else: + console.print(t("config_get_value", key=key, value=value)) + else: + print_error(t("config_get_not_found", key=key)) + raise typer.Exit(1) + else: + console.print(f"\n[bold]{t('config_show_title')}[/bold]\n") + console.print(f"[dim]{t('config_file_location', path=str(settings.config_path))}[/dim]\n") + + config_yaml = yaml.dump(settings.get_all(), default_flow_style=False, allow_unicode=True) + syntax = Syntax(config_yaml, "yaml", theme="monokai", line_numbers=False) + console.print(syntax) + + +@app.command(name="set") +def set_config( + key: str = typer.Argument(..., help="Configuration key (e.g., server.port)"), + value: str = typer.Argument(..., help="Value to set"), +) -> None: + """Set a configuration value.""" + settings = get_settings() + + # Try to parse value as JSON/YAML for complex types + parsed_value = _parse_value(value) + + settings.set(key, parsed_value) + print_success(t("config_set_success", key=key, value=parsed_value)) + + +@app.command(name="get") +def get_config( + key: str = typer.Argument(..., help="Configuration key (e.g., server.port)"), +) -> None: + """Get a configuration value.""" + settings = get_settings() + value = settings.get(key) + + if value is not None: + if isinstance(value, (dict, list)): + console.print(yaml.dump(value, default_flow_style=False, allow_unicode=True)) + else: + console.print(str(value)) + else: + print_error(t("config_get_not_found", key=key)) + raise typer.Exit(1) + + +@app.command(name="reset") +def reset( + yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation"), +) -> None: + """Reset configuration to defaults.""" + if not yes: + if not confirm(t("config_reset_confirm"), default=False): + raise typer.Abort() + + settings = get_settings() + settings.reset() + print_success(t("config_reset_success")) + + +@app.command(name="path") +def path() -> None: + """Show configuration file path.""" + settings = get_settings() + console.print(str(settings.config_path)) + + +@app.command(name="model-path-list", deprecated=True, hidden=True) +def model_path_list() -> None: + """[Deprecated] Use 'kt model path-list' instead.""" + console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-list' instead.[/yellow]\n") + import subprocess + subprocess.run(["kt", "model", "path-list"]) + + +@app.command(name="model-path-add", deprecated=True, hidden=True) +def model_path_add( + path: str = typer.Argument(..., help="Path to add"), +) -> None: + """[Deprecated] Use 'kt model path-add' instead.""" + console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-add' instead.[/yellow]\n") + import subprocess + subprocess.run(["kt", "model", "path-add", path]) + + +@app.command(name="model-path-remove", deprecated=True, hidden=True) +def model_path_remove( + path: str = typer.Argument(..., help="Path to remove"), +) -> None: + """[Deprecated] Use 'kt model path-remove' instead.""" + console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-remove' instead.[/yellow]\n") + import subprocess + subprocess.run(["kt", "model", "path-remove", path]) + + +def _parse_value(value: str): + """Parse a string value into appropriate Python type.""" + # Try boolean + if value.lower() in ("true", "yes", "on", "1"): + return True + if value.lower() in ("false", "no", "off", "0"): + return False + + # Try integer + try: + return int(value) + except ValueError: + pass + + # Try float + try: + return float(value) + except ValueError: + pass + + # Try YAML/JSON parsing for lists/dicts + try: + parsed = yaml.safe_load(value) + if isinstance(parsed, (dict, list)): + return parsed + except yaml.YAMLError: + pass + + # Return as string + return value diff --git a/kt-kernel/python/cli/commands/doctor.py b/kt-kernel/python/cli/commands/doctor.py new file mode 100644 index 0000000..681ece2 --- /dev/null +++ b/kt-kernel/python/cli/commands/doctor.py @@ -0,0 +1,394 @@ +""" +Doctor command for kt-cli. + +Diagnoses environment issues and provides recommendations. +""" + +import platform +import shutil +from pathlib import Path +from typing import Optional + +import typer +from rich.table import Table + +from kt_kernel.cli.config.settings import get_settings +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import console, print_error, print_info, print_success, print_warning +from kt_kernel.cli.utils.environment import ( + check_docker, + detect_available_ram_gb, + detect_cpu_info, + detect_cuda_version, + detect_disk_space_gb, + detect_env_managers, + detect_gpus, + detect_memory_info, + detect_ram_gb, + get_installed_package_version, +) + + +def doctor( + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed diagnostics"), +) -> None: + """Diagnose environment issues.""" + console.print(f"\n[bold]{t('doctor_title')}[/bold]\n") + + issues_found = False + checks = [] + + # 1. Python version + python_version = platform.python_version() + python_ok = _check_python_version(python_version) + checks.append( + { + "name": t("doctor_check_python"), + "status": "ok" if python_ok else "error", + "value": python_version, + "hint": "Python 3.10+ required" if not python_ok else None, + } + ) + if not python_ok: + issues_found = True + + # 2. CUDA availability + cuda_version = detect_cuda_version() + checks.append( + { + "name": t("doctor_check_cuda"), + "status": "ok" if cuda_version else "warning", + "value": cuda_version or t("version_cuda_not_found"), + "hint": "CUDA is optional but recommended for GPU acceleration" if not cuda_version else None, + } + ) + + # 3. GPU detection + gpus = detect_gpus() + if gpus: + gpu_names = ", ".join(g.name for g in gpus) + total_vram = sum(g.vram_gb for g in gpus) + checks.append( + { + "name": t("doctor_check_gpu"), + "status": "ok", + "value": t("doctor_gpu_found", count=len(gpus), names=gpu_names), + "hint": f"Total VRAM: {total_vram}GB", + } + ) + else: + checks.append( + { + "name": t("doctor_check_gpu"), + "status": "warning", + "value": t("doctor_gpu_not_found"), + "hint": "GPU recommended for best performance", + } + ) + + # 4. CPU information + cpu_info = detect_cpu_info() + checks.append( + { + "name": t("doctor_check_cpu"), + "status": "ok", + "value": t("doctor_cpu_info", name=cpu_info.name, cores=cpu_info.cores, threads=cpu_info.threads), + "hint": None, + } + ) + + # 5. CPU instruction sets (critical for kt-kernel) + isa_list = cpu_info.instruction_sets + # Check for recommended instruction sets + recommended_isa = {"AVX2", "AVX512F", "AMX-INT8"} + has_recommended = bool(set(isa_list) & recommended_isa) + has_avx2 = "AVX2" in isa_list + has_avx512 = any(isa.startswith("AVX512") for isa in isa_list) + has_amx = any(isa.startswith("AMX") for isa in isa_list) + + # Determine status and build display string + if has_amx: + isa_status = "ok" + isa_hint = "AMX available - best performance for INT4/INT8" + elif has_avx512: + isa_status = "ok" + isa_hint = "AVX512 available - good performance" + elif has_avx2: + isa_status = "warning" + isa_hint = "AVX2 only - consider upgrading CPU for better performance" + else: + isa_status = "error" + isa_hint = "AVX2 required for kt-kernel" + + # Show top instruction sets (prioritize important ones) + display_isa = isa_list[:8] if len(isa_list) > 8 else isa_list + isa_display = ", ".join(display_isa) + if len(isa_list) > 8: + isa_display += f" (+{len(isa_list) - 8} more)" + + checks.append( + { + "name": t("doctor_check_cpu_isa"), + "status": isa_status, + "value": isa_display if isa_display else "None detected", + "hint": isa_hint, + } + ) + + # 6. NUMA topology + numa_detail = [] + for node, cpus in sorted(cpu_info.numa_info.items()): + if len(cpus) > 6: + cpu_str = f"{cpus[0]}-{cpus[-1]}" + else: + cpu_str = ",".join(str(c) for c in cpus) + numa_detail.append(f"{node}: {cpu_str}") + + numa_value = t("doctor_numa_info", nodes=cpu_info.numa_nodes) + if verbose and numa_detail: + numa_value += " (" + "; ".join(numa_detail) + ")" + + checks.append( + { + "name": t("doctor_check_numa"), + "status": "ok", + "value": numa_value, + "hint": f"{cpu_info.threads // cpu_info.numa_nodes} threads per node" if cpu_info.numa_nodes > 1 else None, + } + ) + + # 7. System memory (with frequency if available) + mem_info = detect_memory_info() + if mem_info.frequency_mhz and mem_info.type: + mem_value = t( + "doctor_memory_freq", + available=f"{mem_info.available_gb}GB", + total=f"{mem_info.total_gb}GB", + freq=mem_info.frequency_mhz, + type=mem_info.type, + ) + else: + mem_value = t("doctor_memory_info", available=f"{mem_info.available_gb}GB", total=f"{mem_info.total_gb}GB") + + ram_ok = mem_info.total_gb >= 32 + checks.append( + { + "name": t("doctor_check_memory"), + "status": "ok" if ram_ok else "warning", + "value": mem_value, + "hint": "32GB+ RAM recommended for large models" if not ram_ok else None, + } + ) + + # 8. Disk space - check all model paths + settings = get_settings() + model_paths = settings.get_model_paths() + + # Check all configured model paths + for i, disk_path in enumerate(model_paths): + available_disk, total_disk = detect_disk_space_gb(str(disk_path)) + disk_ok = available_disk >= 100 + + # For multiple paths, add index to name + path_label = f"Model Path {i+1}" if len(model_paths) > 1 else t("doctor_check_disk") + + checks.append( + { + "name": path_label, + "status": "ok" if disk_ok else "warning", + "value": t("doctor_disk_info", available=f"{available_disk}GB", path=str(disk_path)), + "hint": "100GB+ free space recommended for model storage" if not disk_ok else None, + } + ) + + # 6. Required packages + packages = [ + ("kt-kernel", ">=0.4.0", False), # name, version_req, required + ("ktransformers", ">=0.4.0", False), + ("sglang", ">=0.4.0", False), + ("torch", ">=2.4.0", True), + ("transformers", ">=4.45.0", True), + ] + + package_issues = [] + for pkg_name, version_req, required in packages: + version = get_installed_package_version(pkg_name) + if version: + package_issues.append((pkg_name, version, "ok")) + elif required: + package_issues.append((pkg_name, t("version_not_installed"), "error")) + issues_found = True + else: + package_issues.append((pkg_name, t("version_not_installed"), "warning")) + + if verbose: + checks.append( + { + "name": t("doctor_check_packages"), + "status": "ok" if not any(p[2] == "error" for p in package_issues) else "error", + "value": f"{sum(1 for p in package_issues if p[2] == 'ok')}/{len(package_issues)} installed", + "packages": package_issues, + } + ) + + # 7. SGLang installation source check + from kt_kernel.cli.utils.sglang_checker import check_sglang_installation, check_sglang_kt_kernel_support + + sglang_info = check_sglang_installation() + + if sglang_info["installed"]: + if sglang_info["from_source"]: + if sglang_info["git_info"]: + git_remote = sglang_info["git_info"].get("remote", "unknown") + git_branch = sglang_info["git_info"].get("branch", "unknown") + sglang_source_value = f"Source (GitHub: {git_remote}, branch: {git_branch})" + sglang_source_status = "ok" + sglang_source_hint = None + else: + sglang_source_value = "Source (editable)" + sglang_source_status = "ok" + sglang_source_hint = None + else: + sglang_source_value = "PyPI (not recommended)" + sglang_source_status = "warning" + sglang_source_hint = t("sglang_pypi_hint") + else: + sglang_source_value = "Not installed" + sglang_source_status = "warning" + sglang_source_hint = t("sglang_install_hint") + + checks.append( + { + "name": "SGLang Source", + "status": sglang_source_status, + "value": sglang_source_value, + "hint": sglang_source_hint, + } + ) + + # 7b. SGLang kt-kernel support check (only if SGLang is installed) + kt_kernel_support = {"supported": True} # Default to True if not checked + if sglang_info["installed"]: + # Use cache=False to force re-check in doctor, but silent=True since we show in table + kt_kernel_support = check_sglang_kt_kernel_support(use_cache=False, silent=True) + + if kt_kernel_support["supported"]: + kt_kernel_value = t("sglang_kt_kernel_supported") + kt_kernel_status = "ok" + kt_kernel_hint = None + else: + kt_kernel_value = t("sglang_kt_kernel_not_supported") + kt_kernel_status = "error" + kt_kernel_hint = 'Reinstall SGLang from: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"' + issues_found = True + + checks.append( + { + "name": "SGLang kt-kernel", + "status": kt_kernel_status, + "value": kt_kernel_value, + "hint": kt_kernel_hint, + } + ) + + # 8. Environment managers + env_managers = detect_env_managers() + docker = check_docker() + env_list = [f"{m.name} {m.version}" for m in env_managers] + if docker: + env_list.append(f"docker {docker.version}") + + checks.append( + { + "name": "Environment Managers", + "status": "ok" if env_list else "warning", + "value": ", ".join(env_list) if env_list else "None found", + "hint": "conda or docker recommended for installation" if not env_list else None, + } + ) + + # Display results + _display_results(checks, verbose) + + # Show SGLang installation instructions if not installed + if not sglang_info["installed"]: + from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions + + console.print() + print_sglang_install_instructions() + # Show kt-kernel installation instructions if SGLang is installed but doesn't support kt-kernel + elif sglang_info["installed"] and not kt_kernel_support.get("supported", True): + from kt_kernel.cli.utils.sglang_checker import print_sglang_kt_kernel_instructions + + console.print() + print_sglang_kt_kernel_instructions() + + # Summary + console.print() + if issues_found: + print_warning(t("doctor_has_issues")) + else: + print_success(t("doctor_all_ok")) + console.print() + + +def _check_python_version(version: str) -> bool: + """Check if Python version meets requirements.""" + parts = version.split(".") + try: + major, minor = int(parts[0]), int(parts[1]) + return major >= 3 and minor >= 10 + except (IndexError, ValueError): + return False + + +def _display_results(checks: list[dict], verbose: bool) -> None: + """Display diagnostic results.""" + table = Table(show_header=True, header_style="bold") + table.add_column("Check", style="bold") + table.add_column("Status", width=8) + table.add_column("Value") + if verbose: + table.add_column("Notes", style="dim") + + for check in checks: + status = check["status"] + if status == "ok": + status_str = f"[green]{t('doctor_status_ok')}[/green]" + elif status == "warning": + status_str = f"[yellow]{t('doctor_status_warning')}[/yellow]" + else: + status_str = f"[red]{t('doctor_status_error')}[/red]" + + if verbose: + table.add_row( + check["name"], + status_str, + check["value"], + check.get("hint", ""), + ) + else: + table.add_row( + check["name"], + status_str, + check["value"], + ) + + # Show package details if verbose + if verbose and "packages" in check: + for pkg_name, pkg_version, pkg_status in check["packages"]: + if pkg_status == "ok": + pkg_status_str = "[green]✓[/green]" + elif pkg_status == "warning": + pkg_status_str = "[yellow]○[/yellow]" + else: + pkg_status_str = "[red]✗[/red]" + + table.add_row( + f" └─ {pkg_name}", + pkg_status_str, + pkg_version, + "", + ) + + console.print(table) diff --git a/kt-kernel/python/cli/commands/model.py b/kt-kernel/python/cli/commands/model.py new file mode 100644 index 0000000..772ef8b --- /dev/null +++ b/kt-kernel/python/cli/commands/model.py @@ -0,0 +1,409 @@ +""" +Model command for kt-cli. + +Manages models: download, list, and storage paths. +""" + +import os +from pathlib import Path +from typing import Optional + +import typer + +from kt_kernel.cli.config.settings import get_settings +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import ( + confirm, + console, + print_error, + print_info, + print_success, + print_warning, + prompt_choice, +) + +app = typer.Typer( + help="Manage models and storage paths", + invoke_without_command=True, + no_args_is_help=False, +) + + +@app.callback() +def callback(ctx: typer.Context) -> None: + """ + Model management commands. + + Run without arguments to see available models. + """ + # If no subcommand is provided, show the model list + if ctx.invoked_subcommand is None: + show_model_list() + + +def show_model_list() -> None: + """Display available models with their status and paths.""" + from rich.table import Table + from kt_kernel.cli.utils.model_registry import get_registry + from kt_kernel.cli.i18n import get_lang + + registry = get_registry() + settings = get_settings() + + console.print() + console.print(f"[bold cyan]{t('model_supported_title')}[/bold cyan]\n") + + # Get local models mapping + local_models = {m.name: p for m, p in registry.find_local_models()} + + # Create table + table = Table(show_header=True, header_style="bold") + table.add_column(t("model_column_model"), style="cyan", no_wrap=True) + table.add_column(t("model_column_status"), justify="center") + + all_models = registry.list_all() + for model in all_models: + if model.name in local_models: + status = f"[green]✓ {t('model_status_local')}[/green]" + else: + status = "[dim]-[/dim]" + + table.add_row(model.name, status) + + console.print(table) + console.print() + + # Usage instructions + console.print(f"[bold]{t('model_usage_title')}:[/bold]") + console.print(f" • {t('model_usage_download')} [cyan]kt model download [/cyan]") + console.print(f" • {t('model_usage_list_local')} [cyan]kt model list --local[/cyan]") + console.print(f" • {t('model_usage_search')} [cyan]kt model search [/cyan]") + console.print() + + # Show model storage paths + model_paths = settings.get_model_paths() + console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]") + for path in model_paths: + marker = "[green]✓[/green]" if path.exists() else "[dim]✗[/dim]" + console.print(f" {marker} {path}") + console.print() + + +@app.command(name="download") +def download( + model: Optional[str] = typer.Argument( + None, + help="Model name or HuggingFace repo (e.g., deepseek-v3, Qwen/Qwen3-30B)", + ), + path: Optional[Path] = typer.Option( + None, + "--path", + "-p", + help="Custom download path", + ), + list_models: bool = typer.Option( + False, + "--list", + "-l", + help="List available models", + ), + resume: bool = typer.Option( + True, + "--resume/--no-resume", + help="Resume incomplete downloads", + ), + yes: bool = typer.Option( + False, + "--yes", + "-y", + help="Skip confirmation prompts", + ), +) -> None: + """Download model weights from HuggingFace.""" + import subprocess + from kt_kernel.cli.i18n import get_lang + from kt_kernel.cli.utils.console import print_model_table, print_step + from kt_kernel.cli.utils.model_registry import get_registry + + settings = get_settings() + registry = get_registry() + + console.print() + + # List mode + if list_models or model is None: + print_step(t("download_list_title")) + console.print() + + models = registry.list_all() + model_dicts = [] + for m in models: + lang = get_lang() + desc = m.description_zh if lang == "zh" and m.description_zh else m.description + model_dicts.append( + { + "name": m.name, + "hf_repo": m.hf_repo, + "type": m.type, + "gpu_vram_gb": m.gpu_vram_gb, + "cpu_ram_gb": m.cpu_ram_gb, + } + ) + + print_model_table(model_dicts) + console.print() + + if model is None: + console.print(f"[dim]{t('model_download_usage_hint')}[/dim]") + console.print() + return + + # Search for model + print_step(t("download_searching", name=model)) + + # Check if it's a direct HuggingFace repo path + if "/" in model: + hf_repo = model + model_info = None + model_name = model.split("/")[-1] + else: + matches = registry.search(model) + + if not matches: + print_error(t("run_model_not_found", name=model)) + console.print() + console.print(t("model_download_list_hint")) + console.print(t("model_download_hf_hint")) + raise typer.Exit(1) + + if len(matches) == 1: + model_info = matches[0] + else: + console.print() + print_info(t("download_multiple_found")) + choices = [f"{m.name} ({m.hf_repo})" for m in matches] + selected = prompt_choice(t("download_select"), choices) + idx = choices.index(selected) + model_info = matches[idx] + + hf_repo = model_info.hf_repo + model_name = model_info.name + + print_success(t("download_found", name=hf_repo)) + + # Determine download path + if path is None: + download_path = settings.models_dir / model_name.replace(" ", "-") + else: + download_path = path + + console.print() + print_info(t("download_destination", path=str(download_path))) + + # Check if already exists + if download_path.exists() and (download_path / "config.json").exists(): + print_warning(t("download_already_exists", path=str(download_path))) + if not yes: + if not confirm(t("download_overwrite_prompt"), default=False): + raise typer.Abort() + + # Confirm download + if not yes: + console.print() + if not confirm(t("prompt_continue")): + raise typer.Abort() + + # Download using huggingface-cli + console.print() + print_step(t("download_starting")) + + cmd = [ + "huggingface-cli", + "download", + hf_repo, + "--local-dir", + str(download_path), + ] + + if resume: + cmd.append("--resume-download") + + # Add mirror if configured + mirror = settings.get("download.mirror", "") + if mirror: + cmd.extend(["--endpoint", mirror]) + + try: + process = subprocess.run(cmd, check=True) + + console.print() + print_success(t("download_complete")) + console.print() + console.print(f" {t('model_saved_to', path=download_path)}") + console.print() + console.print(f" {t('model_start_with', name=model_name)}") + console.print() + + except subprocess.CalledProcessError as e: + print_error(t("model_download_failed", error=str(e))) + raise typer.Exit(1) + except FileNotFoundError: + print_error(t("model_hf_cli_not_found")) + raise typer.Exit(1) + + +@app.command(name="list") +def list_models( + local_only: bool = typer.Option(False, "--local", help="Show only locally downloaded models"), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed info including paths"), +) -> None: + """List available models.""" + from rich.table import Table + from kt_kernel.cli.utils.model_registry import get_registry + + registry = get_registry() + console.print() + + if local_only: + # Show only local models + local_models = registry.find_local_models() + + if not local_models: + print_warning(t("model_no_local_models")) + console.print() + console.print(f" {t('model_download_hint')} [cyan]kt model download [/cyan]") + console.print() + return + + table = Table(title=t("model_local_models_title"), show_header=True, header_style="bold") + table.add_column(t("model_column_model"), style="cyan", no_wrap=True) + if verbose: + table.add_column(t("model_column_local_path"), style="dim") + + for model_info, model_path in local_models: + if verbose: + table.add_row(model_info.name, str(model_path)) + else: + table.add_row(model_info.name) + + console.print(table) + else: + # Show all registered models + all_models = registry.list_all() + local_models_dict = {m.name: p for m, p in registry.find_local_models()} + + table = Table(title=t("model_available_models_title"), show_header=True, header_style="bold") + table.add_column(t("model_column_model"), style="cyan", no_wrap=True) + table.add_column(t("model_column_status"), justify="center") + if verbose: + table.add_column(t("model_column_local_path"), style="dim") + + for model in all_models: + if model.name in local_models_dict: + status = f"[green]✓ {t('model_status_local')}[/green]" + local_path = str(local_models_dict[model.name]) + else: + status = "[dim]-[/dim]" + local_path = f"[dim]{t('model_status_not_downloaded')}[/dim]" + + if verbose: + table.add_row(model.name, status, local_path) + else: + table.add_row(model.name, status) + + console.print(table) + + console.print() + + +@app.command(name="path-list") +def path_list() -> None: + """List all configured model storage paths.""" + settings = get_settings() + model_paths = settings.get_model_paths() + + console.print() + console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]\n") + + for i, path in enumerate(model_paths, 1): + marker = "[green]✓[/green]" if path.exists() else "[red]✗[/red]" + console.print(f" {marker} [{i}] {path}") + + console.print() + + +@app.command(name="path-add") +def path_add( + path: str = typer.Argument(..., help="Path to add"), +) -> None: + """Add a new model storage path.""" + # Expand user home directory + path = os.path.expanduser(path) + + # Check if path exists or can be created + path_obj = Path(path) + if not path_obj.exists(): + console.print(f"[yellow]{t('model_path_not_exist', path=path)}[/yellow]") + if confirm(t("model_create_directory", path=path), default=True): + try: + path_obj.mkdir(parents=True, exist_ok=True) + console.print(f"[green]✓[/green] {t('model_created_directory', path=path)}") + except (OSError, PermissionError) as e: + print_error(t("model_create_dir_failed", error=str(e))) + raise typer.Exit(1) + else: + raise typer.Abort() + + # Add to configuration + settings = get_settings() + settings.add_model_path(path) + print_success(t("model_path_added", path=path)) + + +@app.command(name="path-remove") +def path_remove( + path: str = typer.Argument(..., help="Path to remove"), +) -> None: + """Remove a model storage path from configuration.""" + # Expand user home directory + path = os.path.expanduser(path) + + settings = get_settings() + if settings.remove_model_path(path): + print_success(t("model_path_removed", path=path)) + else: + print_error(t("model_path_not_found", path=path)) + raise typer.Exit(1) + + +@app.command(name="search") +def search( + query: str = typer.Argument(..., help="Search query (model name or keyword)"), +) -> None: + """Search for models in the registry.""" + from rich.table import Table + from kt_kernel.cli.utils.model_registry import get_registry + + registry = get_registry() + matches = registry.search(query) + + console.print() + + if not matches: + print_warning(t("model_search_no_results", query=query)) + console.print() + return + + table = Table(title=t("model_search_results_title", query=query), show_header=True) + table.add_column(t("model_column_name"), style="cyan") + table.add_column(t("model_column_hf_repo"), style="dim") + table.add_column(t("model_column_aliases"), style="yellow") + + for model in matches: + aliases = ", ".join(model.aliases[:3]) + if len(model.aliases) > 3: + aliases += f" +{len(model.aliases) - 3} more" + table.add_row(model.name, model.hf_repo, aliases) + + console.print(table) + console.print() diff --git a/kt-kernel/python/cli/commands/quant.py b/kt-kernel/python/cli/commands/quant.py new file mode 100644 index 0000000..c6cf2c3 --- /dev/null +++ b/kt-kernel/python/cli/commands/quant.py @@ -0,0 +1,239 @@ +""" +Quant command for kt-cli. + +Quantizes model weights for CPU inference. +""" + +import subprocess +import sys +from enum import Enum +from pathlib import Path +from typing import Optional + +import typer + +from kt_kernel.cli.config.settings import get_settings +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import ( + confirm, + console, + create_progress, + print_error, + print_info, + print_step, + print_success, + print_warning, +) +from kt_kernel.cli.utils.environment import detect_cpu_info + + +class QuantMethod(str, Enum): + """Quantization method.""" + + INT4 = "int4" + INT8 = "int8" + + +def quant( + model: str = typer.Argument( + ..., + help="Model name or path to quantize", + ), + method: QuantMethod = typer.Option( + QuantMethod.INT4, + "--method", + "-m", + help="Quantization method", + ), + output: Optional[Path] = typer.Option( + None, + "--output", + "-o", + help="Output path for quantized weights", + ), + input_type: str = typer.Option( + "fp8", + "--input-type", + "-i", + help="Input weight type (fp8, fp16, bf16)", + ), + cpu_threads: Optional[int] = typer.Option( + None, + "--cpu-threads", + help="Number of CPU threads for quantization", + ), + numa_nodes: Optional[int] = typer.Option( + None, + "--numa-nodes", + help="Number of NUMA nodes", + ), + no_merge: bool = typer.Option( + False, + "--no-merge", + help="Don't merge safetensor files", + ), + yes: bool = typer.Option( + False, + "--yes", + "-y", + help="Skip confirmation prompts", + ), +) -> None: + """Quantize model weights for CPU inference.""" + settings = get_settings() + console.print() + + # Resolve input path + input_path = _resolve_input_path(model, settings) + if input_path is None: + print_error(t("quant_input_not_found", path=model)) + raise typer.Exit(1) + + print_info(t("quant_input_path", path=str(input_path))) + + # Resolve output path + if output is None: + output = input_path.parent / f"{input_path.name}-{method.value.upper()}" + + print_info(t("quant_output_path", path=str(output))) + print_info(t("quant_method", method=method.value.upper())) + + # Detect CPU configuration + cpu = detect_cpu_info() + final_cpu_threads = cpu_threads or cpu.cores + final_numa_nodes = numa_nodes or cpu.numa_nodes + + print_info(f"CPU threads: {final_cpu_threads}") + print_info(f"NUMA nodes: {final_numa_nodes}") + + # Check if output exists + if output.exists(): + print_warning(f"Output path already exists: {output}") + if not yes: + if not confirm("Overwrite?", default=False): + raise typer.Abort() + + # Confirm + if not yes: + console.print() + console.print("[bold]Quantization Settings:[/bold]") + console.print(f" Input: {input_path}") + console.print(f" Output: {output}") + console.print(f" Method: {method.value.upper()}") + console.print(f" Input type: {input_type}") + console.print() + print_warning("Quantization may take 30-60 minutes depending on model size.") + console.print() + + if not confirm(t("prompt_continue")): + raise typer.Abort() + + # Find conversion script + kt_kernel_path = _find_kt_kernel_path() + if kt_kernel_path is None: + print_error("kt-kernel not found. Install with: kt install inference") + raise typer.Exit(1) + + script_path = kt_kernel_path / "scripts" / "convert_cpu_weights.py" + if not script_path.exists(): + print_error(f"Conversion script not found: {script_path}") + raise typer.Exit(1) + + # Build command + cmd = [ + sys.executable, str(script_path), + "--input-path", str(input_path), + "--input-type", input_type, + "--output", str(output), + "--quant-method", method.value, + "--cpuinfer-threads", str(final_cpu_threads), + "--threadpool-count", str(final_numa_nodes), + ] + + if no_merge: + cmd.append("--no-merge-safetensor") + + # Run quantization + console.print() + print_step(t("quant_starting")) + console.print() + console.print(f"[dim]$ {' '.join(cmd)}[/dim]") + console.print() + + try: + process = subprocess.run(cmd) + + if process.returncode == 0: + console.print() + print_success(t("quant_complete")) + console.print() + console.print(f" Quantized weights saved to: {output}") + console.print() + console.print(" Use with:") + console.print(f" kt run {model} --weights-path {output}") + console.print() + else: + print_error(f"Quantization failed with exit code {process.returncode}") + raise typer.Exit(process.returncode) + + except FileNotFoundError as e: + print_error(f"Failed to run quantization: {e}") + raise typer.Exit(1) + except KeyboardInterrupt: + console.print() + print_warning("Quantization interrupted.") + raise typer.Exit(130) + + +def _resolve_input_path(model: str, settings) -> Optional[Path]: + """Resolve the input model path.""" + # Check if it's already a path + path = Path(model) + if path.exists() and (path / "config.json").exists(): + return path + + # Search in models directory + from kt_kernel.cli.utils.model_registry import get_registry + + registry = get_registry() + matches = registry.search(model) + + if matches: + model_info = matches[0] + # Try to find in all configured model directories + model_paths = settings.get_model_paths() + + for models_dir in model_paths: + possible_paths = [ + models_dir / model_info.name, + models_dir / model_info.name.lower(), + models_dir / model_info.hf_repo.split("/")[-1], + ] + + for p in possible_paths: + if p.exists() and (p / "config.json").exists(): + return p + + return None + + +def _find_kt_kernel_path() -> Optional[Path]: + """Find the kt-kernel installation path.""" + try: + import kt_kernel + return Path(kt_kernel.__file__).parent.parent + except ImportError: + pass + + # Check common locations + possible_paths = [ + Path.home() / "Projects" / "ktransformers" / "kt-kernel", + Path.cwd().parent / "kt-kernel", + Path.cwd() / "kt-kernel", + ] + + for path in possible_paths: + if path.exists() and (path / "scripts").exists(): + return path + + return None diff --git a/kt-kernel/python/cli/commands/run.py b/kt-kernel/python/cli/commands/run.py new file mode 100644 index 0000000..7bcc085 --- /dev/null +++ b/kt-kernel/python/cli/commands/run.py @@ -0,0 +1,831 @@ +""" +Run command for kt-cli. + +Starts the model inference server using SGLang + kt-kernel. +""" + +import os +import subprocess +import sys +from pathlib import Path +from typing import Optional + +import typer + +from kt_kernel.cli.config.settings import get_settings +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import ( + confirm, + console, + print_api_info, + print_error, + print_info, + print_server_info, + print_step, + print_success, + print_warning, + prompt_choice, +) +from kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect_ram_gb +from kt_kernel.cli.utils.model_registry import MODEL_COMPUTE_FUNCTIONS, ModelInfo, get_registry + + +def run( + model: Optional[str] = typer.Argument( + None, + help="Model name or path (e.g., deepseek-v3, qwen3-30b). If not specified, shows interactive selection.", + ), + host: str = typer.Option( + None, + "--host", + "-H", + help="Server host address", + ), + port: int = typer.Option( + None, + "--port", + "-p", + help="Server port", + ), + # CPU/GPU configuration + gpu_experts: Optional[int] = typer.Option( + None, + "--gpu-experts", + help="Number of GPU experts per layer", + ), + cpu_threads: Optional[int] = typer.Option( + None, + "--cpu-threads", + help="Number of CPU inference threads (kt-cpuinfer, defaults to 80% of CPU cores)", + ), + numa_nodes: Optional[int] = typer.Option( + None, + "--numa-nodes", + help="Number of NUMA nodes", + ), + tensor_parallel_size: Optional[int] = typer.Option( + None, + "--tensor-parallel-size", + "--tp", + help="Tensor parallel size (number of GPUs)", + ), + # Model paths + model_path: Optional[Path] = typer.Option( + None, + "--model-path", + help="Custom model path", + ), + weights_path: Optional[Path] = typer.Option( + None, + "--weights-path", + help="Custom quantized weights path", + ), + # KT-kernel options + kt_method: Optional[str] = typer.Option( + None, + "--kt-method", + help="KT quantization method (AMXINT4, RAWFP8, etc.)", + ), + kt_gpu_prefill_token_threshold: Optional[int] = typer.Option( + None, + "--kt-gpu-prefill-threshold", + help="GPU prefill token threshold for kt-kernel", + ), + # SGLang options + attention_backend: Optional[str] = typer.Option( + None, + "--attention-backend", + help="Attention backend (triton, flashinfer)", + ), + max_total_tokens: Optional[int] = typer.Option( + None, + "--max-total-tokens", + help="Maximum total tokens", + ), + max_running_requests: Optional[int] = typer.Option( + None, + "--max-running-requests", + help="Maximum running requests", + ), + chunked_prefill_size: Optional[int] = typer.Option( + None, + "--chunked-prefill-size", + help="Chunked prefill size", + ), + mem_fraction_static: Optional[float] = typer.Option( + None, + "--mem-fraction-static", + help="Memory fraction for static allocation", + ), + watchdog_timeout: Optional[int] = typer.Option( + None, + "--watchdog-timeout", + help="Watchdog timeout in seconds", + ), + served_model_name: Optional[str] = typer.Option( + None, + "--served-model-name", + help="Custom model name for API responses", + ), + # Performance flags + disable_shared_experts_fusion: Optional[bool] = typer.Option( + None, + "--disable-shared-experts-fusion/--enable-shared-experts-fusion", + help="Disable/enable shared experts fusion", + ), + # Other options + quantize: bool = typer.Option( + False, + "--quantize", + "-q", + help="Quantize model if weights not found", + ), + advanced: bool = typer.Option( + False, + "--advanced", + help="Show advanced options", + ), + dry_run: bool = typer.Option( + False, + "--dry-run", + help="Show command without executing", + ), +) -> None: + """Start model inference server.""" + # Check if SGLang is installed before proceeding + from kt_kernel.cli.utils.sglang_checker import ( + check_sglang_installation, + check_sglang_kt_kernel_support, + print_sglang_install_instructions, + print_sglang_kt_kernel_instructions, + ) + + sglang_info = check_sglang_installation() + if not sglang_info["installed"]: + console.print() + print_error(t("sglang_not_found")) + console.print() + print_sglang_install_instructions() + raise typer.Exit(1) + + # Check if SGLang supports kt-kernel (has --kt-gpu-prefill-token-threshold parameter) + kt_kernel_support = check_sglang_kt_kernel_support() + if not kt_kernel_support["supported"]: + console.print() + print_error(t("sglang_kt_kernel_not_supported")) + console.print() + print_sglang_kt_kernel_instructions() + raise typer.Exit(1) + + settings = get_settings() + registry = get_registry() + + console.print() + + # If no model specified, show interactive selection + if model is None: + model = _interactive_model_selection(registry, settings) + if model is None: + raise typer.Exit(0) + + # Step 1: Detect hardware + print_step(t("run_detecting_hardware")) + gpus = detect_gpus() + cpu = detect_cpu_info() + ram = detect_ram_gb() + + if gpus: + gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)" + if len(gpus) > 1: + gpu_info += f" + {len(gpus) - 1} more" + print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb)) + else: + print_warning(t("doctor_gpu_not_found")) + gpu_info = "None" + + print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes)) + print_info(t("run_ram_info", total=int(ram))) + + # Step 2: Resolve model + console.print() + print_step(t("run_checking_model")) + + model_info = None + resolved_model_path = model_path + + # Check if model is a path + if Path(model).exists(): + resolved_model_path = Path(model) + print_info(t("run_model_path", path=str(resolved_model_path))) + + # Try to infer model type from path to use default configurations + # Check directory name against known models + dir_name = resolved_model_path.name.lower() + for registered_model in registry.list_all(): + # Check if directory name matches model name or aliases + if dir_name == registered_model.name.lower(): + model_info = registered_model + print_info(f"Detected model type: {registered_model.name}") + break + for alias in registered_model.aliases: + if dir_name == alias.lower() or alias.lower() in dir_name: + model_info = registered_model + print_info(f"Detected model type: {registered_model.name}") + break + if model_info: + break + + # Also check HuggingFace repo format (org--model) + if not model_info: + for registered_model in registry.list_all(): + repo_slug = registered_model.hf_repo.replace("/", "--").lower() + if repo_slug in dir_name or dir_name in repo_slug: + model_info = registered_model + print_info(f"Detected model type: {registered_model.name}") + break + + if not model_info: + print_warning("Could not detect model type from path. Using default parameters.") + console.print(" [dim]Tip: Use model name (e.g., 'kt run m2') to apply optimized configurations[/dim]") + else: + # Search in registry + matches = registry.search(model) + + if not matches: + print_error(t("run_model_not_found", name=model)) + console.print() + console.print("Available models:") + for m in registry.list_all()[:5]: + console.print(f" - {m.name} ({', '.join(m.aliases[:2])})") + raise typer.Exit(1) + + if len(matches) == 1: + model_info = matches[0] + else: + # Multiple matches - prompt user + console.print() + print_info(t("run_multiple_matches")) + choices = [f"{m.name} ({m.hf_repo})" for m in matches] + selected = prompt_choice(t("run_select_model"), choices) + idx = choices.index(selected) + model_info = matches[idx] + + # Find model path + if model_path is None: + resolved_model_path = _find_model_path(model_info, settings) + if resolved_model_path is None: + print_error(t("run_model_not_found", name=model_info.name)) + console.print() + console.print( + f" Download with: kt download {model_info.aliases[0] if model_info.aliases else model_info.name}" + ) + raise typer.Exit(1) + + print_info(t("run_model_path", path=str(resolved_model_path))) + + # Step 3: Check quantized weights (only if explicitly requested) + resolved_weights_path = None + + # Only use quantized weights if explicitly specified by user + if weights_path is not None: + # User explicitly specified weights path + resolved_weights_path = weights_path + if not resolved_weights_path.exists(): + print_error(t("run_weights_not_found")) + console.print(f" Path: {resolved_weights_path}") + raise typer.Exit(1) + print_info(f"Using quantized weights: {resolved_weights_path}") + elif quantize: + # User requested quantization + console.print() + print_step(t("run_quantizing")) + # TODO: Implement quantization + print_warning("Quantization not yet implemented. Please run 'kt quant' manually.") + raise typer.Exit(1) + else: + # Default: use original precision model without quantization + console.print() + print_info("Using original precision model (no quantization)") + + # Step 4: Build command + # Resolve all parameters (CLI > model defaults > config > auto-detect) + final_host = host or settings.get("server.host", "0.0.0.0") + final_port = port or settings.get("server.port", 30000) + + # Get defaults from model info if available + model_defaults = model_info.default_params if model_info else {} + + # Determine tensor parallel size first (needed for GPU expert calculation) + # Priority: CLI > model defaults > config > auto-detect (with model constraints) + + # Check if explicitly specified by user or configuration + explicitly_specified = ( + tensor_parallel_size # CLI argument (highest priority) + or model_defaults.get("tensor-parallel-size") # Model defaults + or settings.get("inference.tensor_parallel_size") # Config file + ) + + if explicitly_specified: + # Use explicitly specified value + requested_tensor_parallel_size = explicitly_specified + else: + # Auto-detect from GPUs, considering model's max constraint + detected_gpu_count = len(gpus) if gpus else 1 + if model_info and model_info.max_tensor_parallel_size is not None: + # Automatically limit to model's maximum to use as many GPUs as possible + requested_tensor_parallel_size = min(detected_gpu_count, model_info.max_tensor_parallel_size) + else: + requested_tensor_parallel_size = detected_gpu_count + + # Apply model's max_tensor_parallel_size constraint if explicitly specified value exceeds it + final_tensor_parallel_size = requested_tensor_parallel_size + if model_info and model_info.max_tensor_parallel_size is not None: + if requested_tensor_parallel_size > model_info.max_tensor_parallel_size: + console.print() + print_warning( + f"Model {model_info.name} only supports up to {model_info.max_tensor_parallel_size}-way " + f"tensor parallelism, but {requested_tensor_parallel_size} was requested. " + f"Reducing to {model_info.max_tensor_parallel_size}." + ) + final_tensor_parallel_size = model_info.max_tensor_parallel_size + + # CPU/GPU configuration with smart defaults + # kt-cpuinfer: default to 80% of total CPU threads (cores * NUMA nodes) + total_threads = cpu.cores * cpu.numa_nodes + final_cpu_threads = ( + cpu_threads + or model_defaults.get("kt-cpuinfer") + or settings.get("inference.cpu_threads") + or int(total_threads * 0.8) + ) + + # kt-threadpool-count: default to NUMA node count + final_numa_nodes = ( + numa_nodes + or model_defaults.get("kt-threadpool-count") + or settings.get("inference.numa_nodes") + or cpu.numa_nodes + ) + + # kt-num-gpu-experts: use model-specific computation if available and not explicitly set + if gpu_experts is not None: + # User explicitly set it + final_gpu_experts = gpu_experts + elif model_info and model_info.name in MODEL_COMPUTE_FUNCTIONS and gpus: + # Use model-specific computation function (only if GPUs detected) + vram_per_gpu = gpus[0].vram_gb + compute_func = MODEL_COMPUTE_FUNCTIONS[model_info.name] + final_gpu_experts = compute_func(final_tensor_parallel_size, vram_per_gpu) + console.print() + print_info( + f"Auto-computed kt-num-gpu-experts: {final_gpu_experts} (TP={final_tensor_parallel_size}, VRAM={vram_per_gpu}GB per GPU)" + ) + else: + # Fall back to defaults + final_gpu_experts = model_defaults.get("kt-num-gpu-experts") or settings.get("inference.gpu_experts", 1) + + # KT-kernel options + final_kt_method = kt_method or model_defaults.get("kt-method") or settings.get("inference.kt_method", "AMXINT4") + final_kt_gpu_prefill_threshold = ( + kt_gpu_prefill_token_threshold + or model_defaults.get("kt-gpu-prefill-token-threshold") + or settings.get("inference.kt_gpu_prefill_token_threshold", 4096) + ) + + # SGLang options + final_attention_backend = ( + attention_backend + or model_defaults.get("attention-backend") + or settings.get("inference.attention_backend", "triton") + ) + final_max_total_tokens = ( + max_total_tokens or model_defaults.get("max-total-tokens") or settings.get("inference.max_total_tokens", 40000) + ) + final_max_running_requests = ( + max_running_requests + or model_defaults.get("max-running-requests") + or settings.get("inference.max_running_requests", 32) + ) + final_chunked_prefill_size = ( + chunked_prefill_size + or model_defaults.get("chunked-prefill-size") + or settings.get("inference.chunked_prefill_size", 4096) + ) + final_mem_fraction_static = ( + mem_fraction_static + or model_defaults.get("mem-fraction-static") + or settings.get("inference.mem_fraction_static", 0.98) + ) + final_watchdog_timeout = ( + watchdog_timeout or model_defaults.get("watchdog-timeout") or settings.get("inference.watchdog_timeout", 3000) + ) + final_served_model_name = ( + served_model_name or model_defaults.get("served-model-name") or settings.get("inference.served_model_name", "") + ) + + # Performance flags + if disable_shared_experts_fusion is not None: + final_disable_shared_experts_fusion = disable_shared_experts_fusion + elif "disable-shared-experts-fusion" in model_defaults: + final_disable_shared_experts_fusion = model_defaults["disable-shared-experts-fusion"] + else: + final_disable_shared_experts_fusion = settings.get("inference.disable_shared_experts_fusion", False) + + # Pass all model default params to handle any extra parameters + extra_params = model_defaults if model_info else {} + + cmd = _build_sglang_command( + model_path=resolved_model_path, + weights_path=resolved_weights_path, + model_info=model_info, + host=final_host, + port=final_port, + gpu_experts=final_gpu_experts, + cpu_threads=final_cpu_threads, + numa_nodes=final_numa_nodes, + tensor_parallel_size=final_tensor_parallel_size, + kt_method=final_kt_method, + kt_gpu_prefill_threshold=final_kt_gpu_prefill_threshold, + attention_backend=final_attention_backend, + max_total_tokens=final_max_total_tokens, + max_running_requests=final_max_running_requests, + chunked_prefill_size=final_chunked_prefill_size, + mem_fraction_static=final_mem_fraction_static, + watchdog_timeout=final_watchdog_timeout, + served_model_name=final_served_model_name, + disable_shared_experts_fusion=final_disable_shared_experts_fusion, + settings=settings, + extra_model_params=extra_params, + ) + + # Prepare environment variables + env = os.environ.copy() + # Add environment variables from advanced.env + env.update(settings.get_env_vars()) + # Add environment variables from inference.env + inference_env = settings.get("inference.env", {}) + if isinstance(inference_env, dict): + env.update({k: str(v) for k, v in inference_env.items()}) + + # Step 5: Show configuration summary + console.print() + print_step("Configuration") + + # Model info + if model_info: + console.print(f" Model: [bold]{model_info.name}[/bold]") + else: + console.print(f" Model: [bold]{resolved_model_path.name}[/bold]") + + console.print(f" Path: [dim]{resolved_model_path}[/dim]") + + # Key parameters + console.print() + console.print(f" GPU Experts: [cyan]{final_gpu_experts}[/cyan] per layer") + console.print(f" CPU Threads (kt-cpuinfer): [cyan]{final_cpu_threads}[/cyan]") + console.print(f" NUMA Nodes (kt-threadpool-count): [cyan]{final_numa_nodes}[/cyan]") + console.print(f" Tensor Parallel: [cyan]{final_tensor_parallel_size}[/cyan]") + console.print(f" Method: [cyan]{final_kt_method}[/cyan]") + console.print(f" Attention: [cyan]{final_attention_backend}[/cyan]") + + # Weights info + if resolved_weights_path: + console.print() + console.print(f" Quantized weights: [yellow]{resolved_weights_path}[/yellow]") + + console.print() + console.print(f" Server: [green]http://{final_host}:{final_port}[/green]") + console.print() + + # Step 6: Show or execute + if dry_run: + console.print() + console.print("[bold]Command:[/bold]") + console.print() + console.print(f" [dim]{' '.join(cmd)}[/dim]") + console.print() + return + + # Execute with prepared environment variables + # Don't print "Server started" or API info here - let sglang's logs speak for themselves + # The actual startup takes time and these messages are misleading + + # Print the command being executed + console.print() + console.print("[bold]Launching server with command:[/bold]") + console.print() + console.print(f" [dim]{' '.join(cmd)}[/dim]") + console.print() + + try: + # Execute directly without intercepting output or signals + # This allows direct output to terminal and Ctrl+C to work naturally + process = subprocess.run(cmd, env=env) + sys.exit(process.returncode) + + except FileNotFoundError: + from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions + + print_error(t("sglang_not_found")) + console.print() + print_sglang_install_instructions() + raise typer.Exit(1) + except Exception as e: + print_error(f"Failed to start server: {e}") + raise typer.Exit(1) + + +def _find_model_path(model_info: ModelInfo, settings) -> Optional[Path]: + """Find the model path on disk by searching all configured model paths.""" + model_paths = settings.get_model_paths() + + # Search in all configured model directories + for models_dir in model_paths: + # Check common path patterns + possible_paths = [ + models_dir / model_info.name, + models_dir / model_info.name.lower(), + models_dir / model_info.name.replace(" ", "-"), + models_dir / model_info.hf_repo.split("/")[-1], + models_dir / model_info.hf_repo.replace("/", "--"), + ] + + # Add alias-based paths + for alias in model_info.aliases: + possible_paths.append(models_dir / alias) + possible_paths.append(models_dir / alias.lower()) + + for path in possible_paths: + if path.exists() and (path / "config.json").exists(): + return path + + return None + + +def _find_weights_path(model_info: ModelInfo, settings) -> Optional[Path]: + """Find the quantized weights path on disk by searching all configured paths.""" + model_paths = settings.get_model_paths() + weights_dir = settings.weights_dir + + # Check common patterns + base_names = [ + model_info.name, + model_info.name.lower(), + model_info.hf_repo.split("/")[-1], + ] + + suffixes = ["-INT4", "-int4", "_INT4", "_int4", "-quant", "-quantized"] + + # Prepare search directories + search_dirs = [weights_dir] if weights_dir else [] + search_dirs.extend(model_paths) + + for base in base_names: + for suffix in suffixes: + for dir_path in search_dirs: + if dir_path: + path = dir_path / f"{base}{suffix}" + if path.exists(): + return path + + return None + + +def _build_sglang_command( + model_path: Path, + weights_path: Optional[Path], + model_info: Optional[ModelInfo], + host: str, + port: int, + gpu_experts: int, + cpu_threads: int, + numa_nodes: int, + tensor_parallel_size: int, + kt_method: str, + kt_gpu_prefill_threshold: int, + attention_backend: str, + max_total_tokens: int, + max_running_requests: int, + chunked_prefill_size: int, + mem_fraction_static: float, + watchdog_timeout: int, + served_model_name: str, + disable_shared_experts_fusion: bool, + settings, + extra_model_params: Optional[dict] = None, # New parameter for additional params +) -> list[str]: + """Build the SGLang launch command.""" + cmd = [ + sys.executable, + "-m", + "sglang.launch_server", + "--host", + host, + "--port", + str(port), + "--model", + str(model_path), + ] + + # Add kt-kernel options + # kt-kernel is needed for: + # 1. Quantized models (when weights_path is provided) + # 2. MoE models with CPU offloading (when kt-cpuinfer > 0 or kt-num-gpu-experts is configured) + use_kt_kernel = False + + # Check if we should use kt-kernel + if weights_path: + # Quantized model - always use kt-kernel + use_kt_kernel = True + elif cpu_threads > 0 or gpu_experts > 1: + # CPU offloading configured - use kt-kernel + use_kt_kernel = True + elif model_info and model_info.type == "moe": + # MoE model - likely needs kt-kernel for expert offloading + use_kt_kernel = True + + if use_kt_kernel: + # Add kt-weight-path: use quantized weights if available, otherwise use model path + weight_path_to_use = weights_path if weights_path else model_path + + # Add kt-kernel configuration + cmd.extend( + [ + "--kt-weight-path", + str(weight_path_to_use), + "--kt-cpuinfer", + str(cpu_threads), + "--kt-threadpool-count", + str(numa_nodes), + "--kt-num-gpu-experts", + str(gpu_experts), + "--kt-method", + kt_method, + "--kt-gpu-prefill-token-threshold", + str(kt_gpu_prefill_threshold), + ] + ) + + # Add SGLang options + cmd.extend( + [ + "--attention-backend", + attention_backend, + "--trust-remote-code", + "--mem-fraction-static", + str(mem_fraction_static), + "--chunked-prefill-size", + str(chunked_prefill_size), + "--max-running-requests", + str(max_running_requests), + "--max-total-tokens", + str(max_total_tokens), + "--watchdog-timeout", + str(watchdog_timeout), + "--enable-mixed-chunk", + "--tensor-parallel-size", + str(tensor_parallel_size), + "--enable-p2p-check", + ] + ) + + # Add served model name if specified + if served_model_name: + cmd.extend(["--served-model-name", served_model_name]) + + # Add performance flags + if disable_shared_experts_fusion: + cmd.append("--disable-shared-experts-fusion") + + # Add any extra parameters from model defaults that weren't explicitly handled + if extra_model_params: + # List of parameters already handled above + handled_params = { + "kt-num-gpu-experts", + "kt-cpuinfer", + "kt-threadpool-count", + "kt-method", + "kt-gpu-prefill-token-threshold", + "attention-backend", + "tensor-parallel-size", + "max-total-tokens", + "max-running-requests", + "chunked-prefill-size", + "mem-fraction-static", + "watchdog-timeout", + "served-model-name", + "disable-shared-experts-fusion", + } + + for key, value in extra_model_params.items(): + if key not in handled_params: + # Add unhandled parameters dynamically + cmd.append(f"--{key}") + if isinstance(value, bool): + # Boolean flags don't need a value + if not value: + # For False boolean, skip the flag entirely + cmd.pop() # Remove the flag we just added + else: + cmd.append(str(value)) + + # Add extra args from settings + extra_args = settings.get("advanced.sglang_args", []) + if extra_args: + cmd.extend(extra_args) + + return cmd + + +def _interactive_model_selection(registry, settings) -> Optional[str]: + """Show interactive model selection interface. + + Returns: + Selected model name or None if cancelled. + """ + from rich.panel import Panel + from rich.table import Table + from rich.prompt import Prompt + + from kt_kernel.cli.i18n import get_lang + + lang = get_lang() + + # Find local models first + local_models = registry.find_local_models() + + # Get all registered models + all_models = registry.list_all() + + console.print() + console.print( + Panel.fit( + t("run_select_model_title"), + border_style="cyan", + ) + ) + console.print() + + # Build choices list + choices = [] + choice_map = {} # index -> model name + + # Section 1: Local models (downloaded) + if local_models: + console.print(f"[bold green]{t('run_local_models')}[/bold green]") + console.print() + + for i, (model_info, path) in enumerate(local_models, 1): + desc = model_info.description_zh if lang == "zh" else model_info.description + short_desc = desc[:50] + "..." if len(desc) > 50 else desc + console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]") + console.print(f" [dim]{short_desc}[/dim]") + console.print(f" [dim]{path}[/dim]") + choices.append(str(i)) + choice_map[str(i)] = model_info.name + + console.print() + + # Section 2: All registered models (for reference) + start_idx = len(local_models) + 1 + console.print(f"[bold yellow]{t('run_registered_models')}[/bold yellow]") + console.print() + + # Filter out already shown local models + local_model_names = {m.name for m, _ in local_models} + + for i, model_info in enumerate(all_models, start_idx): + if model_info.name in local_model_names: + continue + + desc = model_info.description_zh if lang == "zh" else model_info.description + short_desc = desc[:50] + "..." if len(desc) > 50 else desc + console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]") + console.print(f" [dim]{short_desc}[/dim]") + console.print(f" [dim]{model_info.hf_repo}[/dim]") + choices.append(str(i)) + choice_map[str(i)] = model_info.name + + console.print() + + # Add cancel option + cancel_idx = str(len(choices) + 1) + console.print(f" [cyan][{cancel_idx}][/cyan] [dim]{t('cancel')}[/dim]") + choices.append(cancel_idx) + console.print() + + # Prompt for selection + try: + selection = Prompt.ask( + t("run_select_model_prompt"), + choices=choices, + default="1" if choices else cancel_idx, + ) + except KeyboardInterrupt: + console.print() + return None + + if selection == cancel_idx: + return None + + return choice_map.get(selection) diff --git a/kt-kernel/python/cli/commands/sft.py b/kt-kernel/python/cli/commands/sft.py new file mode 100644 index 0000000..9a665f2 --- /dev/null +++ b/kt-kernel/python/cli/commands/sft.py @@ -0,0 +1,52 @@ +""" +SFT command for kt-cli. + +Fine-tuning with LlamaFactory integration. +""" + +import typer + +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import console + +app = typer.Typer(help="Fine-tuning with LlamaFactory (coming soon)") + + +@app.callback(invoke_without_command=True) +def callback(ctx: typer.Context) -> None: + """Fine-tuning commands (coming soon).""" + if ctx.invoked_subcommand is None: + console.print() + console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]") + console.print() + console.print("[dim]kt sft train - Train a model[/dim]") + console.print("[dim]kt sft chat - Chat with a trained model[/dim]") + console.print("[dim]kt sft export - Export a trained model[/dim]") + console.print() + + +@app.command(name="train") +def train() -> None: + """Train a model using LlamaFactory (coming soon).""" + console.print() + console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]") + console.print() + raise typer.Exit(0) + + +@app.command(name="chat") +def chat() -> None: + """Chat with a trained model using LlamaFactory (coming soon).""" + console.print() + console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]") + console.print() + raise typer.Exit(0) + + +@app.command(name="export") +def export() -> None: + """Export a trained model using LlamaFactory (coming soon).""" + console.print() + console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]") + console.print() + raise typer.Exit(0) diff --git a/kt-kernel/python/cli/commands/version.py b/kt-kernel/python/cli/commands/version.py new file mode 100644 index 0000000..3d4adf2 --- /dev/null +++ b/kt-kernel/python/cli/commands/version.py @@ -0,0 +1,118 @@ +""" +Version command for kt-cli. + +Displays version information for kt-cli and related packages. +""" + +import platform +from typing import Optional + +import typer + +from kt_kernel.cli import __version__ +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import console, print_version_table +from kt_kernel.cli.utils.environment import detect_cuda_version, get_installed_package_version + + +def _get_sglang_info() -> str: + """Get sglang version and installation source information.""" + try: + import sglang + + version = getattr(sglang, "__version__", None) + + if not version: + version = get_installed_package_version("sglang") + + if not version: + return t("version_not_installed") + + # Try to detect installation source + from pathlib import Path + import subprocess + + if hasattr(sglang, "__file__") and sglang.__file__: + location = Path(sglang.__file__).parent.parent + git_dir = location / ".git" + + if git_dir.exists(): + # Installed from git (editable install) + try: + # Get remote URL + result = subprocess.run( + ["git", "remote", "get-url", "origin"], + cwd=location, + capture_output=True, + text=True, + timeout=2, + ) + if result.returncode == 0: + remote_url = result.stdout.strip() + # Simplify GitHub URLs + if "github.com" in remote_url: + repo_name = remote_url.split("/")[-1].replace(".git", "") + owner = remote_url.split("/")[-2] + return f"{version} [dim](GitHub: {owner}/{repo_name})[/dim]" + return f"{version} [dim](Git: {remote_url})[/dim]" + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + pass + + # Default: installed from PyPI + return f"{version} [dim](PyPI)[/dim]" + + except ImportError: + return t("version_not_installed") + + +def version( + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed version info"), +) -> None: + """Show version information.""" + console.print(f"\n[bold]{t('version_info')}[/bold] v{__version__}\n") + + # Basic info + versions = { + t("version_python"): platform.python_version(), + t("version_platform"): f"{platform.system()} {platform.release()}", + } + + # CUDA version + cuda_version = detect_cuda_version() + versions[t("version_cuda")] = cuda_version or t("version_cuda_not_found") + + print_version_table(versions) + + # Always show key packages with installation source + console.print("\n[bold]Packages:[/bold]\n") + + sglang_info = _get_sglang_info() + key_packages = { + t("version_kt_kernel"): get_installed_package_version("kt-kernel") or t("version_not_installed"), + t("version_sglang"): sglang_info, + } + + print_version_table(key_packages) + + # Show SGLang installation hint if not installed + if sglang_info == t("version_not_installed"): + from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions + + console.print() + print_sglang_install_instructions() + + if verbose: + console.print("\n[bold]Additional Packages:[/bold]\n") + + package_versions = { + t("version_ktransformers"): get_installed_package_version("ktransformers") or t("version_not_installed"), + t("version_llamafactory"): get_installed_package_version("llamafactory") or t("version_not_installed"), + "typer": get_installed_package_version("typer") or t("version_not_installed"), + "rich": get_installed_package_version("rich") or t("version_not_installed"), + "torch": get_installed_package_version("torch") or t("version_not_installed"), + "transformers": get_installed_package_version("transformers") or t("version_not_installed"), + } + + print_version_table(package_versions) + + console.print() diff --git a/kt-kernel/python/cli/completions/__init__.py b/kt-kernel/python/cli/completions/__init__.py new file mode 100644 index 0000000..d32c4eb --- /dev/null +++ b/kt-kernel/python/cli/completions/__init__.py @@ -0,0 +1 @@ +"""Shell completion scripts for kt-cli.""" diff --git a/kt-kernel/python/cli/completions/_kt b/kt-kernel/python/cli/completions/_kt new file mode 100644 index 0000000..e3850dd --- /dev/null +++ b/kt-kernel/python/cli/completions/_kt @@ -0,0 +1,153 @@ +#compdef kt +# Zsh completion for kt command +# This is a static completion script that doesn't require Python startup + +_kt() { + local -a commands + commands=( + 'version:Show version information' + 'run:Start model inference server' + 'chat:Interactive chat with running model' + 'quant:Quantize model weights' + 'bench:Run full benchmark' + 'microbench:Run micro-benchmark' + 'doctor:Diagnose environment issues' + 'model:Manage models and storage paths' + 'config:Manage configuration' + 'sft:Fine-tuning with LlamaFactory' + ) + + local -a run_opts + run_opts=( + '--host[Server host]:host:' + '--port[Server port]:port:' + '--gpu-experts[Number of GPU experts]:count:' + '--cpu-threads[Number of CPU threads]:count:' + '--tensor-parallel-size[Tensor parallel size]:size:' + '--kt-method[KT method]:method:(AMXINT4 FP8 RAWINT4)' + '--attention-backend[Attention backend]:backend:(triton flashinfer)' + '--max-total-tokens[Maximum total tokens]:tokens:' + '--dry-run[Show command without executing]' + '--help[Show help message]' + ) + + local -a chat_opts + chat_opts=( + '--host[Server host]:host:' + '--port[Server port]:port:' + '--model[Model name]:model:' + '--temperature[Sampling temperature]:temp:' + '--max-tokens[Maximum tokens]:tokens:' + '--system[System prompt]:prompt:' + '--save-history[Save conversation history]' + '--no-save-history[Do not save history]' + '--history-file[History file path]:path:_files' + '--stream[Enable streaming output]' + '--no-stream[Disable streaming output]' + '--help[Show help message]' + ) + + local -a model_cmds + model_cmds=( + 'download:Download a model from HuggingFace' + 'list:List available models' + 'path-list:List all model storage paths' + 'path-add:Add a new model storage path' + 'path-remove:Remove a model storage path' + 'search:Search for models in the registry' + ) + + local -a config_cmds + config_cmds=( + 'show:Show all configuration' + 'get:Get configuration value' + 'set:Set configuration value' + 'reset:Reset to defaults' + 'path:Show configuration file path' + 'init:Re-run first-time setup wizard' + ) + + local -a sft_cmds + sft_cmds=( + 'train:Train model' + 'chat:Chat with model' + 'export:Export model' + ) + + _arguments -C \ + '1: :->command' \ + '*::arg:->args' + + case $state in + command) + _describe 'kt commands' commands + _arguments \ + '--help[Show help message]' \ + '--version[Show version]' + ;; + args) + case $words[1] in + run) + _arguments $run_opts \ + '1:model:' + ;; + chat) + _arguments $chat_opts + ;; + quant) + _arguments \ + '--method[Quantization method]:method:' \ + '--output[Output directory]:path:_files -/' \ + '--help[Show help message]' \ + '1:model:_files -/' + ;; + bench|microbench) + _arguments \ + '--model[Model name or path]:model:' \ + '--config[Config file path]:path:_files' \ + '--help[Show help message]' + ;; + doctor) + _arguments \ + '--verbose[Verbose output]' \ + '--help[Show help message]' + ;; + model) + _arguments \ + '1: :->model_cmd' \ + '*::arg:->model_args' + + case $state in + model_cmd) + _describe 'model commands' model_cmds + ;; + esac + ;; + config) + _arguments \ + '1: :->config_cmd' \ + '*::arg:->config_args' + + case $state in + config_cmd) + _describe 'config commands' config_cmds + ;; + esac + ;; + sft) + _arguments \ + '1: :->sft_cmd' \ + '*::arg:->sft_args' + + case $state in + sft_cmd) + _describe 'sft commands' sft_cmds + ;; + esac + ;; + esac + ;; + esac +} + +_kt "$@" diff --git a/kt-kernel/python/cli/completions/kt-completion.bash b/kt-kernel/python/cli/completions/kt-completion.bash new file mode 100644 index 0000000..8f1d3be --- /dev/null +++ b/kt-kernel/python/cli/completions/kt-completion.bash @@ -0,0 +1,73 @@ +#!/bin/bash +# Bash completion for kt command +# This is a static completion script that doesn't require Python startup + +_kt_completion() { + local cur prev opts + COMPREPLY=() + cur="${COMP_WORDS[COMP_CWORD]}" + prev="${COMP_WORDS[COMP_CWORD-1]}" + + # Main commands + local commands="version run chat quant bench microbench doctor model config sft" + + # Global options + local global_opts="--help --version" + + # Handle subcommands + case "${COMP_CWORD}" in + 1) + # First argument: suggest commands and global options + COMPREPLY=( $(compgen -W "${commands} ${global_opts}" -- ${cur}) ) + return 0 + ;; + *) + # Handle specific command options + case "${COMP_WORDS[1]}" in + run) + local run_opts="--host --port --gpu-experts --cpu-threads --tensor-parallel-size --kt-method --attention-backend --max-total-tokens --dry-run --help" + COMPREPLY=( $(compgen -W "${run_opts}" -- ${cur}) ) + ;; + chat) + local chat_opts="--host --port --model --temperature --max-tokens --system --save-history --no-save-history --history-file --stream --no-stream --help" + COMPREPLY=( $(compgen -W "${chat_opts}" -- ${cur}) ) + ;; + quant) + local quant_opts="--method --output --help" + COMPREPLY=( $(compgen -W "${quant_opts}" -- ${cur}) ) + ;; + bench|microbench) + local bench_opts="--model --config --help" + COMPREPLY=( $(compgen -W "${bench_opts}" -- ${cur}) ) + ;; + doctor) + local doctor_opts="--verbose --help" + COMPREPLY=( $(compgen -W "${doctor_opts}" -- ${cur}) ) + ;; + model) + local model_cmds="download list path-list path-add path-remove search" + local model_opts="--help" + COMPREPLY=( $(compgen -W "${model_cmds} ${model_opts}" -- ${cur}) ) + ;; + config) + local config_cmds="show get set reset path init model-path-list model-path-add model-path-remove" + local config_opts="--help" + COMPREPLY=( $(compgen -W "${config_cmds} ${config_opts}" -- ${cur}) ) + ;; + sft) + local sft_cmds="train chat export" + local sft_opts="--help" + COMPREPLY=( $(compgen -W "${sft_cmds} ${sft_opts}" -- ${cur}) ) + ;; + version) + COMPREPLY=( $(compgen -W "--help" -- ${cur}) ) + ;; + *) + COMPREPLY=() + ;; + esac + ;; + esac +} + +complete -F _kt_completion kt diff --git a/kt-kernel/python/cli/completions/kt.fish b/kt-kernel/python/cli/completions/kt.fish new file mode 100644 index 0000000..7b85504 --- /dev/null +++ b/kt-kernel/python/cli/completions/kt.fish @@ -0,0 +1,74 @@ +# Fish completion for kt command +# This is a static completion script that doesn't require Python startup + +# Main commands +complete -c kt -f -n "__fish_use_subcommand" -a "version" -d "Show version information" +complete -c kt -f -n "__fish_use_subcommand" -a "run" -d "Start model inference server" +complete -c kt -f -n "__fish_use_subcommand" -a "chat" -d "Interactive chat with running model" +complete -c kt -f -n "__fish_use_subcommand" -a "quant" -d "Quantize model weights" +complete -c kt -f -n "__fish_use_subcommand" -a "bench" -d "Run full benchmark" +complete -c kt -f -n "__fish_use_subcommand" -a "microbench" -d "Run micro-benchmark" +complete -c kt -f -n "__fish_use_subcommand" -a "doctor" -d "Diagnose environment issues" +complete -c kt -f -n "__fish_use_subcommand" -a "model" -d "Manage models and storage paths" +complete -c kt -f -n "__fish_use_subcommand" -a "config" -d "Manage configuration" +complete -c kt -f -n "__fish_use_subcommand" -a "sft" -d "Fine-tuning with LlamaFactory" + +# Global options +complete -c kt -l help -d "Show help message" +complete -c kt -l version -d "Show version" + +# Run command options +complete -c kt -f -n "__fish_seen_subcommand_from run" -l host -d "Server host" +complete -c kt -f -n "__fish_seen_subcommand_from run" -l port -d "Server port" +complete -c kt -f -n "__fish_seen_subcommand_from run" -l gpu-experts -d "Number of GPU experts" +complete -c kt -f -n "__fish_seen_subcommand_from run" -l cpu-threads -d "Number of CPU threads" +complete -c kt -f -n "__fish_seen_subcommand_from run" -l tensor-parallel-size -d "Tensor parallel size" +complete -c kt -f -n "__fish_seen_subcommand_from run" -l kt-method -d "KT method" +complete -c kt -f -n "__fish_seen_subcommand_from run" -l attention-backend -d "Attention backend" +complete -c kt -f -n "__fish_seen_subcommand_from run" -l max-total-tokens -d "Maximum total tokens" +complete -c kt -f -n "__fish_seen_subcommand_from run" -l dry-run -d "Show command without executing" + +# Chat command options +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l host -d "Server host" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l port -d "Server port" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l model -d "Model name" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l temperature -d "Sampling temperature" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l max-tokens -d "Maximum tokens" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l system -d "System prompt" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l save-history -d "Save conversation history" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l no-save-history -d "Do not save history" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l history-file -d "History file path" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l stream -d "Enable streaming output" +complete -c kt -f -n "__fish_seen_subcommand_from chat" -l no-stream -d "Disable streaming output" + +# Quant command options +complete -c kt -f -n "__fish_seen_subcommand_from quant" -l method -d "Quantization method" +complete -c kt -f -n "__fish_seen_subcommand_from quant" -l output -d "Output directory" + +# Bench command options +complete -c kt -f -n "__fish_seen_subcommand_from bench microbench" -l model -d "Model name or path" +complete -c kt -f -n "__fish_seen_subcommand_from bench microbench" -l config -d "Config file path" + +# Doctor command options +complete -c kt -f -n "__fish_seen_subcommand_from doctor" -l verbose -d "Verbose output" + +# Model subcommands +complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "download" -d "Download a model from HuggingFace" +complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "list" -d "List available models" +complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "path-list" -d "List all model storage paths" +complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "path-add" -d "Add a new model storage path" +complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "path-remove" -d "Remove a model storage path" +complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "search" -d "Search for models in the registry" + +# Config subcommands +complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "show" -d "Show all configuration" +complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "get" -d "Get configuration value" +complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "set" -d "Set configuration value" +complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "reset" -d "Reset to defaults" +complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "path" -d "Show configuration file path" +complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "init" -d "Re-run first-time setup wizard" + +# SFT subcommands +complete -c kt -f -n "__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export" -a "train" -d "Train model" +complete -c kt -f -n "__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export" -a "chat" -d "Chat with model" +complete -c kt -f -n "__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export" -a "export" -d "Export model" diff --git a/kt-kernel/python/cli/config/__init__.py b/kt-kernel/python/cli/config/__init__.py new file mode 100644 index 0000000..2d6a11d --- /dev/null +++ b/kt-kernel/python/cli/config/__init__.py @@ -0,0 +1,7 @@ +""" +Configuration management for kt-cli. +""" + +from kt_kernel.cli.config.settings import Settings, get_settings + +__all__ = ["Settings", "get_settings"] diff --git a/kt-kernel/python/cli/config/settings.py b/kt-kernel/python/cli/config/settings.py new file mode 100644 index 0000000..2cc5f05 --- /dev/null +++ b/kt-kernel/python/cli/config/settings.py @@ -0,0 +1,311 @@ +""" +Configuration management for kt-cli. + +Handles reading and writing configuration from ~/.ktransformers/config.yaml +""" + +import os +from pathlib import Path +from typing import Any, Optional + +import yaml + +# Default configuration directory +DEFAULT_CONFIG_DIR = Path.home() / ".ktransformers" +DEFAULT_CONFIG_FILE = DEFAULT_CONFIG_DIR / "config.yaml" +DEFAULT_MODELS_DIR = DEFAULT_CONFIG_DIR / "models" +DEFAULT_CACHE_DIR = DEFAULT_CONFIG_DIR / "cache" + +# Default configuration values +DEFAULT_CONFIG = { + "general": { + "language": "auto", # auto, en, zh + "color": True, + "verbose": False, + }, + "paths": { + "models": str(DEFAULT_MODELS_DIR), + "cache": str(DEFAULT_CACHE_DIR), + "weights": "", # Custom quantized weights path + }, + "server": { + "host": "0.0.0.0", + "port": 30000, + }, + "inference": { + # Inference parameters are model-specific and should not have defaults + # They will be auto-detected or use model-specific optimizations + # Environment variables (general optimizations) + "env": { + "PYTORCH_ALLOC_CONF": "expandable_segments:True", + "SGLANG_ENABLE_JIT_DEEPGEMM": "0", + }, + }, + "download": { + "mirror": "", # HuggingFace mirror URL + "resume": True, + "verify": True, + }, + "advanced": { + # Environment variables to set when running + "env": {}, + # Extra arguments to pass to sglang + "sglang_args": [], + # Extra arguments to pass to llamafactory + "llamafactory_args": [], + }, + "dependencies": { + # SGLang installation source configuration + "sglang": { + "source": "github", # "pypi" or "github" + "repo": "https://github.com/kvcache-ai/sglang", + "branch": "main", + }, + }, +} + + +class Settings: + """Configuration manager for kt-cli.""" + + def __init__(self, config_path: Optional[Path] = None): + """Initialize settings manager. + + Args: + config_path: Path to config file. Defaults to ~/.ktransformers/config.yaml + """ + self.config_path = config_path or DEFAULT_CONFIG_FILE + self.config_dir = self.config_path.parent + self._config: dict[str, Any] = {} + self._load() + + def _ensure_dirs(self) -> None: + """Ensure configuration directories exist.""" + self.config_dir.mkdir(parents=True, exist_ok=True) + + # Ensure all model paths exist + model_paths = self.get_model_paths() + for path in model_paths: + path.mkdir(parents=True, exist_ok=True) + + Path(self.get("paths.cache", DEFAULT_CACHE_DIR)).mkdir(parents=True, exist_ok=True) + + def _load(self) -> None: + """Load configuration from file.""" + self._config = self._deep_copy(DEFAULT_CONFIG) + + if self.config_path.exists(): + try: + with open(self.config_path, "r", encoding="utf-8") as f: + user_config = yaml.safe_load(f) or {} + self._deep_merge(self._config, user_config) + except (yaml.YAMLError, OSError) as e: + # Log warning but continue with defaults + print(f"Warning: Failed to load config: {e}") + + self._ensure_dirs() + + def _save(self) -> None: + """Save configuration to file.""" + self._ensure_dirs() + try: + with open(self.config_path, "w", encoding="utf-8") as f: + yaml.dump(self._config, f, default_flow_style=False, allow_unicode=True) + except OSError as e: + raise RuntimeError(f"Failed to save config: {e}") + + def _deep_copy(self, obj: Any) -> Any: + """Create a deep copy of a nested dict.""" + if isinstance(obj, dict): + return {k: self._deep_copy(v) for k, v in obj.items()} + if isinstance(obj, list): + return [self._deep_copy(item) for item in obj] + return obj + + def _deep_merge(self, base: dict, override: dict) -> None: + """Deep merge override into base.""" + for key, value in override.items(): + if key in base and isinstance(base[key], dict) and isinstance(value, dict): + self._deep_merge(base[key], value) + else: + base[key] = value + + def get(self, key: str, default: Any = None) -> Any: + """Get a configuration value by dot-separated key. + + Args: + key: Dot-separated key path (e.g., "server.port") + default: Default value if key not found + + Returns: + Configuration value or default + """ + parts = key.split(".") + value = self._config + + for part in parts: + if isinstance(value, dict) and part in value: + value = value[part] + else: + return default + + return value + + def set(self, key: str, value: Any) -> None: + """Set a configuration value by dot-separated key. + + Args: + key: Dot-separated key path (e.g., "server.port") + value: Value to set + """ + parts = key.split(".") + config = self._config + + # Navigate to parent + for part in parts[:-1]: + if part not in config: + config[part] = {} + config = config[part] + + # Set value + config[parts[-1]] = value + self._save() + + def delete(self, key: str) -> bool: + """Delete a configuration value. + + Args: + key: Dot-separated key path + + Returns: + True if key was deleted, False if not found + """ + parts = key.split(".") + config = self._config + + # Navigate to parent + for part in parts[:-1]: + if part not in config: + return False + config = config[part] + + # Delete key + if parts[-1] in config: + del config[parts[-1]] + self._save() + return True + return False + + def reset(self) -> None: + """Reset configuration to defaults.""" + self._config = self._deep_copy(DEFAULT_CONFIG) + self._save() + + def get_all(self) -> dict[str, Any]: + """Get all configuration values.""" + return self._deep_copy(self._config) + + def get_env_vars(self) -> dict[str, str]: + """Get environment variables to set.""" + env_vars = {} + + # Get from advanced.env + advanced_env = self.get("advanced.env", {}) + if isinstance(advanced_env, dict): + env_vars.update({k: str(v) for k, v in advanced_env.items()}) + + return env_vars + + @property + def models_dir(self) -> Path: + """Get the primary models directory path (for backward compatibility).""" + paths = self.get_model_paths() + return paths[0] if paths else Path(DEFAULT_MODELS_DIR) + + def get_model_paths(self) -> list[Path]: + """Get all model directory paths. + + Returns a list of Path objects. Supports both: + - Single path: paths.models = "/path/to/models" + - Multiple paths: paths.models = ["/path/1", "/path/2"] + """ + models_config = self.get("paths.models", DEFAULT_MODELS_DIR) + + # Handle both string and list + if isinstance(models_config, str): + return [Path(models_config)] + elif isinstance(models_config, list): + return [Path(p) for p in models_config] + else: + return [Path(DEFAULT_MODELS_DIR)] + + def add_model_path(self, path: str) -> None: + """Add a new model path to the configuration.""" + models_config = self.get("paths.models", DEFAULT_MODELS_DIR) + + # Convert to list if it's a string + if isinstance(models_config, str): + paths = [models_config] + elif isinstance(models_config, list): + paths = list(models_config) + else: + paths = [] + + # Add new path if not already present + if path not in paths: + paths.append(path) + self.set("paths.models", paths) + + def remove_model_path(self, path: str) -> bool: + """Remove a model path from the configuration. + + Returns True if path was removed, False if not found. + """ + models_config = self.get("paths.models", DEFAULT_MODELS_DIR) + + if isinstance(models_config, str): + # Can't remove if it's a single string + if models_config == path: + # Don't remove the last path + return False + return False + elif isinstance(models_config, list): + if path in models_config: + paths = list(models_config) + paths.remove(path) + # Don't allow removing all paths + if not paths: + return False + self.set("paths.models", paths if len(paths) > 1 else paths[0]) + return True + + return False + + @property + def cache_dir(self) -> Path: + """Get the cache directory path.""" + return Path(self.get("paths.cache", DEFAULT_CACHE_DIR)) + + @property + def weights_dir(self) -> Optional[Path]: + """Get the custom weights directory path.""" + weights = self.get("paths.weights", "") + return Path(weights) if weights else None + + +# Global settings instance +_settings: Optional[Settings] = None + + +def get_settings() -> Settings: + """Get the global settings instance.""" + global _settings + if _settings is None: + _settings = Settings() + return _settings + + +def reset_settings() -> None: + """Reset the global settings instance.""" + global _settings + _settings = None diff --git a/kt-kernel/python/cli/i18n.py b/kt-kernel/python/cli/i18n.py new file mode 100644 index 0000000..af90cba --- /dev/null +++ b/kt-kernel/python/cli/i18n.py @@ -0,0 +1,655 @@ +""" +Internationalization (i18n) module for kt-cli. + +Supports English and Chinese languages, with automatic detection based on +system locale or KT_LANG environment variable. +""" + +import os +from typing import Any + +# Message definitions for all supported languages +MESSAGES: dict[str, dict[str, str]] = { + "en": { + # General + "welcome": "Welcome to KTransformers!", + "goodbye": "Goodbye!", + "error": "Error", + "warning": "Warning", + "success": "Success", + "info": "Info", + "yes": "Yes", + "no": "No", + "cancel": "Cancel", + "confirm": "Confirm", + "done": "Done", + "failed": "Failed", + "skip": "Skip", + "back": "Back", + "next": "Next", + "retry": "Retry", + "abort": "Abort", + # Version command + "version_info": "KTransformers CLI", + "version_python": "Python", + "version_platform": "Platform", + "version_cuda": "CUDA", + "version_cuda_not_found": "Not found", + "version_kt_kernel": "kt-kernel", + "version_ktransformers": "ktransformers", + "version_sglang": "sglang", + "version_llamafactory": "llamafactory", + "version_not_installed": "Not installed", + # Install command + "install_detecting_env": "Detecting environment managers...", + "install_found": "Found {name} (version {version})", + "install_not_found": "Not found: {name}", + "install_checking_env": "Checking existing environments...", + "install_env_exists": "Found existing 'kt' environment", + "install_env_not_exists": "No 'kt' environment found", + "install_no_env_manager": "No virtual environment manager detected", + "install_select_method": "Please select installation method:", + "install_method_conda": "Create new conda environment 'kt' (Recommended)", + "install_method_venv": "Create new venv environment", + "install_method_uv": "Create new uv environment (Fast)", + "install_method_docker": "Use Docker container", + "install_method_system": "Install to system Python (Not recommended)", + "install_select_mode": "Please select installation mode:", + "install_mode_inference": "Inference - Install kt-kernel + SGLang", + "install_mode_sft": "Training - Install kt-sft + LlamaFactory", + "install_mode_full": "Full - Install all components", + "install_creating_env": "Creating {type} environment '{name}'...", + "install_env_created": "Environment created successfully", + "install_installing_deps": "Installing dependencies...", + "install_checking_deps": "Checking dependency versions...", + "install_dep_ok": "OK", + "install_dep_outdated": "Needs update", + "install_dep_missing": "Missing", + "install_installing_pytorch": "Installing PyTorch...", + "install_installing_from_requirements": "Installing from requirements file...", + "install_deps_outdated": "Found {count} package(s) that need updating. Continue?", + "install_updating": "Updating packages...", + "install_complete": "Installation complete!", + "install_activate_hint": "Activate environment: {command}", + "install_start_hint": "Get started: kt run --help", + "install_docker_pulling": "Pulling Docker image...", + "install_docker_complete": "Docker image ready!", + "install_docker_run_hint": "Run with: docker run --gpus all -p 30000:30000 {image} kt run {model}", + "install_in_venv": "Running in virtual environment: {name}", + "install_continue_without_venv": "Continue installing to system Python?", + "install_already_installed": "All dependencies are already installed!", + "install_confirm": "Install {count} package(s)?", + # Install - System dependencies + "install_checking_system_deps": "Checking system dependencies...", + "install_dep_name": "Dependency", + "install_dep_status": "Status", + "install_deps_all_installed": "All system dependencies are installed", + "install_deps_install_prompt": "Install missing dependencies?", + "install_installing_system_deps": "Installing system dependencies...", + "install_installing_dep": "Installing {name}", + "install_dep_no_install_cmd": "No install command available for {name} on {os}", + "install_dep_install_failed": "Failed to install {name}", + "install_deps_skipped": "Skipping dependency installation", + "install_deps_failed": "Failed to install system dependencies", + # Install - CPU detection + "install_auto_detect_cpu": "Auto-detecting CPU capabilities...", + "install_cpu_features": "Detected CPU features: {features}", + "install_cpu_no_features": "No advanced CPU features detected", + # Install - Build configuration + "install_build_config": "Build Configuration:", + "install_native_warning": "Note: Binary optimized for THIS CPU only (not portable)", + "install_building_from_source": "Building kt-kernel from source...", + "install_build_failed": "Build failed", + "install_build_success": "Build completed successfully", + # Install - Verification + "install_verifying": "Verifying installation...", + "install_verify_success": "kt-kernel {version} ({variant} variant) installed successfully", + "install_verify_failed": "Verification failed: {error}", + # Install - Docker + "install_docker_guide_title": "Docker Installation", + "install_docker_guide_desc": "For Docker installation, please refer to the official guide:", + # Config command + "config_show_title": "Current Configuration", + "config_set_success": "Configuration updated: {key} = {value}", + "config_get_value": "{key} = {value}", + "config_get_not_found": "Configuration key '{key}' not found", + "config_reset_confirm": "This will reset all configurations to default. Continue?", + "config_reset_success": "Configuration reset to default", + "config_file_location": "Configuration file: {path}", + # Doctor command + "doctor_title": "KTransformers Environment Diagnostics", + "doctor_checking": "Running diagnostics...", + "doctor_check_python": "Python version", + "doctor_check_cuda": "CUDA availability", + "doctor_check_gpu": "GPU detection", + "doctor_check_cpu": "CPU", + "doctor_check_cpu_isa": "CPU Instructions", + "doctor_check_numa": "NUMA Topology", + "doctor_check_memory": "System memory", + "doctor_check_disk": "Disk space", + "doctor_check_packages": "Required packages", + "doctor_check_env": "Environment variables", + "doctor_status_ok": "OK", + "doctor_status_warning": "Warning", + "doctor_status_error": "Error", + "doctor_gpu_found": "Found {count} GPU(s): {names}", + "doctor_gpu_not_found": "No GPU detected", + "doctor_cpu_info": "{name} ({cores} cores / {threads} threads)", + "doctor_cpu_isa_info": "{isa_list}", + "doctor_cpu_isa_missing": "Missing recommended: {missing}", + "doctor_numa_info": "{nodes} node(s)", + "doctor_numa_detail": "{node}: CPUs {cpus}", + "doctor_memory_info": "{available} available / {total} total", + "doctor_memory_freq": "{available} available / {total} total ({freq}MHz {type})", + "doctor_disk_info": "{available} available at {path}", + "doctor_all_ok": "All checks passed! Your environment is ready.", + "doctor_has_issues": "Some issues were found. Please review the warnings/errors above.", + # Run command + "run_detecting_hardware": "Detecting hardware configuration...", + "run_gpu_info": "GPU: {name} ({vram}GB VRAM)", + "run_cpu_info": "CPU: {name} ({cores} cores, {numa} NUMA nodes)", + "run_ram_info": "RAM: {total}GB", + "run_checking_model": "Checking model status...", + "run_model_path": "Model path: {path}", + "run_weights_not_found": "Quantized weights not found", + "run_quant_prompt": "Quantize model now? (This may take a while)", + "run_quantizing": "Quantizing model...", + "run_starting_server": "Starting server...", + "run_server_mode": "Mode: SGLang + kt-kernel", + "run_server_port": "Port: {port}", + "run_gpu_experts": "GPU experts: {count}/layer", + "run_cpu_threads": "CPU threads: {count}", + "run_server_started": "Server started!", + "run_api_url": "API URL: http://{host}:{port}", + "run_docs_url": "Docs URL: http://{host}:{port}/docs", + "run_stop_hint": "Press Ctrl+C to stop the server", + "run_model_not_found": "Model '{name}' not found. Run 'kt download' first.", + "run_multiple_matches": "Multiple models found. Please select:", + "run_select_model": "Select model", + "run_select_model_title": "Select a model to run", + "run_select_model_prompt": "Enter number", + "run_local_models": "Local Models (Downloaded)", + "run_registered_models": "Registered Models", + # Download command + "download_list_title": "Available Models", + "download_searching": "Searching for model '{name}'...", + "download_found": "Found: {name}", + "download_multiple_found": "Multiple matches found:", + "download_select": "Select model to download:", + "download_destination": "Destination: {path}", + "download_starting": "Starting download...", + "download_progress": "Downloading {name}...", + "download_complete": "Download complete!", + "download_already_exists": "Model already exists at {path}", + "download_overwrite_prompt": "Overwrite existing files?", + # Quant command + "quant_input_path": "Input path: {path}", + "quant_output_path": "Output path: {path}", + "quant_method": "Quantization method: {method}", + "quant_starting": "Starting quantization...", + "quant_progress": "Quantizing...", + "quant_complete": "Quantization complete!", + "quant_input_not_found": "Input model not found at {path}", + # SFT command + "sft_mode_train": "Training mode", + "sft_mode_chat": "Chat mode", + "sft_mode_export": "Export mode", + "sft_config_path": "Config file: {path}", + "sft_starting": "Starting {mode}...", + "sft_complete": "{mode} complete!", + "sft_config_not_found": "Config file not found: {path}", + # Bench command + "bench_starting": "Starting benchmark...", + "bench_type": "Benchmark type: {type}", + "bench_complete": "Benchmark complete!", + "bench_results_title": "Benchmark Results", + # Common prompts + "prompt_continue": "Continue?", + "prompt_select": "Please select:", + "prompt_enter_value": "Enter value:", + "prompt_confirm_action": "Confirm this action?", + # First-run setup - Model path selection + "setup_model_path_title": "Model Storage Location", + "setup_model_path_desc": "LLM models are large (50-200GB+). Please select a storage location with sufficient space:", + "setup_scanning_disks": "Scanning available storage locations...", + "setup_disk_option": "{path} ({available} available / {total} total)", + "setup_disk_option_recommended": "{path} ({available} available / {total} total) [Recommended]", + "setup_custom_path": "Enter custom path", + "setup_enter_custom_path": "Enter the path for model storage", + "setup_path_not_exist": "Path does not exist. Create it?", + "setup_path_no_write": "No write permission for this path. Please choose another.", + "setup_path_low_space": "Warning: Less than 100GB available. Large models may not fit.", + "setup_model_path_set": "Model storage path set to: {path}", + "setup_no_large_disk": "No large storage locations found. Using default path.", + "setup_scanning_models": "Scanning for existing models...", + "setup_found_models": "Found {count} model(s):", + "setup_model_info": "{name} ({size}, {type})", + "setup_no_models_found": "No existing models found in this location.", + "setup_location_has_models": "{count} model(s) found", + "setup_installing_completion": "Installing shell completion for {shell}...", + "setup_completion_installed": "Shell completion installed! Restart terminal to enable.", + "setup_completion_failed": "Failed to install shell completion. Run 'kt --install-completion' manually.", + # Auto completion + "completion_installed_title": "Tab Completion", + "completion_installed_for": "Shell completion installed for {shell}", + "completion_activate_now": "To enable completion in this terminal session, run:", + "completion_next_session": "Completion will be automatically enabled in new terminal sessions.", + # SGLang + "sglang_not_found": "SGLang not found", + "sglang_pypi_warning": "SGLang from PyPI may not be compatible with kt-kernel", + "sglang_pypi_hint": 'SGLang from PyPI may not be compatible. Install from source: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"', + "sglang_install_hint": 'Install SGLang: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"', + "sglang_recommend_source": 'Recommend reinstalling from source: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"', + "sglang_kt_kernel_not_supported": "SGLang does not support kt-kernel (missing --kt-gpu-prefill-token-threshold parameter)", + "sglang_checking_kt_kernel_support": "Checking SGLang kt-kernel support...", + "sglang_kt_kernel_supported": "SGLang kt-kernel support verified", + # Chat + "chat_proxy_detected": "Proxy detected in environment", + "chat_proxy_confirm": "Use proxy for connection?", + "chat_proxy_disabled": "Proxy disabled for this session", + # Model command + "model_supported_title": "KTransformers Supported Models", + "model_column_model": "Model", + "model_column_status": "Status", + "model_column_local_path": "Local Path", + "model_status_local": "Local", + "model_status_not_downloaded": "Not downloaded", + "model_usage_title": "Usage", + "model_usage_download": "Download a model:", + "model_usage_list_local": "List local models:", + "model_usage_search": "Search models:", + "model_storage_paths_title": "Model Storage Paths", + "model_local_models_title": "Locally Downloaded Models", + "model_available_models_title": "Available Models", + "model_no_local_models": "No locally downloaded models found", + "model_download_hint": "Download a model with:", + "model_download_usage_hint": "Usage: kt model download ", + "model_download_list_hint": "Use 'kt model download --list' to see available models.", + "model_download_hf_hint": "Or specify a HuggingFace repo directly: kt model download org/model-name", + "model_saved_to": "Model saved to: {path}", + "model_start_with": "Start with: kt run {name}", + "model_download_failed": "Download failed: {error}", + "model_hf_cli_not_found": "huggingface-cli not found. Install with: pip install huggingface-hub", + "model_path_not_exist": "Path does not exist: {path}", + "model_create_directory": "Create directory {path}?", + "model_created_directory": "Created directory: {path}", + "model_create_dir_failed": "Failed to create directory: {error}", + "model_path_added": "Added model path: {path}", + "model_path_removed": "Removed model path: {path}", + "model_path_not_found": "Path not found in configuration or cannot remove last path: {path}", + "model_search_no_results": "No models found matching '{query}'", + "model_search_results_title": "Search Results for '{query}'", + "model_column_name": "Name", + "model_column_hf_repo": "HuggingFace Repo", + "model_column_aliases": "Aliases", + # Coming soon + "feature_coming_soon": "This feature is coming soon...", + }, + "zh": { + # General + "welcome": "欢迎使用 KTransformers!", + "goodbye": "再见!", + "error": "错误", + "warning": "警告", + "success": "成功", + "info": "信息", + "yes": "是", + "no": "否", + "cancel": "取消", + "confirm": "确认", + "done": "完成", + "failed": "失败", + "skip": "跳过", + "back": "返回", + "next": "下一步", + "retry": "重试", + "abort": "中止", + # Version command + "version_info": "KTransformers CLI", + "version_python": "Python", + "version_platform": "平台", + "version_cuda": "CUDA", + "version_cuda_not_found": "未找到", + "version_kt_kernel": "kt-kernel", + "version_ktransformers": "ktransformers", + "version_sglang": "sglang", + "version_llamafactory": "llamafactory", + "version_not_installed": "未安装", + # Install command + "install_detecting_env": "检测环境管理工具...", + "install_found": "发现 {name} (版本 {version})", + "install_not_found": "未找到: {name}", + "install_checking_env": "检查现有环境...", + "install_env_exists": "发现现有 'kt' 环境", + "install_env_not_exists": "未发现 'kt' 环境", + "install_no_env_manager": "未检测到虚拟环境管理工具", + "install_select_method": "请选择安装方式:", + "install_method_conda": "创建新的 conda 环境 'kt' (推荐)", + "install_method_venv": "创建新的 venv 环境", + "install_method_uv": "创建新的 uv 环境 (快速)", + "install_method_docker": "使用 Docker 容器", + "install_method_system": "安装到系统 Python (不推荐)", + "install_select_mode": "请选择安装模式:", + "install_mode_inference": "推理模式 - 安装 kt-kernel + SGLang", + "install_mode_sft": "训练模式 - 安装 kt-sft + LlamaFactory", + "install_mode_full": "完整安装 - 安装所有组件", + "install_creating_env": "正在创建 {type} 环境 '{name}'...", + "install_env_created": "环境创建成功", + "install_installing_deps": "正在安装依赖...", + "install_checking_deps": "检查依赖版本...", + "install_dep_ok": "正常", + "install_dep_outdated": "需更新", + "install_dep_missing": "缺失", + "install_installing_pytorch": "正在安装 PyTorch...", + "install_installing_from_requirements": "从依赖文件安装...", + "install_deps_outdated": "发现 {count} 个包需要更新,是否继续?", + "install_updating": "正在更新包...", + "install_complete": "安装完成!", + "install_activate_hint": "激活环境: {command}", + "install_start_hint": "开始使用: kt run --help", + "install_docker_pulling": "正在拉取 Docker 镜像...", + "install_docker_complete": "Docker 镜像已就绪!", + "install_docker_run_hint": "运行: docker run --gpus all -p 30000:30000 {image} kt run {model}", + "install_in_venv": "当前在虚拟环境中: {name}", + "install_continue_without_venv": "继续安装到系统 Python?", + "install_already_installed": "所有依赖已安装!", + "install_confirm": "安装 {count} 个包?", + # Install - System dependencies + "install_checking_system_deps": "检查系统依赖...", + "install_dep_name": "依赖项", + "install_dep_status": "状态", + "install_deps_all_installed": "所有系统依赖已安装", + "install_deps_install_prompt": "是否安装缺失的依赖?", + "install_installing_system_deps": "正在安装系统依赖...", + "install_installing_dep": "正在安装 {name}", + "install_dep_no_install_cmd": "{os} 系统上没有 {name} 的安装命令", + "install_dep_install_failed": "安装 {name} 失败", + "install_deps_skipped": "跳过依赖安装", + "install_deps_failed": "系统依赖安装失败", + # Install - CPU detection + "install_auto_detect_cpu": "正在自动检测 CPU 能力...", + "install_cpu_features": "检测到的 CPU 特性: {features}", + "install_cpu_no_features": "未检测到高级 CPU 特性", + # Install - Build configuration + "install_build_config": "构建配置:", + "install_native_warning": "注意: 二进制文件仅针对当前 CPU 优化(不可移植)", + "install_building_from_source": "正在从源码构建 kt-kernel...", + "install_build_failed": "构建失败", + "install_build_success": "构建成功", + # Install - Verification + "install_verifying": "正在验证安装...", + "install_verify_success": "kt-kernel {version} ({variant} 变体) 安装成功", + "install_verify_failed": "验证失败: {error}", + # Install - Docker + "install_docker_guide_title": "Docker 安装", + "install_docker_guide_desc": "有关 Docker 安装,请参阅官方指南:", + # Config command + "config_show_title": "当前配置", + "config_set_success": "配置已更新: {key} = {value}", + "config_get_value": "{key} = {value}", + "config_get_not_found": "未找到配置项 '{key}'", + "config_reset_confirm": "这将重置所有配置为默认值。是否继续?", + "config_reset_success": "配置已重置为默认值", + "config_file_location": "配置文件: {path}", + # Doctor command + "doctor_title": "KTransformers 环境诊断", + "doctor_checking": "正在运行诊断...", + "doctor_check_python": "Python 版本", + "doctor_check_cuda": "CUDA 可用性", + "doctor_check_gpu": "GPU 检测", + "doctor_check_cpu": "CPU", + "doctor_check_cpu_isa": "CPU 指令集", + "doctor_check_numa": "NUMA 拓扑", + "doctor_check_memory": "系统内存", + "doctor_check_disk": "磁盘空间", + "doctor_check_packages": "必需的包", + "doctor_check_env": "环境变量", + "doctor_status_ok": "正常", + "doctor_status_warning": "警告", + "doctor_status_error": "错误", + "doctor_gpu_found": "发现 {count} 个 GPU: {names}", + "doctor_gpu_not_found": "未检测到 GPU", + "doctor_cpu_info": "{name} ({cores} 核心 / {threads} 线程)", + "doctor_cpu_isa_info": "{isa_list}", + "doctor_cpu_isa_missing": "缺少推荐指令集: {missing}", + "doctor_numa_info": "{nodes} 个节点", + "doctor_numa_detail": "{node}: CPU {cpus}", + "doctor_memory_info": "{available} 可用 / {total} 总计", + "doctor_memory_freq": "{available} 可用 / {total} 总计 ({freq}MHz {type})", + "doctor_disk_info": "{path} 有 {available} 可用空间", + "doctor_all_ok": "所有检查通过!您的环境已就绪。", + "doctor_has_issues": "发现一些问题,请查看上方的警告/错误信息。", + # Run command + "run_detecting_hardware": "检测硬件配置...", + "run_gpu_info": "GPU: {name} ({vram}GB 显存)", + "run_cpu_info": "CPU: {name} ({cores} 核心, {numa} NUMA 节点)", + "run_ram_info": "内存: {total}GB", + "run_checking_model": "检查模型状态...", + "run_model_path": "模型路径: {path}", + "run_weights_not_found": "未找到量化权重", + "run_quant_prompt": "是否现在量化模型?(这可能需要一些时间)", + "run_quantizing": "正在量化模型...", + "run_starting_server": "正在启动服务器...", + "run_server_mode": "模式: SGLang + kt-kernel", + "run_server_port": "端口: {port}", + "run_gpu_experts": "GPU 专家: {count}/层", + "run_cpu_threads": "CPU 线程: {count}", + "run_server_started": "服务器已启动!", + "run_api_url": "API 地址: http://{host}:{port}", + "run_docs_url": "文档地址: http://{host}:{port}/docs", + "run_stop_hint": "按 Ctrl+C 停止服务器", + "run_model_not_found": "未找到模型 '{name}'。请先运行 'kt download'。", + "run_multiple_matches": "找到多个匹配的模型,请选择:", + "run_select_model": "选择模型", + "run_select_model_title": "选择要运行的模型", + "run_select_model_prompt": "输入编号", + "run_local_models": "本地模型 (已下载)", + "run_registered_models": "注册模型", + # Download command + "download_list_title": "可用模型", + "download_searching": "正在搜索模型 '{name}'...", + "download_found": "找到: {name}", + "download_multiple_found": "找到多个匹配:", + "download_select": "选择要下载的模型:", + "download_destination": "目标路径: {path}", + "download_starting": "开始下载...", + "download_progress": "正在下载 {name}...", + "download_complete": "下载完成!", + "download_already_exists": "模型已存在于 {path}", + "download_overwrite_prompt": "是否覆盖现有文件?", + # Quant command + "quant_input_path": "输入路径: {path}", + "quant_output_path": "输出路径: {path}", + "quant_method": "量化方法: {method}", + "quant_starting": "开始量化...", + "quant_progress": "正在量化...", + "quant_complete": "量化完成!", + "quant_input_not_found": "未找到输入模型: {path}", + # SFT command + "sft_mode_train": "训练模式", + "sft_mode_chat": "聊天模式", + "sft_mode_export": "导出模式", + "sft_config_path": "配置文件: {path}", + "sft_starting": "正在启动 {mode}...", + "sft_complete": "{mode} 完成!", + "sft_config_not_found": "未找到配置文件: {path}", + # Bench command + "bench_starting": "开始基准测试...", + "bench_type": "测试类型: {type}", + "bench_complete": "基准测试完成!", + "bench_results_title": "基准测试结果", + # Common prompts + "prompt_continue": "是否继续?", + "prompt_select": "请选择:", + "prompt_enter_value": "请输入:", + "prompt_confirm_action": "确认此操作?", + # First-run setup - Model path selection + "setup_model_path_title": "模型存储位置", + "setup_model_path_desc": "大语言模型体积较大(50-200GB+)。请选择一个有足够空间的存储位置:", + "setup_scanning_disks": "正在扫描可用存储位置...", + "setup_disk_option": "{path} (可用 {available} / 总共 {total})", + "setup_disk_option_recommended": "{path} (可用 {available} / 总共 {total}) [推荐]", + "setup_custom_path": "输入自定义路径", + "setup_enter_custom_path": "请输入模型存储路径", + "setup_path_not_exist": "路径不存在,是否创建?", + "setup_path_no_write": "没有该路径的写入权限,请选择其他路径。", + "setup_path_low_space": "警告:可用空间不足 100GB,可能无法存储大型模型。", + "setup_model_path_set": "模型存储路径已设置为: {path}", + "setup_no_large_disk": "未发现大容量存储位置,使用默认路径。", + "setup_scanning_models": "正在扫描已有模型...", + "setup_found_models": "发现 {count} 个模型:", + "setup_model_info": "{name} ({size}, {type})", + "setup_no_models_found": "该位置未发现已有模型。", + "setup_location_has_models": "发现 {count} 个模型", + "setup_installing_completion": "正在为 {shell} 安装命令补全...", + "setup_completion_installed": "命令补全已安装!重启终端后生效。", + "setup_completion_failed": "命令补全安装失败。请手动运行 'kt --install-completion'。", + # Auto completion + "completion_installed_title": "命令补全", + "completion_installed_for": "已为 {shell} 安装命令补全", + "completion_activate_now": "在当前终端会话中启用补全,请运行:", + "completion_next_session": "新的终端会话将自动启用补全。", + # SGLang + "sglang_not_found": "未找到 SGLang", + "sglang_pypi_warning": "PyPI 版本的 SGLang 可能与 kt-kernel 不兼容", + "sglang_pypi_hint": 'PyPI 版本可能不兼容。从源码安装: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"', + "sglang_install_hint": '安装 SGLang: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"', + "sglang_recommend_source": '建议从源码重新安装: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"', + "sglang_kt_kernel_not_supported": "SGLang 不支持 kt-kernel (缺少 --kt-gpu-prefill-token-threshold 参数)", + "sglang_checking_kt_kernel_support": "正在检查 SGLang kt-kernel 支持...", + "sglang_kt_kernel_supported": "SGLang kt-kernel 支持已验证", + # Chat + "chat_proxy_detected": "检测到环境中存在代理设置", + "chat_proxy_confirm": "是否使用代理连接?", + "chat_proxy_disabled": "已在本次会话中禁用代理", + # Model command + "model_supported_title": "KTransformers 支持的模型", + "model_column_model": "模型", + "model_column_status": "状态", + "model_column_local_path": "本地路径", + "model_status_local": "本地", + "model_status_not_downloaded": "未下载", + "model_usage_title": "使用方法", + "model_usage_download": "下载模型:", + "model_usage_list_local": "列出本地模型:", + "model_usage_search": "搜索模型:", + "model_storage_paths_title": "模型存储路径", + "model_local_models_title": "本地已下载的模型", + "model_available_models_title": "可用模型", + "model_no_local_models": "未找到本地已下载的模型", + "model_download_hint": "下载模型:", + "model_download_usage_hint": "用法: kt model download <模型名称>", + "model_download_list_hint": "使用 'kt model download --list' 查看可用模型。", + "model_download_hf_hint": "或直接指定 HuggingFace 仓库: kt model download org/model-name", + "model_saved_to": "模型已保存到: {path}", + "model_start_with": "启动命令: kt run {name}", + "model_download_failed": "下载失败: {error}", + "model_hf_cli_not_found": "未找到 huggingface-cli。请安装: pip install huggingface-hub", + "model_path_not_exist": "路径不存在: {path}", + "model_create_directory": "创建目录 {path}?", + "model_created_directory": "已创建目录: {path}", + "model_create_dir_failed": "创建目录失败: {error}", + "model_path_added": "已添加模型路径: {path}", + "model_path_removed": "已移除模型路径: {path}", + "model_path_not_found": "路径未找到或无法移除最后一个路径: {path}", + "model_search_no_results": "未找到匹配 '{query}' 的模型", + "model_search_results_title": "'{query}' 的搜索结果", + "model_column_name": "名称", + "model_column_hf_repo": "HuggingFace 仓库", + "model_column_aliases": "别名", + # Coming soon + "feature_coming_soon": "此功能即将推出...", + }, +} + + +# Cache for language detection to avoid repeated I/O +_lang_cache: str | None = None + + +def get_lang() -> str: + """ + Detect the current language setting. + + Priority: + 1. KT_LANG environment variable + 2. Config file general.language setting + 3. LANG environment variable (if config is "auto") + 4. Default to English + + Returns: + Language code: "zh" for Chinese, "en" for English + """ + global _lang_cache + + # 1. Check KT_LANG environment variable (highest priority) + kt_lang = os.environ.get("KT_LANG", "").lower() + if kt_lang: + return "zh" if kt_lang.startswith("zh") else "en" + + # 2. Return cached value if available (avoids I/O on every call) + if _lang_cache is not None: + return _lang_cache + + # 3. Check config file setting (with caching) + # Import here to avoid circular imports + from kt_kernel.cli.config.settings import get_settings + + try: + settings = get_settings() + config_lang = settings.get("general.language", "auto") + if config_lang and config_lang != "auto": + lang = "zh" if config_lang.lower().startswith("zh") else "en" + _lang_cache = lang + return lang + except Exception: + # If settings fail to load, continue with system detection + pass + + # 4. Check system LANG environment variable + system_lang = os.environ.get("LANG", "").lower() + lang = "zh" if system_lang.startswith("zh") else "en" + _lang_cache = lang + return lang + + +def t(msg_key: str, **kwargs: Any) -> str: + """ + Translate a message key to the current language. + + Args: + msg_key: Message key to translate + **kwargs: Format arguments for the message + + Returns: + Translated and formatted message string + + Example: + >>> t("welcome") + "Welcome to KTransformers!" # or "欢迎使用 KTransformers!" in Chinese + + >>> t("install_found", name="conda", version="24.1.0") + "Found conda (version 24.1.0)" + """ + lang = get_lang() + messages = MESSAGES.get(lang, MESSAGES["en"]) + message = messages.get(msg_key, MESSAGES["en"].get(msg_key, msg_key)) + + if kwargs: + try: + return message.format(**kwargs) + except KeyError: + return message + return message + + +def set_lang(lang: str) -> None: + """ + Set the language for the current session. + + Args: + lang: Language code ("en" or "zh") + """ + global _lang_cache + os.environ["KT_LANG"] = lang + _lang_cache = lang # Update cache when language is explicitly set diff --git a/kt-kernel/python/cli/main.py b/kt-kernel/python/cli/main.py new file mode 100644 index 0000000..7be1dce --- /dev/null +++ b/kt-kernel/python/cli/main.py @@ -0,0 +1,436 @@ +""" +Main entry point for kt-cli. + +KTransformers CLI - A unified command-line interface for KTransformers. +""" + +import sys + +import typer + +from kt_kernel.cli import __version__ +from kt_kernel.cli.commands import bench, chat, config, doctor, model, quant, run, sft, version +from kt_kernel.cli.i18n import t, set_lang, get_lang + + +def _get_app_help() -> str: + """Get app help text based on current language.""" + lang = get_lang() + if lang == "zh": + return "KTransformers CLI - KTransformers 统一命令行界面" + return "KTransformers CLI - A unified command-line interface for KTransformers." + + +def _get_help(key: str) -> str: + """Get help text based on current language.""" + help_texts = { + "version": {"en": "Show version information", "zh": "显示版本信息"}, + "run": {"en": "Start model inference server", "zh": "启动模型推理服务器"}, + "chat": {"en": "Interactive chat with running model", "zh": "与运行中的模型进行交互式聊天"}, + "quant": {"en": "Quantize model weights", "zh": "量化模型权重"}, + "bench": {"en": "Run full benchmark", "zh": "运行完整基准测试"}, + "microbench": {"en": "Run micro-benchmark", "zh": "运行微基准测试"}, + "doctor": {"en": "Diagnose environment issues", "zh": "诊断环境问题"}, + "model": {"en": "Manage models and storage paths", "zh": "管理模型和存储路径"}, + "config": {"en": "Manage configuration", "zh": "管理配置"}, + "sft": {"en": "Fine-tuning with LlamaFactory", "zh": "使用 LlamaFactory 进行微调"}, + } + lang = get_lang() + return help_texts.get(key, {}).get(lang, help_texts.get(key, {}).get("en", key)) + + +# Create main app with dynamic help +app = typer.Typer( + name="kt", + help="KTransformers CLI - A unified command-line interface for KTransformers.", + no_args_is_help=True, + add_completion=False, # Use static completion scripts instead of dynamic completion + rich_markup_mode="rich", +) + + +def _update_help_texts() -> None: + """Update all help texts based on current language setting.""" + # Update main app help + app.info.help = _get_app_help() + + # Update command help texts + for cmd_info in app.registered_commands: + # cmd_info is a CommandInfo object + if hasattr(cmd_info, "name") and cmd_info.name: + cmd_info.help = _get_help(cmd_info.name) + + # Update sub-app help texts + for group_info in app.registered_groups: + if hasattr(group_info, "name") and group_info.name: + group_info.help = _get_help(group_info.name) + + +# Register commands +app.command(name="version", help="Show version information")(version.version) +app.command(name="run", help="Start model inference server")(run.run) +app.command(name="chat", help="Interactive chat with running model")(chat.chat) +app.command(name="quant", help="Quantize model weights")(quant.quant) +app.command(name="bench", help="Run full benchmark")(bench.bench) +app.command(name="microbench", help="Run micro-benchmark")(bench.microbench) +app.command(name="doctor", help="Diagnose environment issues")(doctor.doctor) + +# Register sub-apps +app.add_typer(model.app, name="model", help="Manage models and storage paths") +app.add_typer(config.app, name="config", help="Manage configuration") +app.add_typer(sft.app, name="sft", help="Fine-tuning with LlamaFactory") + + +def check_first_run() -> None: + """Check if this is the first run and prompt for language setup.""" + import os + + # Skip if not running in interactive terminal + if not sys.stdin.isatty(): + return + + from kt_kernel.cli.config.settings import DEFAULT_CONFIG_FILE + + # Only check if config file exists - don't create it yet + if not DEFAULT_CONFIG_FILE.exists(): + # First run - show welcome and language selection + from kt_kernel.cli.config.settings import get_settings + + settings = get_settings() + _show_first_run_setup(settings) + else: + # Config exists - check if initialized + from kt_kernel.cli.config.settings import get_settings + + settings = get_settings() + if not settings.get("general._initialized"): + _show_first_run_setup(settings) + + +def _show_first_run_setup(settings) -> None: + """Show first-run setup wizard.""" + from rich.console import Console + from rich.panel import Panel + from rich.prompt import Prompt, Confirm + from rich.spinner import Spinner + from rich.live import Live + + from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb, scan_models_in_location + + console = Console() + + # Welcome message + console.print() + console.print( + Panel.fit( + "[bold cyan]Welcome to KTransformers CLI! / 欢迎使用 KTransformers CLI![/bold cyan]\n\n" + "Let's set up your preferences.\n" + "让我们设置您的偏好。", + title="kt-cli", + border_style="cyan", + ) + ) + console.print() + + # Language selection + console.print("[bold]Select your preferred language / 选择您的首选语言:[/bold]") + console.print() + console.print(" [cyan][1][/cyan] English") + console.print(" [cyan][2][/cyan] 中文 (Chinese)") + console.print() + + while True: + choice = Prompt.ask("Enter choice / 输入选择", choices=["1", "2"], default="1") + + if choice == "1": + lang = "en" + break + elif choice == "2": + lang = "zh" + break + + # Save language setting + settings.set("general.language", lang) + set_lang(lang) + + # Confirmation message + console.print() + if lang == "zh": + console.print("[green]✓[/green] 语言已设置为中文") + else: + console.print("[green]✓[/green] Language set to English") + + # Model storage path selection + console.print() + console.print(f"[bold]{t('setup_model_path_title')}[/bold]") + console.print() + console.print(f"[dim]{t('setup_model_path_desc')}[/dim]") + console.print() + + # Scan for storage locations + console.print(f"[dim]{t('setup_scanning_disks')}[/dim]") + locations = scan_storage_locations(min_size_gb=50.0) + console.print() + + if locations: + # Scan for models in each location + console.print(f"[dim]{t('setup_scanning_models')}[/dim]") + location_models: dict[str, list] = {} + for loc in locations[:5]: + models = scan_models_in_location(loc, max_depth=2) + if models: + location_models[loc.path] = models + console.print() + + # Show options + for i, loc in enumerate(locations[:5], 1): # Show top 5 options + available = format_size_gb(loc.available_gb) + total = format_size_gb(loc.total_gb) + + # Build the option string + if i == 1: + option_str = t("setup_disk_option_recommended", path=loc.path, available=available, total=total) + else: + option_str = t("setup_disk_option", path=loc.path, available=available, total=total) + + # Add model count if any + if loc.path in location_models: + model_count = len(location_models[loc.path]) + option_str += f" [green]✓ {t('setup_location_has_models', count=model_count)}[/green]" + + console.print(f" [cyan][{i}][/cyan] {option_str}") + + # Show first few models found in this location + if loc.path in location_models: + for model in location_models[loc.path][:3]: # Show up to 3 models + size_str = format_size_gb(model.size_gb) + console.print(f" [dim]• {model.name} ({size_str})[/dim]") + if len(location_models[loc.path]) > 3: + remaining = len(location_models[loc.path]) - 3 + console.print(f" [dim] ... +{remaining} more[/dim]") + + # Custom path option + custom_idx = min(len(locations), 5) + 1 + console.print(f" [cyan][{custom_idx}][/cyan] {t('setup_custom_path')}") + console.print() + + valid_choices = [str(i) for i in range(1, custom_idx + 1)] + path_choice = Prompt.ask(t("prompt_select"), choices=valid_choices, default="1") + + if path_choice == str(custom_idx): + # Custom path + selected_path = _prompt_custom_path(console, settings) + else: + selected_path = locations[int(path_choice) - 1].path + else: + # No large storage found, ask for custom path + console.print(f"[yellow]{t('setup_no_large_disk')}[/yellow]") + console.print() + selected_path = _prompt_custom_path(console, settings) + + # Ensure the path exists + import os + from pathlib import Path + + if not os.path.exists(selected_path): + if Confirm.ask(t("setup_path_not_exist"), default=True): + try: + Path(selected_path).mkdir(parents=True, exist_ok=True) + except (OSError, PermissionError) as e: + console.print(f"[red]{t('error')}: {e}[/red]") + # Fall back to default + selected_path = str(Path.home() / ".ktransformers" / "models") + Path(selected_path).mkdir(parents=True, exist_ok=True) + + # Check available space and warn if low + from kt_kernel.cli.utils.environment import detect_disk_space_gb + + available_gb, _ = detect_disk_space_gb( + selected_path if os.path.exists(selected_path) else str(Path(selected_path).parent) + ) + if available_gb < 100: + console.print(f"[yellow]{t('setup_path_low_space')}[/yellow]") + + # Save the path + settings.set("paths.models", selected_path) + settings.set("general._initialized", True) + + console.print() + console.print(f"[green]✓[/green] {t('setup_model_path_set', path=selected_path)}") + console.print() + + # Tips + if lang == "zh": + console.print("[dim]提示: 运行 'kt config show' 查看所有配置[/dim]") + else: + console.print("[dim]Tip: Run 'kt config show' to view all settings[/dim]") + + console.print() + + +def _prompt_custom_path(console, settings) -> str: + """Prompt user to enter a custom path.""" + from rich.prompt import Prompt + from pathlib import Path + import os + + default_path = str(Path.home() / ".ktransformers" / "models") + + while True: + custom_path = Prompt.ask(t("setup_enter_custom_path"), default=default_path) + + # Expand user home + custom_path = os.path.expanduser(custom_path) + + # Check if path exists or parent is writable + if os.path.exists(custom_path): + if os.access(custom_path, os.W_OK): + return custom_path + else: + console.print(f"[red]{t('setup_path_no_write')}[/red]") + else: + # Check if we can create it (parent writable) + parent = str(Path(custom_path).parent) + while not os.path.exists(parent) and parent != "/": + parent = str(Path(parent).parent) + + if os.access(parent, os.W_OK): + return custom_path + else: + console.print(f"[red]{t('setup_path_no_write')}[/red]") + + +def _install_shell_completion() -> None: + """Install shell completion scripts to user directories. + + Uses standard locations that are auto-loaded by shell completion systems: + - Bash: ~/.local/share/bash-completion/completions/kt (auto-loaded by bash-completion 2.0+) + - Zsh: ~/.zfunc/_kt (requires fpath setup, but commonly used) + - Fish: ~/.config/fish/completions/kt.fish (auto-loaded) + """ + import os + import shutil + from pathlib import Path + + from kt_kernel.cli.config.settings import get_settings + + settings = get_settings() + + # Check if already installed + if settings.get("general._completion_installed", False): + return + + # Detect current shell + shell = os.environ.get("SHELL", "") + if "zsh" in shell: + shell_name = "zsh" + elif "fish" in shell: + shell_name = "fish" + else: + shell_name = "bash" + + try: + cli_dir = Path(__file__).parent + completions_dir = cli_dir / "completions" + home = Path.home() + + installed = False + + if shell_name == "bash": + # Use XDG standard location for bash-completion (auto-loaded) + src_file = completions_dir / "kt-completion.bash" + dest_dir = home / ".local" / "share" / "bash-completion" / "completions" + dest_file = dest_dir / "kt" + + if src_file.exists(): + dest_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(src_file, dest_file) + installed = True + + elif shell_name == "zsh": + src_file = completions_dir / "_kt" + dest_dir = home / ".zfunc" + dest_file = dest_dir / "_kt" + + if src_file.exists(): + dest_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(src_file, dest_file) + installed = True + + elif shell_name == "fish": + # Fish auto-loads from this directory + src_file = completions_dir / "kt.fish" + dest_dir = home / ".config" / "fish" / "completions" + dest_file = dest_dir / "kt.fish" + + if src_file.exists(): + dest_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(src_file, dest_file) + installed = True + + # Mark as installed + settings.set("general._completion_installed", True) + + # For bash/zsh, completion will work in new terminals automatically + # (bash-completion 2.0+ auto-loads from ~/.local/share/bash-completion/completions/) + + except (OSError, IOError): + # Silently ignore errors - completion is not critical + pass + + +def _apply_saved_language() -> None: + """Apply the saved language setting. + + Priority: + 1. KT_LANG environment variable (if already set, don't override) + 2. Config file setting + 3. System locale (auto) + """ + import os + + # Don't override if KT_LANG is already set by user + if os.environ.get("KT_LANG"): + return + + from kt_kernel.cli.config.settings import get_settings + + settings = get_settings() + lang = settings.get("general.language", "auto") + + if lang != "auto": + set_lang(lang) + + +def main(): + """Main entry point.""" + # Apply saved language setting first (before anything else for correct help display) + _apply_saved_language() + + # Update help texts based on language + _update_help_texts() + + # Check for first run (but not for certain commands) + # Skip first-run check for: --help, config commands, version + args = sys.argv[1:] if len(sys.argv) > 1 else [] + skip_commands = ["--help", "-h", "config", "version", "--version"] + + should_check_first_run = True + for arg in args: + if arg in skip_commands: + should_check_first_run = False + break + + # Auto-install shell completion on first run + if should_check_first_run: + _install_shell_completion() + + # Check first run before running commands + if should_check_first_run and args: + check_first_run() + + app() + + +if __name__ == "__main__": + main() diff --git a/kt-kernel/python/cli/requirements/inference.txt b/kt-kernel/python/cli/requirements/inference.txt new file mode 100644 index 0000000..b5e6f1e --- /dev/null +++ b/kt-kernel/python/cli/requirements/inference.txt @@ -0,0 +1,6 @@ +# Inference dependencies for KTransformers +# NOTE: sglang is installed separately from source (see install.py) + +transformers>=4.45.0 +safetensors>=0.4.0 +huggingface-hub>=0.20.0 diff --git a/kt-kernel/python/cli/requirements/sft.txt b/kt-kernel/python/cli/requirements/sft.txt new file mode 100644 index 0000000..6981daa --- /dev/null +++ b/kt-kernel/python/cli/requirements/sft.txt @@ -0,0 +1,7 @@ +# SFT (Supervised Fine-Tuning) dependencies for KTransformers + +llamafactory>=0.9.0 +peft>=0.12.0 +transformers>=4.45.0 +datasets>=2.14.0 +accelerate>=0.30.0 diff --git a/kt-kernel/python/cli/utils/__init__.py b/kt-kernel/python/cli/utils/__init__.py new file mode 100644 index 0000000..42d4a7f --- /dev/null +++ b/kt-kernel/python/cli/utils/__init__.py @@ -0,0 +1,3 @@ +""" +Utility modules for kt-cli. +""" diff --git a/kt-kernel/python/cli/utils/console.py b/kt-kernel/python/cli/utils/console.py new file mode 100644 index 0000000..01f6f9e --- /dev/null +++ b/kt-kernel/python/cli/utils/console.py @@ -0,0 +1,249 @@ +""" +Console utilities for kt-cli. + +Provides Rich-based console output helpers for consistent formatting. +""" + +from typing import Optional + +from rich.console import Console +from rich.panel import Panel +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) +from rich.prompt import Confirm, Prompt +from rich.table import Table +from rich.theme import Theme + +from kt_kernel.cli.i18n import t + +# Custom theme for kt-cli +KT_THEME = Theme( + { + "info": "cyan", + "warning": "yellow", + "error": "bold red", + "success": "bold green", + "highlight": "bold magenta", + "muted": "dim", + } +) + +# Global console instance +console = Console(theme=KT_THEME) + + +def print_info(message: str, **kwargs) -> None: + """Print an info message.""" + console.print(f"[info]ℹ[/info] {message}", **kwargs) + + +def print_success(message: str, **kwargs) -> None: + """Print a success message.""" + console.print(f"[success]✓[/success] {message}", **kwargs) + + +def print_warning(message: str, **kwargs) -> None: + """Print a warning message.""" + console.print(f"[warning]⚠[/warning] {message}", **kwargs) + + +def print_error(message: str, **kwargs) -> None: + """Print an error message.""" + console.print(f"[error]✗[/error] {message}", **kwargs) + + +def print_step(message: str, **kwargs) -> None: + """Print a step indicator.""" + console.print(f"[highlight]→[/highlight] {message}", **kwargs) + + +def print_header(title: str, subtitle: Optional[str] = None) -> None: + """Print a header panel.""" + content = f"[bold]{title}[/bold]" + if subtitle: + content += f"\n[muted]{subtitle}[/muted]" + console.print(Panel(content, expand=False)) + + +def print_version_table(versions: dict[str, Optional[str]]) -> None: + """Print a version information table.""" + table = Table(show_header=False, box=None, padding=(0, 2)) + table.add_column("Component", style="bold") + table.add_column("Version") + + for name, version in versions.items(): + if version: + table.add_row(name, f"[success]{version}[/success]") + else: + table.add_row(name, f"[muted]{t('version_not_installed')}[/muted]") + + console.print(table) + + +def print_dependency_table(deps: list[dict]) -> None: + """Print a dependency status table.""" + table = Table(title=t("install_checking_deps")) + table.add_column(t("version_info"), style="bold") + table.add_column("Current") + table.add_column("Required") + table.add_column("Status") + + for dep in deps: + status = dep.get("status", "ok") + if status == "ok": + status_str = f"[success]{t('install_dep_ok')}[/success]" + elif status == "outdated": + status_str = f"[warning]{t('install_dep_outdated')}[/warning]" + else: + status_str = f"[error]{t('install_dep_missing')}[/error]" + + table.add_row( + dep["name"], + dep.get("installed", "-"), + dep.get("required", "-"), + status_str, + ) + + console.print(table) + + +def confirm(message: str, default: bool = True) -> bool: + """Ask for confirmation.""" + return Confirm.ask(message, default=default, console=console) + + +def prompt_choice(message: str, choices: list[str], default: Optional[str] = None) -> str: + """Prompt for a choice from a list.""" + # Display numbered choices + console.print(f"\n[bold]{message}[/bold]") + for i, choice in enumerate(choices, 1): + console.print(f" [highlight][{i}][/highlight] {choice}") + + while True: + response = Prompt.ask( + "\n" + t("prompt_select"), + console=console, + default=str(choices.index(default) + 1) if default else None, + ) + try: + idx = int(response) - 1 + if 0 <= idx < len(choices): + return choices[idx] + except ValueError: + # Check if response matches a choice directly + if response in choices: + return response + + print_error(f"Please enter a number between 1 and {len(choices)}") + + +def prompt_text(message: str, default: Optional[str] = None) -> str: + """Prompt for text input.""" + return Prompt.ask(message, console=console, default=default) + + +def create_progress() -> Progress: + """Create a progress bar for general tasks.""" + return Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + console=console, + ) + + +def create_download_progress() -> Progress: + """Create a progress bar for downloads.""" + return Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + DownloadColumn(), + TransferSpeedColumn(), + TimeRemainingColumn(), + console=console, + ) + + +def print_model_table(models: list[dict]) -> None: + """Print a table of models.""" + table = Table(title=t("download_list_title")) + table.add_column("Name", style="bold") + table.add_column("Repository") + table.add_column("Type") + table.add_column("Requirements") + + for model in models: + reqs = [] + if model.get("gpu_vram_gb"): + reqs.append(f"GPU: {model['gpu_vram_gb']}GB") + if model.get("cpu_ram_gb"): + reqs.append(f"RAM: {model['cpu_ram_gb']}GB") + + table.add_row( + model.get("name", ""), + model.get("hf_repo", ""), + model.get("type", ""), + ", ".join(reqs) if reqs else "-", + ) + + console.print(table) + + +def print_hardware_info(gpu_info: str, cpu_info: str, ram_info: str) -> None: + """Print hardware information.""" + table = Table(show_header=False, box=None) + table.add_column("Icon", width=3) + table.add_column("Info") + + table.add_row("🖥️", gpu_info) + table.add_row("💻", cpu_info) + table.add_row("🧠", ram_info) + + console.print(Panel(table, title="Hardware", expand=False)) + + +def print_server_info( + mode: str, host: str, port: int, gpu_experts: int, cpu_threads: int +) -> None: + """Print server startup information.""" + table = Table(show_header=False, box=None) + table.add_column("Key", style="bold") + table.add_column("Value") + + table.add_row(t("run_server_mode").split(":")[0], mode) + table.add_row("Host", host) + table.add_row("Port", str(port)) + table.add_row(t("run_gpu_experts").split(":")[0], f"{gpu_experts}/layer") + table.add_row(t("run_cpu_threads").split(":")[0], str(cpu_threads)) + + console.print(Panel(table, title=t("run_server_started"), expand=False, border_style="green")) + + +def print_api_info(host: str, port: int) -> None: + """Print API endpoint information.""" + api_url = f"http://{host}:{port}" + docs_url = f"http://{host}:{port}/docs" + + console.print() + console.print(f" {t('run_api_url', host=host, port=port)}") + console.print(f" {t('run_docs_url', host=host, port=port)}") + console.print() + console.print(f" [muted]Test command:[/muted]") + console.print( + f" [dim]curl {api_url}/v1/chat/completions -H 'Content-Type: application/json' " + f"-d '{{\"model\": \"default\", \"messages\": [{{\"role\": \"user\", \"content\": \"Hello\"}}]}}'[/dim]" + ) + console.print() + console.print(f" [muted]{t('run_stop_hint')}[/muted]") diff --git a/kt-kernel/python/cli/utils/environment.py b/kt-kernel/python/cli/utils/environment.py new file mode 100644 index 0000000..422c6dc --- /dev/null +++ b/kt-kernel/python/cli/utils/environment.py @@ -0,0 +1,1108 @@ +""" +Environment detection utilities for kt-cli. + +Provides functions to detect: +- Virtual environment managers (conda, venv, uv, mamba) +- Python version and packages +- CUDA and GPU information +- System resources (CPU, RAM, disk) +""" + +import os +import platform +import shutil +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + + +@dataclass +class EnvManager: + """Information about an environment manager.""" + + name: str + version: str + path: str + + +@dataclass +class GPUInfo: + """Information about a GPU.""" + + index: int + name: str + vram_gb: float + cuda_capability: Optional[str] = None + + +@dataclass +class CPUInfo: + """Information about the CPU.""" + + name: str + cores: int + threads: int + numa_nodes: int + instruction_sets: list[str] = field(default_factory=list) # AVX, AVX2, AVX512, AMX, etc. + numa_info: dict = field(default_factory=dict) # node -> cpus mapping + + +@dataclass +class MemoryInfo: + """Information about system memory.""" + + total_gb: float + available_gb: float + frequency_mhz: Optional[int] = None + channels: Optional[int] = None + type: Optional[str] = None # DDR4, DDR5, etc. + + +@dataclass +class SystemInfo: + """Complete system information.""" + + python_version: str + platform: str + cuda_version: Optional[str] + gpus: list[GPUInfo] + cpu: CPUInfo + ram_gb: float + env_managers: list[EnvManager] + + +def run_command(cmd: list[str], timeout: int = 10) -> Optional[str]: + """Run a command and return its output, or None if it fails.""" + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, check=False) + if result.returncode == 0: + return result.stdout.strip() + return None + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + return None + + +def detect_env_managers() -> list[EnvManager]: + """Detect available virtual environment managers.""" + managers = [] + + # Check conda + conda_path = shutil.which("conda") + if conda_path: + version = run_command(["conda", "--version"]) + if version: + # "conda 24.1.0" -> "24.1.0" + version = version.split()[-1] if version else "unknown" + managers.append(EnvManager(name="conda", version=version, path=conda_path)) + + # Check mamba + mamba_path = shutil.which("mamba") + if mamba_path: + version = run_command(["mamba", "--version"]) + if version: + # First line: "mamba 1.5.0" + version = version.split("\n")[0].split()[-1] if version else "unknown" + managers.append(EnvManager(name="mamba", version=version, path=mamba_path)) + + # Check uv + uv_path = shutil.which("uv") + if uv_path: + version = run_command(["uv", "--version"]) + if version: + # "uv 0.5.0" -> "0.5.0" + version = version.split()[-1] if version else "unknown" + managers.append(EnvManager(name="uv", version=version, path=uv_path)) + + # Check if venv is available (built into Python) + try: + import venv # noqa: F401 + + managers.append(EnvManager(name="venv", version="builtin", path="python -m venv")) + except ImportError: + pass + + return managers + + +def check_docker() -> Optional[EnvManager]: + """Check if Docker is available.""" + docker_path = shutil.which("docker") + if docker_path: + version = run_command(["docker", "--version"]) + if version: + # "Docker version 24.0.7, build afdd53b" + parts = version.split() + version = parts[2].rstrip(",") if len(parts) > 2 else "unknown" + return EnvManager(name="docker", version=version, path=docker_path) + return None + + +def check_kt_env_exists(manager: str, env_name: str = "kt") -> bool: + """Check if a kt environment exists for the given manager.""" + if manager == "conda" or manager == "mamba": + result = run_command([manager, "env", "list"]) + if result: + # Check if env_name appears as a separate word in the output + for line in result.split("\n"): + parts = line.split() + if parts and parts[0] == env_name: + return True + elif manager == "uv": + # uv uses .venv in the project directory or ~/.local/share/uv/envs/ + venv_path = Path.home() / ".local" / "share" / "uv" / "envs" / env_name + if venv_path.exists(): + return True + # Also check current directory + if Path(env_name).exists() and (Path(env_name) / "bin" / "python").exists(): + return True + elif manager == "venv": + # Check common locations + venv_path = Path.home() / ".virtualenvs" / env_name + if venv_path.exists(): + return True + if Path(env_name).exists() and (Path(env_name) / "bin" / "python").exists(): + return True + + return False + + +def get_kt_env_path(manager: str, env_name: str = "kt") -> Optional[Path]: + """Get the path to the kt environment.""" + if manager == "conda" or manager == "mamba": + result = run_command([manager, "env", "list"]) + if result: + for line in result.split("\n"): + parts = line.split() + if parts and parts[0] == env_name: + # The path is the last part + return Path(parts[-1]) + elif manager == "uv": + venv_path = Path.home() / ".local" / "share" / "uv" / "envs" / env_name + if venv_path.exists(): + return venv_path + elif manager == "venv": + venv_path = Path.home() / ".virtualenvs" / env_name + if venv_path.exists(): + return venv_path + + return None + + +def detect_cuda_version() -> Optional[str]: + """Detect CUDA version from nvidia-smi or nvcc.""" + # Try nvidia-smi first + nvidia_smi = run_command(["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"]) + if nvidia_smi: + # Get CUDA version from nvidia-smi + full_output = run_command(["nvidia-smi"]) + if full_output: + for line in full_output.split("\n"): + if "CUDA Version:" in line: + # "| CUDA Version: 12.1 |" + parts = line.split("CUDA Version:") + if len(parts) > 1: + version = parts[1].strip().split()[0] + return version + + # Try nvcc + nvcc_output = run_command(["nvcc", "--version"]) + if nvcc_output: + for line in nvcc_output.split("\n"): + if "release" in line.lower(): + # "Cuda compilation tools, release 12.1, V12.1.105" + parts = line.split("release") + if len(parts) > 1: + version = parts[1].strip().split(",")[0].strip() + return version + + return None + + +def detect_gpus() -> list[GPUInfo]: + """Detect available NVIDIA GPUs, respecting CUDA_VISIBLE_DEVICES.""" + gpus = [] + + nvidia_smi = run_command(["nvidia-smi", "--query-gpu=index,name,memory.total", "--format=csv,noheader,nounits"]) + + if nvidia_smi: + for line in nvidia_smi.strip().split("\n"): + parts = [p.strip() for p in line.split(",")] + if len(parts) >= 3: + try: + index = int(parts[0]) + name = parts[1] + vram_mb = float(parts[2]) + vram_gb = round(vram_mb / 1024, 1) + gpus.append(GPUInfo(index=index, name=name, vram_gb=vram_gb)) + except (ValueError, IndexError): + continue + + # Filter by CUDA_VISIBLE_DEVICES if set + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") + if cuda_visible is not None: + if cuda_visible == "": + # Empty string means no GPUs visible + return [] + + try: + # Parse CUDA_VISIBLE_DEVICES (can be "0,1,2" or "0-3" etc.) + visible_indices = _parse_cuda_visible_devices(cuda_visible) + # Filter GPUs to only those in CUDA_VISIBLE_DEVICES + filtered_gpus = [gpu for gpu in gpus if gpu.index in visible_indices] + # Re-index GPUs to match CUDA's logical indexing (0, 1, 2, ...) + for i, gpu in enumerate(filtered_gpus): + # Keep original index in a comment, but CUDA sees them as 0,1,2... + gpu.index = i + return filtered_gpus + except ValueError: + # If parsing fails, return all GPUs as fallback + pass + + return gpus + + +def _parse_cuda_visible_devices(cuda_visible: str) -> list[int]: + """Parse CUDA_VISIBLE_DEVICES string into list of GPU indices. + + Supports formats like: + - "0,1,2,3" -> [0, 1, 2, 3] + - "0-3" -> [0, 1, 2, 3] + - "0,2-4,7" -> [0, 2, 3, 4, 7] + """ + indices = [] + parts = cuda_visible.split(",") + + for part in parts: + part = part.strip() + if "-" in part: + # Range like "0-3" + start, end = part.split("-") + indices.extend(range(int(start), int(end) + 1)) + else: + # Single index + indices.append(int(part)) + + return sorted(set(indices)) # Remove duplicates and sort + + +def detect_cpu_info() -> CPUInfo: + """Detect CPU information including instruction sets and NUMA topology.""" + name = "Unknown" + cores = os.cpu_count() or 1 + threads = cores + numa_nodes = 1 + instruction_sets: list[str] = [] + numa_info: dict[str, list[int]] = {} + + if platform.system() == "Linux": + try: + with open("/proc/cpuinfo", "r") as f: + content = f.read() + + # Get CPU name + for line in content.split("\n"): + if line.startswith("model name"): + name = line.split(":")[1].strip() + break + + # Get physical cores vs threads + cpu_cores = content.count("processor\t:") + if cpu_cores > 0: + threads = cpu_cores + + siblings = None + cores_per = None + for line in content.split("\n"): + if "siblings" in line: + siblings = int(line.split(":")[1].strip()) + if "cpu cores" in line: + cores_per = int(line.split(":")[1].strip()) + if siblings and cores_per: + cores = threads // (siblings // cores_per) if siblings > cores_per else threads + + # Get instruction sets from flags + for line in content.split("\n"): + if line.startswith("flags"): + flags = line.split(":")[1].strip().split() + instruction_sets = _parse_cpu_flags(flags) + break + + except (OSError, IOError, ValueError): + pass + + # Get NUMA topology + numa_path = Path("/sys/devices/system/node") + if numa_path.exists(): + numa_dirs = [d for d in numa_path.iterdir() if d.name.startswith("node")] + numa_nodes = len(numa_dirs) + + for node_dir in numa_dirs: + node_name = node_dir.name # e.g., "node0" + cpulist_path = node_dir / "cpulist" + if cpulist_path.exists(): + try: + cpulist = cpulist_path.read_text().strip() + numa_info[node_name] = _parse_cpu_list(cpulist) + except (OSError, IOError): + pass + + elif platform.system() == "Darwin": + # macOS + name_output = run_command(["sysctl", "-n", "machdep.cpu.brand_string"]) + if name_output: + name = name_output.strip() + cores_output = run_command(["sysctl", "-n", "hw.physicalcpu"]) + if cores_output: + cores = int(cores_output.strip()) + threads_output = run_command(["sysctl", "-n", "hw.logicalcpu"]) + if threads_output: + threads = int(threads_output.strip()) + + # Get instruction sets on macOS + features_output = run_command(["sysctl", "-n", "machdep.cpu.features"]) + if features_output: + flags = features_output.lower().split() + instruction_sets = _parse_cpu_flags(flags) + + return CPUInfo( + name=name, + cores=cores, + threads=threads, + numa_nodes=numa_nodes, + instruction_sets=instruction_sets, + numa_info=numa_info, + ) + + +def _parse_cpu_flags(flags: list[str]) -> list[str]: + """Parse CPU flags to extract relevant instruction sets for KTransformers.""" + # Instruction sets important for KTransformers/kt-kernel + relevant_instructions = { + # Basic SIMD + "sse": "SSE", + "sse2": "SSE2", + "sse3": "SSE3", + "ssse3": "SSSE3", + "sse4_1": "SSE4.1", + "sse4_2": "SSE4.2", + # AVX family + "avx": "AVX", + "avx2": "AVX2", + "avx512f": "AVX512F", + "avx512bw": "AVX512BW", + "avx512vl": "AVX512VL", + "avx512dq": "AVX512DQ", + "avx512cd": "AVX512CD", + "avx512vnni": "AVX512VNNI", + "avx512_bf16": "AVX512BF16", + "avx512_fp16": "AVX512FP16", + "avx_vnni": "AVX-VNNI", + # AMX (Advanced Matrix Extensions) - Intel + "amx_tile": "AMX-TILE", + "amx_bf16": "AMX-BF16", + "amx_int8": "AMX-INT8", + "amx_fp16": "AMX-FP16", + # Other relevant + "fma": "FMA", + "f16c": "F16C", + "bmi1": "BMI1", + "bmi2": "BMI2", + } + + found = [] + flags_lower = {f.lower() for f in flags} + + for flag, display_name in relevant_instructions.items(): + if flag in flags_lower: + found.append(display_name) + + # Sort by importance for display + priority = [ + "AMX-INT8", + "AMX-BF16", + "AMX-FP16", + "AMX-TILE", + "AVX512BF16", + "AVX512VNNI", + "AVX512F", + "AVX512BW", + "AVX512VL", + "AVX2", + "AVX", + "FMA", + "SSE4.2", + ] + result = [] + for p in priority: + if p in found: + result.append(p) + found.remove(p) + result.extend(sorted(found)) # Add remaining + + return result + + +def _parse_cpu_list(cpulist: str) -> list[int]: + """Parse CPU list string like '0-3,8-11' to list of CPU IDs.""" + cpus = [] + for part in cpulist.split(","): + if "-" in part: + start, end = part.split("-") + cpus.extend(range(int(start), int(end) + 1)) + else: + cpus.append(int(part)) + return cpus + + +def detect_memory_info() -> MemoryInfo: + """Detect detailed memory information including frequency and type.""" + total_gb = detect_ram_gb() + available_gb = detect_available_ram_gb() + frequency_mhz: Optional[int] = None + channels: Optional[int] = None + mem_type: Optional[str] = None + + if platform.system() == "Linux": + # Try dmidecode without sudo first (may work if user has permissions) + dmidecode_output = run_command(["dmidecode", "-t", "memory"]) + if dmidecode_output: + frequency_mhz, mem_type, channels = _parse_dmidecode_memory(dmidecode_output) + + # Fallback: try to read from /sys or /proc + if frequency_mhz is None: + frequency_mhz = _detect_memory_frequency_sysfs() + + elif platform.system() == "Darwin": + # macOS - use system_profiler + mem_output = run_command(["system_profiler", "SPMemoryDataType"]) + if mem_output: + frequency_mhz, mem_type = _parse_macos_memory(mem_output) + + return MemoryInfo( + total_gb=total_gb, + available_gb=available_gb, + frequency_mhz=frequency_mhz, + channels=channels, + type=mem_type, + ) + + +def _parse_dmidecode_memory(output: str) -> tuple[Optional[int], Optional[str], Optional[int]]: + """Parse dmidecode memory output.""" + frequency_mhz: Optional[int] = None + mem_type: Optional[str] = None + dimm_count = 0 + + for line in output.split("\n"): + line = line.strip() + if line.startswith("Speed:") and "MHz" in line: + try: + # "Speed: 4800 MHz" or "Speed: 4800 MT/s" + parts = line.split(":")[1].strip().split() + freq = int(parts[0]) + if freq > 0 and (frequency_mhz is None or freq > frequency_mhz): + frequency_mhz = freq + except (ValueError, IndexError): + pass + elif line.startswith("Type:") and mem_type is None: + type_val = line.split(":")[1].strip() + if type_val and type_val != "Unknown": + mem_type = type_val + elif line.startswith("Size:") and "MB" in line or "GB" in line: + dimm_count += 1 + + return frequency_mhz, mem_type, dimm_count if dimm_count > 0 else None + + +def _detect_memory_frequency_sysfs() -> Optional[int]: + """Try to detect memory frequency from sysfs.""" + # This is a fallback and may not work on all systems + try: + # Try reading from edac + edac_path = Path("/sys/devices/system/edac/mc") + if edac_path.exists(): + for mc_dir in edac_path.iterdir(): + freq_file = mc_dir / "mc_config" + if freq_file.exists(): + content = freq_file.read_text() + # Parse for frequency information + # Format varies by system + pass + except (OSError, IOError): + pass + + return None + + +def _parse_macos_memory(output: str) -> tuple[Optional[int], Optional[str]]: + """Parse macOS system_profiler memory output.""" + frequency_mhz: Optional[int] = None + mem_type: Optional[str] = None + + for line in output.split("\n"): + line = line.strip() + if "Speed:" in line: + try: + parts = line.split(":")[1].strip().split() + frequency_mhz = int(parts[0]) + except (ValueError, IndexError): + pass + elif "Type:" in line: + mem_type = line.split(":")[1].strip() + + return frequency_mhz, mem_type + + +def detect_ram_gb() -> float: + """Detect total system RAM in GB.""" + if platform.system() == "Linux": + try: + with open("/proc/meminfo", "r") as f: + for line in f: + if line.startswith("MemTotal:"): + # "MemTotal: 32780516 kB" + kb = int(line.split()[1]) + return round(kb / 1024 / 1024, 1) + except (OSError, IOError, ValueError): + pass + elif platform.system() == "Darwin": + mem_output = run_command(["sysctl", "-n", "hw.memsize"]) + if mem_output: + return round(int(mem_output) / 1024 / 1024 / 1024, 1) + + # Fallback + try: + import psutil + + return round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 1) + except ImportError: + return 0.0 + + +def detect_available_ram_gb() -> float: + """Detect available system RAM in GB.""" + if platform.system() == "Linux": + try: + with open("/proc/meminfo", "r") as f: + for line in f: + if line.startswith("MemAvailable:"): + kb = int(line.split()[1]) + return round(kb / 1024 / 1024, 1) + except (OSError, IOError, ValueError): + pass + + # Fallback + try: + import psutil + + return round(psutil.virtual_memory().available / 1024 / 1024 / 1024, 1) + except ImportError: + return 0.0 + + +def detect_disk_space_gb(path: str = "/") -> tuple[float, float]: + """Detect disk space (available, total) in GB for the given path.""" + try: + import shutil + + total, used, free = shutil.disk_usage(path) + return round(free / 1024 / 1024 / 1024, 1), round(total / 1024 / 1024 / 1024, 1) + except (OSError, IOError): + return 0.0, 0.0 + + +def get_installed_package_version(package_name: str) -> Optional[str]: + """Get the version of an installed Python package.""" + try: + from importlib.metadata import version + + return version(package_name) + except Exception: + return None + + +def get_system_info() -> SystemInfo: + """Gather complete system information.""" + return SystemInfo( + python_version=platform.python_version(), + platform=f"{platform.system()} {platform.release()}", + cuda_version=detect_cuda_version(), + gpus=detect_gpus(), + cpu=detect_cpu_info(), + ram_gb=detect_ram_gb(), + env_managers=detect_env_managers(), + ) + + +def is_in_virtual_env() -> bool: + """Check if currently running inside a virtual environment.""" + return ( + hasattr(sys, "real_prefix") + or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix) + or os.environ.get("VIRTUAL_ENV") is not None + or os.environ.get("CONDA_PREFIX") is not None + ) + + +def get_current_env_name() -> Optional[str]: + """Get the name of the current virtual environment.""" + if os.environ.get("CONDA_DEFAULT_ENV"): + return os.environ["CONDA_DEFAULT_ENV"] + if os.environ.get("VIRTUAL_ENV"): + return Path(os.environ["VIRTUAL_ENV"]).name + return None + + +# Import sys for is_in_virtual_env +import sys # noqa: E402 + + +@dataclass +class StorageLocation: + """Information about a storage location.""" + + path: str + available_gb: float + total_gb: float + is_writable: bool + mount_point: str + + +def scan_storage_locations(min_size_gb: float = 50.0) -> list[StorageLocation]: + """ + Scan system for potential model storage locations. + + Looks for: + - Large mounted filesystems (> min_size_gb) + - Common model storage paths + - User home directory + + Args: + min_size_gb: Minimum available space in GB to consider + + Returns: + List of StorageLocation sorted by available space (descending) + """ + locations: dict[str, StorageLocation] = {} # Use dict to deduplicate by path + + # Get all mount points from /proc/mounts (Linux) + mount_points = _get_mount_points() + + for mount_point in mount_points: + try: + available_gb, total_gb = detect_disk_space_gb(mount_point) + + # Skip small or pseudo filesystems + if total_gb < 10: + continue + + # Check if writable + is_writable = os.access(mount_point, os.W_OK) + + # Create potential model paths under this mount + potential_paths = _get_potential_model_paths(mount_point) + + for path in potential_paths: + if path in locations: + continue + + # Get actual available space for this path + path_available, path_total = detect_disk_space_gb(path) + + if path_available >= min_size_gb: + path_writable = os.access(path, os.W_OK) if os.path.exists(path) else is_writable + locations[path] = StorageLocation( + path=path, + available_gb=path_available, + total_gb=path_total, + is_writable=path_writable, + mount_point=mount_point, + ) + except (OSError, IOError): + continue + + # Also check common model storage locations + common_paths = [ + str(Path.home() / ".ktransformers" / "models"), + str(Path.home() / "models"), + str(Path.home() / ".cache" / "huggingface"), + "/data/models", + "/models", + "/opt/models", + ] + + for path in common_paths: + if path in locations: + continue + try: + # Check if parent exists for paths that don't exist yet + check_path = path + while not os.path.exists(check_path) and check_path != "/": + check_path = str(Path(check_path).parent) + + if os.path.exists(check_path): + available_gb, total_gb = detect_disk_space_gb(check_path) + if available_gb >= min_size_gb: + is_writable = os.access(check_path, os.W_OK) + locations[path] = StorageLocation( + path=path, + available_gb=available_gb, + total_gb=total_gb, + is_writable=is_writable, + mount_point=check_path, + ) + except (OSError, IOError): + continue + + # Sort by available space descending, then by path + sorted_locations = sorted(locations.values(), key=lambda x: (-x.available_gb, x.path)) + + # Filter to only writable locations + return [loc for loc in sorted_locations if loc.is_writable] + + +def _get_mount_points() -> list[str]: + """Get all mount points on the system.""" + mount_points = [] + + if platform.system() == "Linux": + try: + with open("/proc/mounts", "r") as f: + for line in f: + parts = line.split() + if len(parts) >= 2: + mount_point = parts[1] + fs_type = parts[2] if len(parts) > 2 else "" + + # Skip pseudo filesystems + skip_fs = { + "proc", + "sysfs", + "devpts", + "tmpfs", + "cgroup", + "cgroup2", + "pstore", + "securityfs", + "debugfs", + "hugetlbfs", + "mqueue", + "fusectl", + "configfs", + "devtmpfs", + "efivarfs", + "autofs", + "binfmt_misc", + "overlay", + "nsfs", + "tracefs", + } + if fs_type in skip_fs: + continue + + # Skip paths that are clearly system paths + skip_prefixes = ("/sys", "/proc", "/dev", "/run/user") + if any(mount_point.startswith(p) for p in skip_prefixes): + continue + + mount_points.append(mount_point) + except (OSError, IOError): + pass + + # Always include home and root + mount_points.extend([str(Path.home()), "/"]) + + # Deduplicate while preserving order + seen = set() + unique_mounts = [] + for mp in mount_points: + if mp not in seen: + seen.add(mp) + unique_mounts.append(mp) + + return unique_mounts + + +def _get_potential_model_paths(mount_point: str) -> list[str]: + """Get potential model storage paths under a mount point.""" + paths = [] + + # The mount point itself (for dedicated data drives) + if mount_point not in ("/", "/home"): + paths.append(mount_point) + paths.append(os.path.join(mount_point, "models")) + + # If it's under home, suggest standard locations + home = str(Path.home()) + if mount_point == home or mount_point == "/home": + paths.append(os.path.join(home, ".ktransformers", "models")) + paths.append(os.path.join(home, "models")) + + # For root mount, suggest /data or /opt + if mount_point == "/": + paths.extend(["/data/models", "/opt/models"]) + + # Check for common data directories on this mount + for subdir in ["data", "models", "ai", "llm", "huggingface"]: + potential = os.path.join(mount_point, subdir) + if os.path.exists(potential) and os.path.isdir(potential): + paths.append(potential) + + return paths + + +def format_size_gb(size_gb: float) -> str: + """Format size in GB to human readable string.""" + if size_gb >= 1000: + return f"{size_gb / 1000:.1f}TB" + return f"{size_gb:.1f}GB" + + +@dataclass +class LocalModel: + """Information about a locally detected model.""" + + name: str + path: str + size_gb: float + model_type: str # "huggingface", "gguf", "safetensors" + has_config: bool + file_count: int + + +def scan_local_models(search_paths: list[str], max_depth: int = 3) -> list[LocalModel]: + """ + Scan directories for locally downloaded models. + + Looks for: + - Directories with config.json (HuggingFace format) + - Directories with .safetensors files + - Directories with .gguf files + + Args: + search_paths: List of paths to search + max_depth: Maximum directory depth to search + + Returns: + List of LocalModel sorted by size (descending) + """ + models: dict[str, LocalModel] = {} # Use path as key to deduplicate + + for search_path in search_paths: + if not os.path.exists(search_path): + continue + + _scan_directory_for_models(search_path, models, current_depth=0, max_depth=max_depth) + + # Sort by size descending + return sorted(models.values(), key=lambda x: -x.size_gb) + + +def _scan_directory_for_models( + directory: str, models: dict[str, LocalModel], current_depth: int, max_depth: int +) -> None: + """Recursively scan a directory for models.""" + if current_depth > max_depth: + return + + try: + entries = list(os.scandir(directory)) + except (PermissionError, OSError): + return + + # Check if this directory is a model + model = _detect_model_in_directory(directory, entries) + if model: + models[model.path] = model + return # Don't scan subdirectories of a model + + # Scan subdirectories + for entry in entries: + if entry.is_dir() and not entry.name.startswith("."): + _scan_directory_for_models(entry.path, models, current_depth + 1, max_depth) + + +def _detect_model_in_directory(directory: str, entries: list) -> Optional[LocalModel]: + """Detect if a directory contains a model.""" + entry_names = {e.name for e in entries} + + has_config = "config.json" in entry_names + safetensor_files = [e for e in entries if e.name.endswith(".safetensors") and e.is_file()] + gguf_files = [e for e in entries if e.name.endswith(".gguf") and e.is_file()] + + # Determine model type + model_type = None + if has_config and safetensor_files: + model_type = "huggingface" + elif gguf_files: + model_type = "gguf" + elif safetensor_files: + model_type = "safetensors" + elif has_config: + # Config but no weights - might be incomplete + # Check for other model-related files + model_files = { + "model.safetensors.index.json", + "pytorch_model.bin.index.json", + "model.safetensors", + "pytorch_model.bin", + } + if entry_names & model_files: + model_type = "huggingface" + + if not model_type: + return None + + # Calculate directory size + size_bytes = _get_directory_size(directory) + size_gb = size_bytes / (1024**3) + + # Skip very small directories (likely incomplete or config-only) + if size_gb < 0.1: + return None + + # Get model name from directory name + name = os.path.basename(directory) + + # Count model files + file_count = len(safetensor_files) + len(gguf_files) + if not file_count: + # Count .bin files as fallback + file_count = len([e for e in entries if e.name.endswith(".bin") and e.is_file()]) + + return LocalModel( + name=name, + path=directory, + size_gb=round(size_gb, 1), + model_type=model_type, + has_config=has_config, + file_count=file_count, + ) + + +def _get_directory_size(directory: str) -> int: + """Get total size of a directory in bytes.""" + total_size = 0 + try: + for entry in os.scandir(directory): + try: + if entry.is_file(follow_symlinks=False): + total_size += entry.stat().st_size + elif entry.is_dir(follow_symlinks=False): + total_size += _get_directory_size(entry.path) + except (PermissionError, OSError): + continue + except (PermissionError, OSError): + pass + return total_size + + +def scan_models_in_location(location: StorageLocation, max_depth: int = 2) -> list[LocalModel]: + """Scan a storage location for models.""" + search_paths = [location.path] + + # Also check common subdirectories + for subdir in ["models", "huggingface", "hub", ".cache/huggingface/hub"]: + subpath = os.path.join(location.path, subdir) + if os.path.exists(subpath): + search_paths.append(subpath) + + return scan_local_models(search_paths, max_depth=max_depth) + + +@dataclass +class CPUBuildFeatures: + """CPU features for build configuration.""" + + has_amx: bool + has_avx512: bool + has_avx512_vnni: bool + has_avx512_bf16: bool + has_avx2: bool + recommended_instruct: str # NATIVE, AVX512, AVX2 + recommended_amx: bool + + +def detect_cpu_build_features() -> CPUBuildFeatures: + """ + Detect CPU features for build configuration. + + This is used to auto-configure kt-kernel source builds. + Reads /proc/cpuinfo on Linux to detect instruction set support. + + Returns: + CPUBuildFeatures with detection results + """ + has_amx = False + has_avx512 = False + has_avx512_vnni = False + has_avx512_bf16 = False + has_avx2 = False + + if platform.system() == "Linux": + try: + with open("/proc/cpuinfo", "r") as f: + content = f.read() + + # Get flags from first processor + for line in content.split("\n"): + if line.startswith("flags"): + flags = line.split(":")[1].strip().split() + flags_lower = {f.lower() for f in flags} + + # Check for AMX support (requires all three) + if {"amx_tile", "amx_int8", "amx_bf16"} <= flags_lower: + has_amx = True + + # Check for AVX512 support + if "avx512f" in flags_lower: + has_avx512 = True + + # Check for AVX512 VNNI + if "avx512_vnni" in flags_lower or "avx512vnni" in flags_lower: + has_avx512_vnni = True + + # Check for AVX512 BF16 + if "avx512_bf16" in flags_lower or "avx512bf16" in flags_lower: + has_avx512_bf16 = True + + # Check for AVX2 + if "avx2" in flags_lower: + has_avx2 = True + + break + except (OSError, IOError): + pass + + elif platform.system() == "Darwin": + # macOS - use sysctl + features_output = run_command(["sysctl", "-n", "machdep.cpu.features"]) + if features_output: + flags_lower = {f.lower() for f in features_output.split()} + has_avx2 = "avx2" in flags_lower + # macOS doesn't have AMX or AVX512 typically + + # Determine recommended configuration + if has_amx: + recommended_instruct = "NATIVE" + recommended_amx = True + elif has_avx512: + recommended_instruct = "NATIVE" + recommended_amx = False + elif has_avx2: + recommended_instruct = "NATIVE" + recommended_amx = False + else: + recommended_instruct = "AVX2" + recommended_amx = False + + return CPUBuildFeatures( + has_amx=has_amx, + has_avx512=has_avx512, + has_avx512_vnni=has_avx512_vnni, + has_avx512_bf16=has_avx512_bf16, + has_avx2=has_avx2, + recommended_instruct=recommended_instruct, + recommended_amx=recommended_amx, + ) diff --git a/kt-kernel/python/cli/utils/model_registry.py b/kt-kernel/python/cli/utils/model_registry.py new file mode 100644 index 0000000..f62154b --- /dev/null +++ b/kt-kernel/python/cli/utils/model_registry.py @@ -0,0 +1,374 @@ +""" +Model registry for kt-cli. + +Provides a registry of supported models with fuzzy matching capabilities. +""" + +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Optional + +import yaml + +from kt_kernel.cli.config.settings import get_settings + + +@dataclass +class ModelInfo: + """Information about a supported model.""" + + name: str + hf_repo: str + aliases: list[str] = field(default_factory=list) + type: str = "moe" # moe, dense + gpu_vram_gb: float = 0 + cpu_ram_gb: float = 0 + default_params: dict = field(default_factory=dict) + description: str = "" + description_zh: str = "" + max_tensor_parallel_size: Optional[int] = None # Maximum tensor parallel size for this model + + +# Built-in model registry +BUILTIN_MODELS: list[ModelInfo] = [ + ModelInfo( + name="DeepSeek-V3-0324", + hf_repo="deepseek-ai/DeepSeek-V3-0324", + aliases=["deepseek-v3-0324", "deepseek-v3", "dsv3", "deepseek3", "v3-0324"], + type="moe", + default_params={ + "kt-num-gpu-experts": 1, + "attention-backend": "triton", + "disable-shared-experts-fusion": True, + "kt-method": "AMXINT4", + }, + description="DeepSeek V3-0324 685B MoE model (March 2025, improved benchmarks)", + description_zh="DeepSeek V3-0324 685B MoE 模型(2025年3月,改进的基准测试)", + ), + ModelInfo( + name="DeepSeek-V3.2", + hf_repo="deepseek-ai/DeepSeek-V3.2", + aliases=["deepseek-v3.2", "dsv3.2", "deepseek3.2", "v3.2"], + type="moe", + default_params={ + "kt-method": "FP8", + "kt-gpu-prefill-token-threshold": 4096, + "attention-backend": "flashinfer", + "fp8-gemm-backend": "triton", + "max-total-tokens": 100000, + "max-running-requests": 16, + "chunked-prefill-size": 32768, + "mem-fraction-static": 0.80, + "watchdog-timeout": 3000, + "served-model-name": "DeepSeek-V3.2", + "disable-shared-experts-fusion": True, + }, + description="DeepSeek V3.2 671B MoE model (latest)", + description_zh="DeepSeek V3.2 671B MoE 模型(最新)", + ), + ModelInfo( + name="DeepSeek-R1-0528", + hf_repo="deepseek-ai/DeepSeek-R1-0528", + aliases=["deepseek-r1-0528", "deepseek-r1", "dsr1", "r1", "r1-0528"], + type="moe", + default_params={ + "kt-num-gpu-experts": 1, + "attention-backend": "triton", + "disable-shared-experts-fusion": True, + "kt-method": "AMXINT4", + }, + description="DeepSeek R1-0528 reasoning model (May 2025, improved reasoning depth)", + description_zh="DeepSeek R1-0528 推理模型(2025年5月,改进的推理深度)", + ), + ModelInfo( + name="Kimi-K2-Thinking", + hf_repo="moonshotai/Kimi-K2-Thinking", + aliases=["kimi-k2-thinking", "kimi-thinking", "k2-thinking", "kimi", "k2"], + type="moe", + default_params={ + "kt-method": "RAWINT4", + "kt-gpu-prefill-token-threshold": 400, + "attention-backend": "flashinfer", + "max-total-tokens": 100000, + "max-running-requests": 16, + "chunked-prefill-size": 32768, + "mem-fraction-static": 0.80, + "watchdog-timeout": 3000, + "served-model-name": "Kimi-K2-Thinking", + "disable-shared-experts-fusion": True, + }, + description="Moonshot Kimi K2 Thinking MoE model", + description_zh="月之暗面 Kimi K2 Thinking MoE 模型", + ), + ModelInfo( + name="MiniMax-M2", + hf_repo="MiniMaxAI/MiniMax-M2", + aliases=["minimax-m2", "m2"], + type="moe", + default_params={ + "kt-method": "FP8", + "kt-gpu-prefill-token-threshold": 4096, + "attention-backend": "flashinfer", + "fp8-gemm-backend": "triton", + "max-total-tokens": 100000, + "max-running-requests": 16, + "chunked-prefill-size": 32768, + "mem-fraction-static": 0.80, + "watchdog-timeout": 3000, + "served-model-name": "MiniMax-M2", + "disable-shared-experts-fusion": True, + "tool-call-parser": "minimax-m2", + "reasoning-parser": "minimax-append-think", + }, + description="MiniMax M2 MoE model", + description_zh="MiniMax M2 MoE 模型", + max_tensor_parallel_size=4, # M2 only supports up to 4-way tensor parallelism + ), + ModelInfo( + name="MiniMax-M2.1", + hf_repo="MiniMaxAI/MiniMax-M2.1", + aliases=["minimax-m2.1", "m2.1"], + type="moe", + default_params={ + "kt-method": "FP8", + "kt-gpu-prefill-token-threshold": 4096, + "attention-backend": "flashinfer", + "fp8-gemm-backend": "triton", + "max-total-tokens": 100000, + "max-running-requests": 16, + "chunked-prefill-size": 32768, + "mem-fraction-static": 0.80, + "watchdog-timeout": 3000, + "served-model-name": "MiniMax-M2.1", + "disable-shared-experts-fusion": True, + "tool-call-parser": "minimax-m2", + "reasoning-parser": "minimax-append-think", + }, + description="MiniMax M2.1 MoE model (enhanced multi-language programming)", + description_zh="MiniMax M2.1 MoE 模型(增强多语言编程能力)", + max_tensor_parallel_size=4, # M2.1 only supports up to 4-way tensor parallelism + ), +] + + +class ModelRegistry: + """Registry of supported models with fuzzy matching.""" + + def __init__(self): + """Initialize the model registry.""" + self._models: dict[str, ModelInfo] = {} + self._aliases: dict[str, str] = {} + self._load_builtin_models() + self._load_user_models() + + def _load_builtin_models(self) -> None: + """Load built-in models.""" + for model in BUILTIN_MODELS: + self._register(model) + + def _load_user_models(self) -> None: + """Load user-defined models from config.""" + settings = get_settings() + registry_file = settings.config_dir / "registry.yaml" + + if registry_file.exists(): + try: + with open(registry_file, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + + for name, info in data.get("models", {}).items(): + model = ModelInfo( + name=name, + hf_repo=info.get("hf_repo", ""), + aliases=info.get("aliases", []), + type=info.get("type", "moe"), + gpu_vram_gb=info.get("gpu_vram_gb", 0), + cpu_ram_gb=info.get("cpu_ram_gb", 0), + default_params=info.get("default_params", {}), + description=info.get("description", ""), + description_zh=info.get("description_zh", ""), + max_tensor_parallel_size=info.get("max_tensor_parallel_size"), + ) + self._register(model) + except (yaml.YAMLError, OSError): + pass + + def _register(self, model: ModelInfo) -> None: + """Register a model.""" + self._models[model.name.lower()] = model + + # Register aliases + for alias in model.aliases: + self._aliases[alias.lower()] = model.name.lower() + + def get(self, name: str) -> Optional[ModelInfo]: + """Get a model by exact name or alias.""" + name_lower = name.lower() + + # Check direct match + if name_lower in self._models: + return self._models[name_lower] + + # Check aliases + if name_lower in self._aliases: + return self._models[self._aliases[name_lower]] + + return None + + def search(self, query: str, limit: int = 10) -> list[ModelInfo]: + """Search for models using fuzzy matching. + + Args: + query: Search query + limit: Maximum number of results + + Returns: + List of matching models, sorted by relevance + """ + query_lower = query.lower() + results: list[tuple[float, ModelInfo]] = [] + + for model in self._models.values(): + score = self._match_score(query_lower, model) + if score > 0: + results.append((score, model)) + + # Sort by score descending + results.sort(key=lambda x: x[0], reverse=True) + + return [model for _, model in results[:limit]] + + def _match_score(self, query: str, model: ModelInfo) -> float: + """Calculate match score for a model. + + Returns a score between 0 and 1, where 1 is an exact match. + """ + # Check exact match + if query == model.name.lower(): + return 1.0 + + # Check alias exact match + for alias in model.aliases: + if query == alias.lower(): + return 0.95 + + # Check if query is contained in name + if query in model.name.lower(): + return 0.8 + + # Check if query is contained in aliases + for alias in model.aliases: + if query in alias.lower(): + return 0.7 + + # Check if query is contained in hf_repo + if query in model.hf_repo.lower(): + return 0.6 + + # Fuzzy matching - check if all query parts are present + query_parts = re.split(r"[-_.\s]", query) + name_lower = model.name.lower() + + matches = sum(1 for part in query_parts if part and part in name_lower) + if matches > 0: + return 0.5 * (matches / len(query_parts)) + + return 0.0 + + def list_all(self) -> list[ModelInfo]: + """List all registered models.""" + return list(self._models.values()) + + def find_local_models(self) -> list[tuple[ModelInfo, Path]]: + """Find models that are downloaded locally in any configured model path. + + Returns: + List of (ModelInfo, path) tuples for local models + """ + settings = get_settings() + model_paths = settings.get_model_paths() + results = [] + + for model in self._models.values(): + found = False + # Search in all configured model directories + for models_dir in model_paths: + if not models_dir.exists(): + continue + + # Check common path patterns + possible_paths = [ + models_dir / model.name, + models_dir / model.name.lower(), + models_dir / model.hf_repo.split("/")[-1], + models_dir / model.hf_repo.replace("/", "--"), + ] + + for path in possible_paths: + if path.exists() and (path / "config.json").exists(): + results.append((model, path)) + found = True + break + + if found: + break + + return results + + +# Global registry instance +_registry: Optional[ModelRegistry] = None + + +def get_registry() -> ModelRegistry: + """Get the global model registry instance.""" + global _registry + if _registry is None: + _registry = ModelRegistry() + return _registry + + +# ============================================================================ +# Model-specific parameter computation functions +# ============================================================================ + + +def compute_deepseek_v3_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int: + per_gpu_gb = 16 + if vram_per_gpu_gb < per_gpu_gb: + return int(0) + total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb)) + + return total_vram // 3 + + +def compute_kimi_k2_thinking_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int: + """Compute kt-num-gpu-experts for Kimi K2 Thinking.""" + per_gpu_gb = 16 + if vram_per_gpu_gb < per_gpu_gb: + return int(0) + total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb)) + + return total_vram * 2 // 3 + + +def compute_minimax_m2_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int: + """Compute kt-num-gpu-experts for MiniMax M2/M2.1.""" + per_gpu_gb = 16 + if vram_per_gpu_gb < per_gpu_gb: + return int(0) + total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb)) + + return total_vram // 1 + + +# Model name to computation function mapping +MODEL_COMPUTE_FUNCTIONS: dict[str, Callable[[int, float], int]] = { + "DeepSeek-V3-0324": compute_deepseek_v3_gpu_experts, + "DeepSeek-V3.2": compute_deepseek_v3_gpu_experts, # Same as V3-0324 + "DeepSeek-R1-0528": compute_deepseek_v3_gpu_experts, # Same as V3-0324 + "Kimi-K2-Thinking": compute_kimi_k2_thinking_gpu_experts, + "MiniMax-M2": compute_minimax_m2_gpu_experts, + "MiniMax-M2.1": compute_minimax_m2_gpu_experts, # Same as M2 +} diff --git a/kt-kernel/python/cli/utils/sglang_checker.py b/kt-kernel/python/cli/utils/sglang_checker.py new file mode 100644 index 0000000..604098b --- /dev/null +++ b/kt-kernel/python/cli/utils/sglang_checker.py @@ -0,0 +1,407 @@ +""" +SGLang installation checker and installation instructions provider. + +This module provides utilities to: +- Check if SGLang is installed and get its metadata +- Provide installation instructions when SGLang is not found +""" + +import subprocess +import sys +from pathlib import Path +from typing import Optional + +from kt_kernel.cli.i18n import t +from kt_kernel.cli.utils.console import console + + +def check_sglang_installation() -> dict: + """Check if SGLang is installed and get its metadata. + + Returns: + dict with keys: + - installed: bool + - version: str or None + - location: str or None (installation path) + - editable: bool (whether installed in editable mode) + - git_info: dict or None (git remote and branch if available) + - from_source: bool (whether installed from source repository) + """ + try: + # Try to import sglang + import sglang + + version = getattr(sglang, "__version__", None) + + # Use pip show to get detailed package information + location = None + editable = False + git_info = None + from_source = False + + try: + # Get pip show output + result = subprocess.run( + [sys.executable, "-m", "pip", "show", "sglang"], + capture_output=True, + text=True, + timeout=10, + ) + + if result.returncode == 0: + pip_info = {} + for line in result.stdout.split("\n"): + if ":" in line: + key, value = line.split(":", 1) + pip_info[key.strip()] = value.strip() + + location = pip_info.get("Location") + editable_location = pip_info.get("Editable project location") + + if editable_location: + editable = True + location = editable_location + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + # Fallback to module location + if hasattr(sglang, "__file__") and sglang.__file__: + location = str(Path(sglang.__file__).parent.parent) + + # Check if it's installed from source (has .git directory) + if location: + git_root = None + check_path = Path(location) + + # Check current directory and up to 2 parent directories + for _ in range(3): + git_dir = check_path / ".git" + if git_dir.exists(): + git_root = check_path + from_source = True + break + if check_path.parent == check_path: # Reached root + break + check_path = check_path.parent + + if from_source and git_root: + # Try to get git remote and branch info + try: + # Get remote URL + result = subprocess.run( + ["git", "remote", "get-url", "origin"], + cwd=git_root, + capture_output=True, + text=True, + timeout=5, + ) + remote_url = result.stdout.strip() if result.returncode == 0 else None + + # Extract org/repo from URL + remote_short = None + if remote_url: + # Handle both https and git@ URLs + if "github.com" in remote_url: + parts = remote_url.rstrip("/").replace(".git", "").split("github.com")[-1] + remote_short = parts.lstrip("/").lstrip(":") + + # Get current branch + result = subprocess.run( + ["git", "branch", "--show-current"], + cwd=git_root, + capture_output=True, + text=True, + timeout=5, + ) + branch = result.stdout.strip() if result.returncode == 0 else None + + if remote_url or branch: + git_info = { + "remote": remote_short or remote_url, + "branch": branch, + } + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + pass + + return { + "installed": True, + "version": version, + "location": location, + "editable": editable, + "git_info": git_info, + "from_source": from_source, + } + except ImportError: + return { + "installed": False, + "version": None, + "location": None, + "editable": False, + "git_info": None, + "from_source": False, + } + + +def get_sglang_install_instructions(lang: Optional[str] = None) -> str: + """Get SGLang installation instructions. + + Args: + lang: Language code ('en' or 'zh'). If None, uses current language setting. + + Returns: + Formatted installation instructions string. + """ + from kt_kernel.cli.i18n import get_lang + + if lang is None: + lang = get_lang() + + if lang == "zh": + return """ +[bold yellow]SGLang \u672a\u5b89\u88c5[/bold yellow] + +\u8bf7\u6309\u7167\u4ee5\u4e0b\u6b65\u9aa4\u5b89\u88c5 SGLang: + +[bold]1. \u514b\u9686\u4ed3\u5e93:[/bold] + git clone https://github.com/kvcache-ai/sglang.git + cd sglang + +[bold]2. \u5b89\u88c5 (\u4e8c\u9009\u4e00):[/bold] + + [cyan]\u65b9\u5f0f A - pip \u5b89\u88c5 (\u63a8\u8350):[/cyan] + pip install -e "python[all]" + + [cyan]\u65b9\u5f0f B - uv \u5b89\u88c5 (\u66f4\u5feb):[/cyan] + pip install uv + uv pip install -e "python[all]" + +[dim]\u6ce8\u610f: \u8bf7\u786e\u4fdd\u5728\u6b63\u786e\u7684 Python \u73af\u5883\u4e2d\u6267\u884c\u4ee5\u4e0a\u547d\u4ee4[/dim] +""" + else: + return """ +[bold yellow]SGLang is not installed[/bold yellow] + +Please follow these steps to install SGLang: + +[bold]1. Clone the repository:[/bold] + git clone https://github.com/kvcache-ai/sglang.git + cd sglang + +[bold]2. Install (choose one):[/bold] + + [cyan]Option A - pip install (recommended):[/cyan] + pip install -e "python[all]" + + [cyan]Option B - uv install (faster):[/cyan] + pip install uv + uv pip install -e "python[all]" + +[dim]Note: Make sure to run these commands in the correct Python environment[/dim] +""" + + +def print_sglang_install_instructions() -> None: + """Print SGLang installation instructions to console.""" + instructions = get_sglang_install_instructions() + console.print(instructions) + + +def check_sglang_and_warn() -> bool: + """Check if SGLang is installed, print warning if not. + + Returns: + True if SGLang is installed, False otherwise. + """ + info = check_sglang_installation() + + if not info["installed"]: + print_sglang_install_instructions() + return False + + # Check if installed from PyPI (not recommended) + if info["installed"] and not info["from_source"]: + from kt_kernel.cli.utils.console import print_warning + + print_warning(t("sglang_pypi_warning")) + console.print() + console.print("[dim]" + t("sglang_recommend_source") + "[/dim]") + console.print() + + return True + + +def _get_sglang_kt_kernel_cache_path() -> Path: + """Get the path to the sglang kt-kernel support cache file.""" + cache_dir = Path.home() / ".ktransformers" / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / "sglang_kt_kernel_supported" + + +def _is_sglang_kt_kernel_cache_valid() -> bool: + """Check if the sglang kt-kernel support cache is valid. + + The cache is considered valid if: + 1. The cache file exists + 2. The cache file contains 'true' (indicating previous check passed) + + Returns: + True if cache is valid and indicates support, False otherwise. + """ + cache_path = _get_sglang_kt_kernel_cache_path() + if cache_path.exists(): + try: + content = cache_path.read_text().strip().lower() + return content == "true" + except (OSError, IOError): + pass + return False + + +def _save_sglang_kt_kernel_cache(supported: bool) -> None: + """Save the sglang kt-kernel support check result to cache.""" + cache_path = _get_sglang_kt_kernel_cache_path() + try: + cache_path.write_text("true" if supported else "false") + except (OSError, IOError): + pass # Ignore cache write errors + + +def clear_sglang_kt_kernel_cache() -> None: + """Clear the sglang kt-kernel support cache, forcing a re-check on next run.""" + cache_path = _get_sglang_kt_kernel_cache_path() + try: + if cache_path.exists(): + cache_path.unlink() + except (OSError, IOError): + pass + + +def check_sglang_kt_kernel_support(use_cache: bool = True, silent: bool = False) -> dict: + """Check if SGLang supports kt-kernel parameters (--kt-gpu-prefill-token-threshold). + + This function runs `python -m sglang.launch_server --help` and checks if the + output contains the `--kt-gpu-prefill-token-threshold` parameter. This parameter + is only available in the kvcache-ai/sglang fork, not in the official sglang. + + The result is cached after the first successful check to avoid repeated checks. + + Args: + use_cache: If True, use cached result if available. Default is True. + silent: If True, don't print checking message. Default is False. + + Returns: + dict with keys: + - supported: bool - True if kt-kernel parameters are supported + - help_output: str or None - The help output from sglang.launch_server + - error: str or None - Error message if check failed + - from_cache: bool - True if result was from cache + """ + from kt_kernel.cli.utils.console import print_step + + # Check cache first + if use_cache and _is_sglang_kt_kernel_cache_valid(): + return { + "supported": True, + "help_output": None, + "error": None, + "from_cache": True, + } + + # Print checking message + if not silent: + print_step(t("sglang_checking_kt_kernel_support")) + + try: + result = subprocess.run( + [sys.executable, "-m", "sglang.launch_server", "--help"], + capture_output=True, + text=True, + timeout=30, + ) + + help_output = result.stdout + result.stderr + + # Check if --kt-gpu-prefill-token-threshold is in the help output + supported = "--kt-gpu-prefill-token-threshold" in help_output + + # Save to cache if supported + if supported: + _save_sglang_kt_kernel_cache(True) + + return { + "supported": supported, + "help_output": help_output, + "error": None, + "from_cache": False, + } + + except subprocess.TimeoutExpired: + return { + "supported": False, + "help_output": None, + "error": "Timeout while checking sglang.launch_server --help", + "from_cache": False, + } + except FileNotFoundError: + return { + "supported": False, + "help_output": None, + "error": "Python interpreter not found", + "from_cache": False, + } + except Exception as e: + return { + "supported": False, + "help_output": None, + "error": str(e), + "from_cache": False, + } + + +def print_sglang_kt_kernel_instructions() -> None: + """Print instructions for installing the kvcache-ai fork of SGLang with kt-kernel support.""" + from kt_kernel.cli.i18n import get_lang + + lang = get_lang() + + if lang == "zh": + instructions = """ +[bold red]SGLang 不支持 kt-kernel[/bold red] + +您当前安装的 SGLang 不包含 kt-kernel 支持。 +kt-kernel 需要使用 kvcache-ai 维护的 SGLang 分支。 + +[bold]请按以下步骤重新安装 SGLang:[/bold] + +[cyan]1. 卸载当前的 SGLang:[/cyan] + pip uninstall sglang -y + +[cyan]2. 克隆 kvcache-ai 的 SGLang 仓库:[/cyan] + git clone https://github.com/kvcache-ai/sglang.git + cd sglang + +[cyan]3. 安装 SGLang:[/cyan] + pip install -e "python[all]" + +[dim]注意: 请确保在正确的 Python 环境中执行以上命令[/dim] +""" + else: + instructions = """ +[bold red]SGLang does not support kt-kernel[/bold red] + +Your current SGLang installation does not include kt-kernel support. +kt-kernel requires the kvcache-ai maintained fork of SGLang. + +[bold]Please reinstall SGLang with the following steps:[/bold] + +[cyan]1. Uninstall current SGLang:[/cyan] + pip uninstall sglang -y + +[cyan]2. Clone the kvcache-ai SGLang repository:[/cyan] + git clone https://github.com/kvcache-ai/sglang.git + cd sglang + +[cyan]3. Install SGLang:[/cyan] + pip install -e "python[all]" + +[dim]Note: Make sure to run these commands in the correct Python environment[/dim] +""" + console.print(instructions) diff --git a/kt-kernel/python/experts.py b/kt-kernel/python/experts.py index 0f89a75..1753d29 100644 --- a/kt-kernel/python/experts.py +++ b/kt-kernel/python/experts.py @@ -17,7 +17,7 @@ from typing import List, Optional from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer # Import backend implementations -from .utils.amx import AMXMoEWrapper, RAWAMXMoEWrapper +from .utils.amx import AMXMoEWrapper, NativeMoEWrapper from .utils.llamafile import LlamafileMoEWrapper from .utils.moe_kernel import GeneralMoEWrapper @@ -77,7 +77,7 @@ class KTMoEWrapper: chunked_prefill_size: Maximum prefill chunk size cpu_save: Whether to save weights to CPU memory max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. - method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "LLAMAFILE", "MOE_INT4", "MOE_INT8") + method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "LLAMAFILE", "MOE_INT4", "MOE_INT8") Returns: An instance of the appropriate backend implementation (e.g., AMXMoEWrapper) @@ -85,8 +85,8 @@ class KTMoEWrapper: # Select backend based on method if method in ["AMXINT4", "AMXINT8"]: backend_cls = AMXMoEWrapper - elif method == "RAWINT4": - backend_cls = RAWAMXMoEWrapper + elif method in ["RAWINT4", "FP8"]: + backend_cls = NativeMoEWrapper elif method == "LLAMAFILE": backend_cls = LlamafileMoEWrapper elif method in ["MOE_INT4", "MOE_INT8"]: diff --git a/kt-kernel/python/utils/__init__.py b/kt-kernel/python/utils/__init__.py index 729699f..c5715fa 100644 --- a/kt-kernel/python/utils/__init__.py +++ b/kt-kernel/python/utils/__init__.py @@ -4,13 +4,13 @@ Utilities for kt_kernel package. """ -from .amx import AMXMoEWrapper, RAWAMXMoEWrapper +from .amx import AMXMoEWrapper, NativeMoEWrapper from .llamafile import LlamafileMoEWrapper from .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader __all__ = [ "AMXMoEWrapper", - "RAWAMXMoEWrapper", + "NativeMoEWrapper", "LlamafileMoEWrapper", "SafeTensorLoader", "CompressedSafeTensorLoader", diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index 273c8a5..c18800b 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -4,16 +4,16 @@ import ctypes # Use relative imports for package structure from ..experts_base import BaseMoEWrapper -from .loader import SafeTensorLoader, CompressedSafeTensorLoader +from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader from kt_kernel_ext.moe import MOEConfig try: - from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE + from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE _HAS_AMX_SUPPORT = True except (ImportError, AttributeError): _HAS_AMX_SUPPORT = False - AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None + AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE = None, None, None, None from typing import Optional @@ -303,10 +303,10 @@ class AMXMoEWrapper(BaseMoEWrapper): del self.down_scales -class RAWAMXMoEWrapper(BaseMoEWrapper): - """Wrapper for RAWINT4 experts stored in compressed SafeTensor format.""" +class NativeMoEWrapper(BaseMoEWrapper): + """Wrapper for RAWINT4/FP8 experts stored in compressed SafeTensor format.""" - _compressed_loader_instance = None + _native_loader_instance = None def __init__( self, @@ -324,8 +324,12 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): max_deferred_experts_per_token: Optional[int] = None, method: str = "RAWINT4", ): - if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None: + if not _HAS_AMX_SUPPORT: + raise RuntimeError("AMX backend is not available.") + if method == "RAWINT4" and AMXInt4_KGroup_MOE is None: raise RuntimeError("AMX backend with RAWINT4 support is not available.") + if method == "FP8" and AMXFP8_MOE is None: + raise RuntimeError("AMX backend with FP8 support is not available.") super().__init__( layer_idx=layer_idx, @@ -343,9 +347,14 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): method=method, ) - if RAWAMXMoEWrapper._compressed_loader_instance is None: - RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path) - self.loader = RAWAMXMoEWrapper._compressed_loader_instance + if NativeMoEWrapper._native_loader_instance is None: + if method == "RAWINT4": + NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path) + elif method == "FP8": + NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path) + else: + raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}") + self.loader = NativeMoEWrapper._native_loader_instance self.gate_weights = None self.up_weights = None @@ -378,9 +387,17 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): self.down_weights = weights["down"] # Convert scales to bf16 individually - self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]] - self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]] - self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]] + # self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]] + # self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]] + # self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]] + self.gate_scales = weights["gate_scale"] + self.up_scales = weights["up_scale"] + self.down_scales = weights["down_scale"] + if self.method == "RAWINT4": + assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4" + elif self.method == "FP8": + assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8" + t2 = time.time() # Build pointer lists: [numa_id][expert_id] -> pointer @@ -404,18 +421,6 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): moe_config.pool = self.cpu_infer.backend_ moe_config.max_len = self.chunked_prefill_size - # Infer group_size from scale shape (column-major layout) - # For gate/up projection: in_features = hidden_size - # So: group_size = hidden_size / scale.shape[1] - scale_shape = self.gate_scales[0].shape - group_size = self.hidden_size // scale_shape[1] - print(f"[RAWAMXMoEWrapper Layer {self.layer_idx}] Inferred group_size: {group_size}") - - moe_config.quant_config.bits = 4 - moe_config.quant_config.group_size = group_size - - moe_config.quant_config.zero_point = False - # Use gate_projs instead of gate_proj for per-expert pointers moe_config.gate_projs = gate_ptrs moe_config.up_projs = up_ptrs @@ -424,7 +429,21 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): moe_config.up_scales = up_scale_ptrs moe_config.down_scales = down_scale_ptrs - self.moe = AMXInt4_KGroup_MOE(moe_config) + # Infer group_size from scale shape (column-major layout) + # For gate/up projection: in_features = hidden_size + # So: group_size = hidden_size / scale.shape[1] + + if self.method == "RAWINT4": + group_size = self.hidden_size // self.gate_scales[0].shape[1] + moe_config.quant_config.bits = 4 + moe_config.quant_config.group_size = group_size + moe_config.quant_config.zero_point = False + self.moe = AMXInt4_KGroup_MOE(moe_config) + elif self.method == "FP8": + moe_config.quant_config.bits = 8 + moe_config.quant_config.group_size = 128 + moe_config.quant_config.zero_point = False + self.moe = AMXFP8_MOE(moe_config) t4 = time.time() self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr())) @@ -440,7 +459,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): t6 = time.time() print( - f"[RAWAMXMoEWrapper Layer {self.layer_idx}] " + f"[NativeMoEWrapper Layer {self.layer_idx}] " f"load_experts: {(t1-t0)*1000:.1f}ms, " f"prepare_tensors: {(t2-t1)*1000:.1f}ms, " f"build_ptrs: {(t3-t2)*1000:.1f}ms, " @@ -453,7 +472,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): def submit_write_weight_scale_to_buffer( self, gpu_tp_count: int, - gpu_experts_num: int, + expert_id: int, w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, @@ -477,7 +496,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper): self.cpu_infer.submit( self.moe.write_weight_scale_to_buffer_task( gpu_tp_count, - gpu_experts_num, + expert_id, w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, diff --git a/kt-kernel/python/utils/llamafile.py b/kt-kernel/python/utils/llamafile.py index 68dce64..d6086a9 100644 --- a/kt-kernel/python/utils/llamafile.py +++ b/kt-kernel/python/utils/llamafile.py @@ -219,4 +219,4 @@ class LlamafileMoEWrapper(BaseMoEWrapper): self.cpu_infer.sync() # Drop original weights after loading - self.weights_to_keep = None \ No newline at end of file + self.weights_to_keep = None diff --git a/kt-kernel/python/utils/loader.py b/kt-kernel/python/utils/loader.py index db689f4..852ceab 100644 --- a/kt-kernel/python/utils/loader.py +++ b/kt-kernel/python/utils/loader.py @@ -237,6 +237,117 @@ class SafeTensorLoader: return name in self.tensor_file_map +class FP8SafeTensorLoader(SafeTensorLoader): + """Loader for FP8 expert weights with auto-detection of naming formats. + + Supported formats: + - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight + - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight + + The format is auto-detected during initialization. + """ + + # Known MoE naming formats: (experts_path_template, gate_name, up_name, down_name) + MOE_FORMATS = { + "deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"), + "mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"), + } + + def __init__(self, file_path: str): + super().__init__(file_path) + self._detected_format = None + self._detect_format() + + def _detect_format(self): + """Auto-detect the MoE naming format by checking tensor keys.""" + # Sample some tensor names to detect format + sample_keys = list(self.tensor_file_map.keys())[:1000] + + for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items(): + # Check if any key matches this format pattern + # Look for pattern like: model.layers.0.{experts_path}.0.{gate_name}.weight + for key in sample_keys: + if ".experts." in key and f".{gate}.weight" in key: + # Verify the path template matches + if "block_sparse_moe.experts" in key and fmt_name == "mixtral": + self._detected_format = fmt_name + print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}") + return + elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek": + self._detected_format = fmt_name + print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}") + return + + # Default to deepseek if no format detected + self._detected_format = "deepseek" + print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek") + + def _get_experts_prefix(self, base_key: str) -> str: + """Get the experts prefix based on detected format.""" + path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format] + return path_tpl.format(base=base_key) + + def _get_proj_names(self): + """Get projection names (gate, up, down) based on detected format.""" + _, gate, up, down = self.MOE_FORMATS[self._detected_format] + return gate, up, down + + def load_tensor(self, key: str, device: str = "cpu"): + if key not in self.tensor_file_map: + raise KeyError(f"Key {key} not found in Safetensor files") + file = self.tensor_file_map[key] + f = self.file_handle_map.get(file) + if f is None: + raise FileNotFoundError(f"File {file} not found in Safetensor files") + tensor = f.get_tensor(key) + if device == "cpu": + return tensor + return tensor.to(device) + + def load_experts(self, base_key: str, device: str = "cpu"): + """Load FP8 expert weights and their block-wise scale_inv tensors.""" + experts_prefix = self._get_experts_prefix(base_key) + gate_name, up_name, down_name = self._get_proj_names() + + expert_count = 0 + while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"): + expert_count += 1 + + if expert_count == 0: + raise ValueError(f"No experts found for key {experts_prefix}") + + gate_weights = [None] * expert_count + up_weights = [None] * expert_count + down_weights = [None] * expert_count + gate_scales = [None] * expert_count + up_scales = [None] * expert_count + down_scales = [None] * expert_count + + for exp_id in range(expert_count): + gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight" + up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight" + down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight" + gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight_scale_inv" + up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.weight_scale_inv" + down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.weight_scale_inv" + + gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous() + up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous() + down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous() + gate_scales[exp_id] = self.load_tensor(gate_s_key, device).contiguous() + up_scales[exp_id] = self.load_tensor(up_s_key, device).contiguous() + down_scales[exp_id] = self.load_tensor(down_s_key, device).contiguous() + + return { + "gate": gate_weights, + "up": up_weights, + "down": down_weights, + "gate_scale": gate_scales, + "up_scale": up_scales, + "down_scale": down_scales, + } + + class CompressedSafeTensorLoader(SafeTensorLoader): """Loader for compressed SafeTensor layouts (RAWINT4 weights).""" diff --git a/kt-kernel/setup.py b/kt-kernel/setup.py index f63da9e..73c47a6 100644 --- a/kt-kernel/setup.py +++ b/kt-kernel/setup.py @@ -285,9 +285,9 @@ class CMakeBuild(build_ext): # Variant configurations: (name, CPUINFER_CPU_INSTRUCT, CPUINFER_ENABLE_AMX) variants = [ - ("amx", "AVX512", "ON"), # AVX512 + AMX + ("amx", "AVX512", "ON"), # AVX512 + AMX ("avx512", "AVX512", "OFF"), # AVX512 only - ("avx2", "AVX2", "OFF"), # AVX2 only + ("avx2", "AVX2", "OFF"), # AVX2 only ] for variant_name, cpu_instruct, enable_amx in variants: @@ -384,6 +384,7 @@ class CMakeBuild(build_ext): build_temp: Temporary build directory for CMake cfg: Build type (Release/Debug/etc.) """ + # Auto-detect CUDA toolkit if user did not explicitly set CPUINFER_USE_CUDA def detect_cuda_toolkit() -> bool: # Respect CUDA_HOME @@ -614,10 +615,26 @@ setup( author="kvcache-ai", license="Apache-2.0", python_requires=">=3.8", - packages=["kt_kernel", "kt_kernel.utils"], + packages=[ + "kt_kernel", + "kt_kernel.utils", + "kt_kernel.cli", + "kt_kernel.cli.commands", + "kt_kernel.cli.config", + "kt_kernel.cli.utils", + ], package_dir={ "kt_kernel": "python", "kt_kernel.utils": "python/utils", + "kt_kernel.cli": "python/cli", + "kt_kernel.cli.commands": "python/cli/commands", + "kt_kernel.cli.config": "python/cli/config", + "kt_kernel.cli.utils": "python/cli/utils", + }, + entry_points={ + "console_scripts": [ + "kt=kt_kernel.cli.main:main", + ], }, ext_modules=[CMakeExtension("kt_kernel.kt_kernel_ext", str(REPO_ROOT))], cmdclass={"build_ext": CMakeBuild}, diff --git a/kt-kernel/test/per_commit/test_basic_cpu.py b/kt-kernel/test/per_commit/test_basic_cpu.py index 46c3c0a..92aad91 100644 --- a/kt-kernel/test/per_commit/test_basic_cpu.py +++ b/kt-kernel/test/per_commit/test_basic_cpu.py @@ -17,6 +17,7 @@ register_cpu_ci(est_time=30, suite="default") # Check if kt_kernel_ext is available try: import kt_kernel # Import kt_kernel first to register kt_kernel_ext + kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module HAS_KT_KERNEL = True except ImportError: @@ -51,7 +52,7 @@ def test_basic_module_attributes(): pytest.skip("kt_kernel_ext not built or available") # Check for key attributes/functions - assert hasattr(kt_kernel_ext, 'CPUInfer'), "kt_kernel_ext should have CPUInfer class" + assert hasattr(kt_kernel_ext, "CPUInfer"), "kt_kernel_ext should have CPUInfer class" def run_all_tests(): diff --git a/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4.py b/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4.py index 9ded113..c59818c 100644 --- a/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4.py +++ b/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4.py @@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default") try: import torch import kt_kernel # Import kt_kernel first to register kt_kernel_ext + kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module HAS_DEPS = True except ImportError as e: @@ -68,9 +69,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): if num_tokens == 0: continue tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = mlp_torch( - tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i] - ) + expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) outputs.append(expert_out) start_idx = end_idx @@ -96,9 +95,7 @@ def test_moe_amx_int4_accuracy(): pytest.skip(f"Dependencies not available: {import_error}") global physical_to_logical_map - physical_to_logical_map = torch.tensor( - data=range(expert_num), device="cpu", dtype=torch.int64 - ).contiguous() + physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() CPUInfer = kt_kernel_ext.CPUInfer(60) @@ -133,9 +130,7 @@ def test_moe_amx_int4_accuracy(): ) # Create MOE config - config = kt_kernel_ext.moe.MOEConfig( - expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0 - ) + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) config.max_len = max_len config.gate_proj = gate_proj.data_ptr() config.up_proj = up_proj.data_ptr() @@ -176,14 +171,10 @@ def test_moe_amx_int4_accuracy(): CPUInfer.sync() # Run torch reference - t_output = moe_torch( - input_data, expert_ids, weights, gate_proj, up_proj, down_proj - ) + t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj) # Calculate relative difference - diff = torch.mean(torch.abs(output - t_output)) / torch.mean( - torch.abs(t_output) - ) + diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output)) print(f"Iteration {i}, diff = {diff:.6f}") # INT4 should have diff < 0.35 @@ -205,6 +196,7 @@ def run_all_tests(): except Exception as e: print(f"\n✗ Test failed: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1.py b/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1.py index 30f88aa..61f59c7 100644 --- a/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1.py +++ b/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1.py @@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default") try: import torch import kt_kernel # Import kt_kernel first to register kt_kernel_ext + kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module HAS_DEPS = True except ImportError as e: @@ -68,9 +69,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): if num_tokens == 0: continue tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = mlp_torch( - tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i] - ) + expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) outputs.append(expert_out) start_idx = end_idx @@ -96,9 +95,7 @@ def test_moe_amx_int4_1_accuracy(): pytest.skip(f"Dependencies not available: {import_error}") global physical_to_logical_map - physical_to_logical_map = torch.tensor( - data=range(expert_num), device="cpu", dtype=torch.int64 - ).contiguous() + physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() CPUInfer = kt_kernel_ext.CPUInfer(60) @@ -133,9 +130,7 @@ def test_moe_amx_int4_1_accuracy(): ) # Create MOE config - config = kt_kernel_ext.moe.MOEConfig( - expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0 - ) + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) config.max_len = max_len config.gate_proj = gate_proj.data_ptr() config.up_proj = up_proj.data_ptr() @@ -176,14 +171,10 @@ def test_moe_amx_int4_1_accuracy(): CPUInfer.sync() # Run torch reference - t_output = moe_torch( - input_data, expert_ids, weights, gate_proj, up_proj, down_proj - ) + t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj) # Calculate relative difference - diff = torch.mean(torch.abs(output - t_output)) / torch.mean( - torch.abs(t_output) - ) + diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output)) print(f"Iteration {i}, diff = {diff:.6f}") # INT4_1 should have diff < 0.35 @@ -205,6 +196,7 @@ def run_all_tests(): except Exception as e: print(f"\n✗ Test failed: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1k.py b/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1k.py index 90e7501..8f0b943 100644 --- a/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1k.py +++ b/kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1k.py @@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default") try: import torch import kt_kernel # Import kt_kernel first to register kt_kernel_ext + kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module HAS_DEPS = True except ImportError as e: @@ -69,9 +70,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): if num_tokens == 0: continue tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = mlp_torch( - tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i] - ) + expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) outputs.append(expert_out) start_idx = end_idx @@ -97,9 +96,7 @@ def test_moe_amx_int4_1k_accuracy(): pytest.skip(f"Dependencies not available: {import_error}") global physical_to_logical_map - physical_to_logical_map = torch.tensor( - data=range(expert_num), device="cpu", dtype=torch.int64 - ).contiguous() + physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() CPUInfer = kt_kernel_ext.CPUInfer(60) @@ -134,9 +131,7 @@ def test_moe_amx_int4_1k_accuracy(): ) # Create MOE config - config = kt_kernel_ext.moe.MOEConfig( - expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0 - ) + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) config.max_len = max_len config.gate_proj = gate_proj.data_ptr() config.up_proj = up_proj.data_ptr() @@ -180,14 +175,10 @@ def test_moe_amx_int4_1k_accuracy(): CPUInfer.sync() # Run torch reference - t_output = moe_torch( - input_data, expert_ids, weights, gate_proj, up_proj, down_proj - ) + t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj) # Calculate relative difference - diff = torch.mean(torch.abs(output - t_output)) / torch.mean( - torch.abs(t_output) - ) + diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output)) print(f"Iteration {i}, diff = {diff:.6f}") # INT4_1K should have diff < 0.35 @@ -209,6 +200,7 @@ def run_all_tests(): except Exception as e: print(f"\n✗ Test failed: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/kt-kernel/test/per_commit/test_moe_amx_accuracy_int8.py b/kt-kernel/test/per_commit/test_moe_amx_accuracy_int8.py index eb91535..f0b4590 100644 --- a/kt-kernel/test/per_commit/test_moe_amx_accuracy_int8.py +++ b/kt-kernel/test/per_commit/test_moe_amx_accuracy_int8.py @@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default") try: import torch import kt_kernel # Import kt_kernel first to register kt_kernel_ext + kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module HAS_DEPS = True except ImportError as e: @@ -68,9 +69,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): if num_tokens == 0: continue tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = mlp_torch( - tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i] - ) + expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) outputs.append(expert_out) start_idx = end_idx @@ -96,9 +95,7 @@ def test_moe_amx_int8_accuracy(): pytest.skip(f"Dependencies not available: {import_error}") global physical_to_logical_map - physical_to_logical_map = torch.tensor( - data=range(expert_num), device="cpu", dtype=torch.int64 - ).contiguous() + physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() CPUInfer = kt_kernel_ext.CPUInfer(60) @@ -133,9 +130,7 @@ def test_moe_amx_int8_accuracy(): ) # Create MOE config - config = kt_kernel_ext.moe.MOEConfig( - expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0 - ) + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) config.max_len = max_len config.gate_proj = gate_proj.data_ptr() config.up_proj = up_proj.data_ptr() @@ -174,14 +169,10 @@ def test_moe_amx_int8_accuracy(): CPUInfer.sync() # Run torch reference - t_output = moe_torch( - input_data, expert_ids, weights, gate_proj, up_proj, down_proj - ) + t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj) # Calculate relative difference - diff = torch.mean(torch.abs(output - t_output)) / torch.mean( - torch.abs(t_output) - ) + diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output)) print(f"Iteration {i}, diff = {diff:.6f}") # INT8 should have diff < 0.05 @@ -203,6 +194,7 @@ def run_all_tests(): except Exception as e: print(f"\n✗ Test failed: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/kt-kernel/test/per_commit/test_moe_amx_bench_int4.py b/kt-kernel/test/per_commit/test_moe_amx_bench_int4.py index d050ab8..e63c572 100644 --- a/kt-kernel/test/per_commit/test_moe_amx_bench_int4.py +++ b/kt-kernel/test/per_commit/test_moe_amx_bench_int4.py @@ -24,8 +24,10 @@ register_cpu_ci(est_time=300, suite="default") try: import torch import kt_kernel # Import kt_kernel first to register kt_kernel_ext + kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module from tqdm import tqdm + HAS_DEPS = True except ImportError as e: HAS_DEPS = False @@ -306,6 +308,7 @@ def run_all_tests(): except Exception as e: print(f"\n✗ Test failed: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/kt-kernel/test/per_commit/test_moe_amx_bench_int4_1.py b/kt-kernel/test/per_commit/test_moe_amx_bench_int4_1.py index 8c5a231..c03b86c 100644 --- a/kt-kernel/test/per_commit/test_moe_amx_bench_int4_1.py +++ b/kt-kernel/test/per_commit/test_moe_amx_bench_int4_1.py @@ -24,6 +24,7 @@ register_cpu_ci(est_time=300, suite="default") try: import torch import kt_kernel # Import kt_kernel first to register kt_kernel_ext + kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module from tqdm import tqdm diff --git a/kt-kernel/test/per_commit/test_moe_amx_bench_int4_1k.py b/kt-kernel/test/per_commit/test_moe_amx_bench_int4_1k.py index 81a9d60..63ba3d7 100644 --- a/kt-kernel/test/per_commit/test_moe_amx_bench_int4_1k.py +++ b/kt-kernel/test/per_commit/test_moe_amx_bench_int4_1k.py @@ -25,8 +25,10 @@ register_cpu_ci(est_time=300, suite="default") try: import torch import kt_kernel # Import kt_kernel first to register kt_kernel_ext + kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module from tqdm import tqdm + HAS_DEPS = True except ImportError as e: HAS_DEPS = False @@ -156,11 +158,7 @@ def test_moe_amx_int4_1k_benchmark(): CPUInfer = kt_kernel_ext.CPUInfer(worker_config) # Physical to logical map for weight loading - physical_to_logical_map = torch.tensor( - data=range(expert_num), - device="cpu", - dtype=torch.int64 - ).contiguous() + physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() # Initialize MOE layers moes = [] @@ -322,6 +320,7 @@ def run_all_tests(): except Exception as e: print(f"\nTest failed: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/kt-kernel/test/per_commit/test_moe_amx_bench_int8.py b/kt-kernel/test/per_commit/test_moe_amx_bench_int8.py index 559c50e..f5d9f85 100644 --- a/kt-kernel/test/per_commit/test_moe_amx_bench_int8.py +++ b/kt-kernel/test/per_commit/test_moe_amx_bench_int8.py @@ -24,8 +24,10 @@ register_cpu_ci(est_time=300, suite="default") try: import torch import kt_kernel # Import kt_kernel first to register kt_kernel_ext + kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module from tqdm import tqdm + HAS_DEPS = True except ImportError as e: HAS_DEPS = False @@ -51,7 +53,6 @@ worker_config_dict = { CPUINFER_PARAM = 60 - def get_git_commit(): """Get current git commit information.""" result = {} @@ -307,6 +308,7 @@ def run_all_tests(): except Exception as e: print(f"\n✗ Test failed: {e}") import traceback + traceback.print_exc() sys.exit(1) diff --git a/version.py b/version.py index 758355b..6671b81 100644 --- a/version.py +++ b/version.py @@ -3,4 +3,4 @@ KTransformers version information. Shared across kt-kernel and kt-sft modules. """ -__version__ = "0.4.4" +__version__ = "0.4.5"