diff --git a/AUTHORS b/AUTHORS index 1bd36158..818feaf6 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,782 +1,6 @@ -# date: Wed Jun 26 19:36:34 EEST 2024 -# this file is auto-generated by scripts/gen-authors.sh - -0cc4m -0xspringtime <110655352+0xspringtime@users.noreply.github.com> -20kdc -2f38b454 -3ooabkhxtn <31479382+3ooabkhxtn@users.noreply.github.com> -44670 <44670@users.noreply.github.com> -AN Long -AT -Aarni Koskela -Aaron Miller -Aaryaman Vasishta -Abheek Gulati -Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> -Abhishek Gopinath K <31348521+overtunned@users.noreply.github.com> -Adithya Balaji -AdithyanI -Adrian -Adrian Hesketh -Ahmet Zeer -AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> -Aisuko -Akarshan Biswas -Albert Jin -Alberto <57916483+albbus-stack@users.noreply.github.com> -Alex -Alex Azarov -Alex Azarov -Alex Klinkhamer -Alex Klinkhamer -Alex Nguyen -Alex Petenchea -Alex Renda -Alex von Gluck IV -Alexey Parfenov -Ali Chraghi <63465728+alichraghi@users.noreply.github.com> -Ali Nehzat -Ali Tariq -Alon -AlpinDale <52078762+AlpinDale@users.noreply.github.com> -Amir -AmirAli Mirian <37371367+amiralimi@users.noreply.github.com> -Ananta Bastola -Anas Ahouzi <112881240+aahouzi@users.noreply.github.com> -András Salamon -Andrei -Andrew Canis -Andrew Downing -Andrew Duffy -Andrew Godfrey -Andy Tai -Arik Poznanski -Artem -Artem Zinnatullin -Artyom Lebedev -Asbjørn Olling -Ásgeir Bjarni Ingvarsson -Ashish <1856117+ashishdatta@users.noreply.github.com> -Ashok Gelal <401055+ashokgelal@users.noreply.github.com> -Ashraful Islam -Atsushi Tatsuma -Austin <77757836+teleprint-me@users.noreply.github.com> -AustinMroz -BADR -Bach Le -Bailey Chittle <39804642+bachittle@users.noreply.github.com> -BarfingLemurs <128182951+BarfingLemurs@users.noreply.github.com> -Bartowski -Behnam M <58621210+ibehnam@users.noreply.github.com> -Ben Ashbaugh -Ben Garney -Ben Siraphob -Ben Williams -Benjamin Findley <39356821+Kartoffelsaft@users.noreply.github.com> -Benjamin Lecaillon <84293038+blecaillon@users.noreply.github.com> -Bernat Vadell -Bingan <70050083+binganao@users.noreply.github.com> -Bodo Graumann -Bono Lv -Borislav Stanimirov -Branden Butler -Brian -Bruce MacDonald -Bryan Honof -CJ Pais -CRD716 -Calvin Laurenson -Cameron -Cameron Kaiser -Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> -Casey Primozic -Casey Primozic -CausalLM <148736309+CausalLM@users.noreply.github.com> -Cebtenzzre -Chad Brewbaker -Chao Jiang -Cheng Shao -Chris Elrod -Chris Kuehl -Christian Demsar -Christian Demsar -Christian Falch <875252+chrfalch@users.noreply.github.com> -Christian Kögler -Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com> -Clark Saben <76020733+csaben@users.noreply.github.com> -Clint Herron -CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> -Cuong Trinh Manh -DAN™ -Damian Stewart -Dane Madsen -DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> -Daniel Bevenius -Daniel Drake -Daniel Hiltgen -Daniel Illescas Romero -Daniele <57776841+daniandtheweb@users.noreply.github.com> -DannyDaemonic -Dat Quoc Nguyen <2412555+datquocnguyen@users.noreply.github.com> -Dave -Dave Airlie -Dave Airlie -Dave Della Costa -David Friehs -David Kennedy -David Pflug -David Renshaw -David Sommers <12738+databyte@users.noreply.github.com> -David Yang -Dawid Potocki -Dawid Wysocki <62249621+TortillaZHawaii@users.noreply.github.com> -Dean -Deins -Deven Mistry <31466137+deven367@users.noreply.github.com> -Didzis Gosko -Djip007 -Don Mahurin -DooWoong Lee (David) -Doomsdayrs <38189170+Doomsdayrs@users.noreply.github.com> -Douglas Hanley -Dr. Tom Murphy VII Ph.D <499244+tom7@users.noreply.github.com> -Ebey Abraham -Ed Lee -Ed Lepedus -Eddie-Wang -Edward Taylor -Elaine -Elbios <141279586+Elbios@users.noreply.github.com> -Elton Kola -Engininja2 <139037756+Engininja2@users.noreply.github.com> -Equim -Eric Sommerlade -Eric Zhang <34133756+EZForever@users.noreply.github.com> -Erik Garrison -Erik Scholz -Ettore Di Giacinto -Evan Jones -Evan Miller -Eve <139727413+netrunnereve@users.noreply.github.com> -Evgeny Kurnevsky -Ewout ter Hoeven -ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> -FK -Fabian -Fabio R. Sluzala -Faez Shakil -FantasyGmm <16450052+FantasyGmm@users.noreply.github.com> -Fattire <528174+fat-tire@users.noreply.github.com> -Felix -Finn Voorhees -Firat -Folko-Ven <71110216+Folko-Ven@users.noreply.github.com> -Foul-Tarnished <107711110+Foul-Tarnished@users.noreply.github.com> -Francisco Melo <43780565+francis2tm@users.noreply.github.com> -Frank Mai -FrankHB -Fred Douglas <43351173+fredlas@users.noreply.github.com> -Frederik Vogel -Gabe Goodhart -GainLee -Galunid -Gary Linscott -Gary Mulder -Gavin Zhao -Genkagaku.GPT -Georgi Gerganov -Gilad S -Giuseppe Scrivano -GiviMAD -Govlzkoy -Guillaume "Vermeille" Sanchez -Guillaume Wenzek -Guoteng <32697156+SolenoidWGT@users.noreply.github.com> -Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com> -Haggai Nuchi -Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com> -Hamdoud Hakem <90524568+hamdoudhakem@users.noreply.github.com> -HanishKVC -Haohui Mai -Haoxiang Fei -Harald Fernengel -Hatsune Miku <129688334+at8u@users.noreply.github.com> -HatsuneMikuUwU33 <173229399+HatsuneMikuUwU33@users.noreply.github.com> -Henk Poley -Henri Vasserman -Henrik Forstén -Herman Semenov -Hesen Peng -Hoang Nguyen -Hong Bo PENG -Hongyu Ouyang <96765450+casavaca@users.noreply.github.com> -Howard Su -Hua Jiang -Huawei Lin -Hugo Roussel -Ian Bull -Ian Bull -Ian Scrivener -Ido S -IgnacioFDM -Igor Okulist -Ikko Eltociear Ashimine -Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> -Ionoclast Laboratories -Isaac McFadyen -IsaacDynamo <61521674+IsaacDynamo@users.noreply.github.com> -Ivan Komarov -Ivan Stepanov -JH23X <165871467+JH23X@users.noreply.github.com> -Jack Mousseau -JackJollimore <130917767+JackJollimore@users.noreply.github.com> -Jaemin Son -Jag Chadha -Jakub N -James A Capozzoli <157492257+jac-jim@users.noreply.github.com> -James Reynolds -Jan Boon -Jan Boon -Jan Ploski -Jannis Schönleber -Jared Van Bortel -Jared Van Bortel -Jason McCartney -Jean-Christophe Hoelt -Jean-Michaël Celerier -Jed Fox -Jeffrey Quesnelle -Jesse Jojo Johnson -Jeximo -Jhen-Jie Hong -Jiahao Li -Jian Liao -JidongZhang-THU <1119708529@qq.com> -Jinwoo Jeong <33892306+williamjeong2@users.noreply.github.com> -Jiří Podivín <66251151+jpodivin@users.noreply.github.com> -Jiří Sejkora -Joan Fontanals -Joan Fontanals -Johan -Johannes Gäßler -Johannes Rudolph -John <78893154+cmp-nct@users.noreply.github.com> -John Balis -John Smith <67539080+kingsidelee@users.noreply.github.com> -JohnnyB -Jonas Wunderlich <32615971+jonas-w@users.noreply.github.com> -Jorge A <161275481+jorgealias@users.noreply.github.com> -Jose Maldonado <63384398+yukiteruamano@users.noreply.github.com> -Joseph Stahl <1269177+josephst@users.noreply.github.com> -Josh Ramer -Joyce -Juan Calderon-Perez <835733+gaby@users.noreply.github.com> -Judd -Julius Arkenberg -Jun Jie <71215065+junnjiee16@users.noreply.github.com> -Junyang Lin -Juraj Bednar -Justin Parker -Justin Suess -Justina Cho -Justine Tunney -Justine Tunney -Juuso Alasuutari -KASR -Kamil Tomšík -Karsten Weiss -Karthick -Karthik Kumar Viswanathan <195178+guilt@users.noreply.github.com> -Karthik Sethuraman -Kasumi <90275229+kasumi-1@users.noreply.github.com> -Kawrakow <48489457+ikawrakow@users.noreply.github.com> -Keiichi Tabata -Kenvix ⭐ -Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> -Kevin Gibbons -Kevin Ji <1146876+kevinji@users.noreply.github.com> -Kevin Kwok -Kevin Lo -Kolen Cheung -Konstantin Herud -Konstantin Zhuravlyov -Kunshang Ji -Kyle Liang -Kyle Mistele -Kylin <56434533+KyL0N@users.noreply.github.com> -Lars Grammel -Laura -Lee <44310445+lx200916@users.noreply.github.com> -Lee Drake -Leng Yue -Leon Knauer -LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> -Leonardo Neumann -Li Tan -Linwei Wang -LoganDark -LostRuins <39025047+LostRuins@users.noreply.github.com> -Luciano -Luo Tian -Lyle Dean -M. Yusuf Sarıgöz -Maarten ter Huurne -Mack Straight -Maël Kerbiriou -MaggotHATE -Manuel <44313466+makuche@users.noreply.github.com> -Marc Köhlbrugge -Marco Matthies <71844+marcom@users.noreply.github.com> -Marcus Dunn <51931484+MarcusDunn@users.noreply.github.com> -Marian Cepok -Mark Fairbairn -Marko Tasic -Markus Tavenrath -Martin Delille -Martin Krasser -Martin Schwaighofer -Marvin Gießing -Masaya, Kato <62578291+msy-kato@users.noreply.github.com> -MasterYi1024 <39848311+MasterYi1024@users.noreply.github.com> -Mateusz Charytoniuk -Matheus C. França -Matheus Gabriel Alves Silva -Mathieu Nayrolles -Mathijs de Bruin -Matt Clayton <156335168+mattjcly@users.noreply.github.com> -Matt Pulver -Matteo Boschini <12133566+mbosc@users.noreply.github.com> -Mattheus Chediak -Matthew Tejo -Matvey Soloviev -Max Krasnyansky -Max Krasnyansky -Maxime <672982+maximegmd@users.noreply.github.com> -Maximilian Winter -Meng Zhang -Meng, Hengyu -Merrick Christensen -Michael Coppola -Michael Hueschen -Michael Kesper -Michael Klimenko -Michael Podvitskiy -Michael Potter -Michael de Gans -Michaël de Vries -Mihai -Mike -Mikko Juola -Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> -Mirko185 -Mirror Azure <54669636+MirrorAzure@users.noreply.github.com> -Miwa / Ensan <63481257+ensan-hcl@users.noreply.github.com> -Mohammadreza Hendiani -Mohammadreza Hendiani -Murilo Santana -Musab Gultekin -Nam D. Tran <42194884+namtranase@users.noreply.github.com> -Nathan Epstein -NawafAlansari <72708095+NawafAlansari@users.noreply.github.com> -Nebula -Neo Zhang <14088817+arthw@users.noreply.github.com> -Neo Zhang -Neo Zhang Jianyu -Neuman Vong -Nexesenex <124105151+Nexesenex@users.noreply.github.com> -Niall Coates <1349685+Niall-@users.noreply.github.com> -Nicolai Weitkemper -Nicolás Pérez -Nigel Bosch -Niklas Korz -Nikolas <127742645+nneubacher@users.noreply.github.com> -Nindaleth -Oleksandr Nikitin -Oleksii Maryshchenko -Olivier Chafik -Ondřej Čertík -Ouadie EL FAROUKI -Patrice Ferlet -Paul Tsochantaris -Pavol Rusnak -Pedro Cuenca -Peter Sugihara -Phil H <5756783+phiharri@users.noreply.github.com> -Philip Taron -Phillip Kravtsov -Pierre Alexandre SCHEMBRI -Pierrick Hymbert -Przemysław Pawełczyk -Qin Yue Chen <71813199+chenqiny@users.noreply.github.com> -Qingyou Meng -Qu Zongfu <43257352+yancaoweidaode@users.noreply.github.com> -RJ Adriaansen -Radoslav Gerganov -Radosław Gryta -Rahul Vivek Nair <68507071+RahulVivekNair@users.noreply.github.com> -Raj Hammeer Singh Hada -Ralph Soika -Rand Xie -Randall Fitzgerald -Reinforce-II -Ren Xuancheng -Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com> -RhinoDevel -Riceball LEE -Richard Kiss -Richard Roberson -Rick G <26732651+TheFlipbook@users.noreply.github.com> -Rickard Edén -Rickard Hallerbäck -Rickey Bowers Jr -Riley Stewart -Rinne -Rinne -Robert Brisita <986796+rbrisita@users.noreply.github.com> -Robert Sung-wook Shin -Robey Holderith -Robyn -Roger Meier -Roland <14355895+rbur0425@users.noreply.github.com> -Romain D <90720+Artefact2@users.noreply.github.com> -Romain Neutron -Roman Parykin -Ron Evans -Ron Jailall -Ronny Brendel -Ronsor -Rowan Hart -Rune <43761327+Rune-AI@users.noreply.github.com> -Ryan Landay -Ryder Wishart -Ryuei -Rőczey Barnabás <31726601+An0nie@users.noreply.github.com> -SakuraUmi -Salvador E. Tropea -Sam Spilsbury -Sami Farin <3876865+Safari77@users.noreply.github.com> -Samuel Maynard -Sang-Kil Park -Seb C <47074056+Sebby37@users.noreply.github.com> -Sebastián A -SebastianApel <13675545+SebastianApel@users.noreply.github.com> -Senemu <10880819+Senemu@users.noreply.github.com> -Sergey Alirzaev -Sergio López -Sertaç Özercan <852750+sozercan@users.noreply.github.com> -SeungWon Jeong <65549245+redlion0929@users.noreply.github.com> -ShadovvBeast -Shakhar Dasgupta -Shangning Xu <32517059+xushangning@users.noreply.github.com> -Shijie <821898965@qq.com> -Shintarou Okada -Shouzheng Liu <61452103+lshzh-ww@users.noreply.github.com> -Shouzheng Liu -Shuichi Tsutsumi -Sigbjørn Skjæret -Simon Willison -Siwen Yu -Sky Yan -Slaren <2141330+slaren@users.noreply.github.com> -Slava Primenko -SoftwareRenderer <138734813+SoftwareRenderer@users.noreply.github.com> -Someone -Someone Serge -Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> -Spencer Sutton -Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> -Srinivas Billa -Stefan Sydow -Steffen Röcker -Stephan Walter -Stephen Nichols -Steve Grubb -Steven Prichard -Steven Roussey -Steward Garcia <57494570+FSSRepo@users.noreply.github.com> -Suaj Carrot <72162667+SuajCarrot@users.noreply.github.com> -SuperUserNameMan -Tai Duc Nguyen -Taikono-Himazin -Tameem <113388789+AhmadTameem@users.noreply.github.com> -Tamotsu Takahashi -Thái Hoàng Tâm <75922889+RoyalHeart@users.noreply.github.com> -Thatcher Chamberlin -Theia Vogel -Thérence <13496987+Royalphax@users.noreply.github.com> -Thibault Terrasson -Thomas Klausner -Tim Miller -Timmy Knight -Timothy Cronin <40186632+4imothy@users.noreply.github.com> -Ting Lou -Ting Sun -Tobias Lütke -Tom C -Tom Jobbins <784313+TheBloke@users.noreply.github.com> -Tomas -Tomáš Pazdiora -Tristan Druyen -Tristan Ross -Tungsten842 <886724vf@anonaddy.me> -Tungsten842 -Tushar -UEXTM.com <84163508+uextm@users.noreply.github.com> -Ulrich Drepper -Uzo Nweke -Vaibhav Srivastav -Val Kharitonov -Valentin Konovalov -Valentyn Bezshapkin <61702053+valentynbez@users.noreply.github.com> -Victor Nogueira -Victor Z. Peng -Vlad -Vladimir -Vladimir Malyutin -Vladimir Zorin -Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com> -WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com> -Weird Constructor -Welby Seely -Wentai Zhang -WillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com> -William Tambellini -Willy Tarreau -Wouter <9594229+DifferentialityDevelopment@users.noreply.github.com> -Wu Jian Ping -Wu Jian Ping -Xiake Sun -Xiang (Kevin) Li -Xiao-Yong Jin -XiaotaoChen -Xiaoyi Chen -Xingchen Song(宋星辰) -Xuan Son Nguyen -Yann Follet <131855179+YannFollet@users.noreply.github.com> -Yaroslav -Yazan Agha-Schrader -Yiming Cui -Yishuo Wang -Yueh-Po Peng <94939112+y10ab1@users.noreply.github.com> -Yui -Yusuf Kağan Hanoğlu -Yuval Peled <31162840+Yuval-Peled@users.noreply.github.com> -ZHAOKAI WANG -Zane Shannon -Zay <95888118+isaiahbjork@users.noreply.github.com> -Zenix -Zhang Peiyuan -Zheng.Deng <32841220+dengzheng-cloud@users.noreply.github.com> -ZhouYuChen -Ziad Ben Hadj-Alouane -Ziang Wu <97337387+ZiangWu-77@users.noreply.github.com> -Zsapi -a-n-n-a-l-e-e <150648636+a-n-n-a-l-e-e@users.noreply.github.com> -adel boussaken -afrideva <95653597+afrideva@users.noreply.github.com> -agray3 -akawrykow <142945436+akawrykow@users.noreply.github.com> -alexpinel <93524949+alexpinel@users.noreply.github.com> -alonfaraj -alwqx -amd-lalithnc -andrijdavid -anon998 <131767832+anon998@users.noreply.github.com> -anzz1 -apaz -apcameron <37645737+apcameron@users.noreply.github.com> -arch-btw <57669023+arch-btw@users.noreply.github.com> -arcrank -arlo-phoenix <140345165+arlo-phoenix@users.noreply.github.com> -at8u <129688334+at8u@users.noreply.github.com> -automaticcat -bandoti <141645996+bandoti@users.noreply.github.com> -beiller -bhubbb <79117352+bhubbb@users.noreply.github.com> -bmwl -bobqianic <129547291+bobqianic@users.noreply.github.com> -bryanSwk <93190252+bryanSwk@users.noreply.github.com> -bsilvereagle -bssrdf -byte-6174 <88070277+byte-6174@users.noreply.github.com> -cebtenzzre -chaihahaha -chiranko <96988916+chiranko@users.noreply.github.com> -clibdev <52199778+clibdev@users.noreply.github.com> -clyang -cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com> -coezbek -comex -compilade <113953597+compilade@users.noreply.github.com> -compilade -cpumaxx <163466046+cpumaxx@users.noreply.github.com> -crasm -crasm -daboe01 -david raistrick -ddh0 -ddpasa <112642920+ddpasa@users.noreply.github.com> -deepdiffuser <112834445+deepdiffuser@users.noreply.github.com> -divinity76 -dm4 -dotpy314 <33351922+dotpy314@users.noreply.github.com> -drbh -ds5t5 <145942675+ds5t5@users.noreply.github.com> -dylan -eastriver -ebraminio -eiery <19350831+eiery@users.noreply.github.com> -eric8607242 +Kawrakow +saood06 +Nexes the Elder <124105151+Nexesenex@users.noreply.github.com> fairydreaming <166155368+fairydreaming@users.noreply.github.com> -fraxy-v <65565042+fraxy-v@users.noreply.github.com> -github-actions[bot] -gliptic -goerch -grahameth <96447521+grahameth@users.noreply.github.com> -gwjr <502526+gwjr@users.noreply.github.com> -h-h-h-h <13482553+h-h-h-h@users.noreply.github.com> -hankcs -hoangmit -hongbo.mo <352280764@qq.com> -hopkins385 <98618192+hopkins385@users.noreply.github.com> -howlger -howlger -hutli <6594598+hutli@users.noreply.github.com> -hutli -hutli -hxer7963 -hydai -iSma -iacore <74560659+iacore@users.noreply.github.com> -igarnier -intelmatt <61025942+intelmatt@users.noreply.github.com> -iohub -jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com> -jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> -jameswu2014 <545426914@qq.com> -jiez <373447296@qq.com> -jneem -joecryptotoo <80373433+joecryptotoo@users.noreply.github.com> -johnson442 <56517414+johnson442@users.noreply.github.com> -jojorne -jon-chuang <9093549+jon-chuang@users.noreply.github.com> -jp-x-g -jukofyork <69222624+jukofyork@users.noreply.github.com> -junchao-loongson <68935141+junchao-loongson@users.noreply.github.com> -jwj7140 <32943891+jwj7140@users.noreply.github.com> -k.h.lai -kaizau -kalomaze <66376113+kalomaze@users.noreply.github.com> -kang -katsu560 <118887472+katsu560@users.noreply.github.com> -kchro3 <62481661+kchro3@users.noreply.github.com> -khimaros -kiltyj -klosax <131523366+klosax@users.noreply.github.com> -kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> -kunnis -kuronekosaiko -kuvaus <22169537+kuvaus@users.noreply.github.com> -kwin1412 <42286931+kwin1412@users.noreply.github.com> -l3utterfly -ldwang -le.chang -leejet -limitedAtonement -liuwei-git <14815172+liuwei-git@users.noreply.github.com> -lon <114724657+longregen@users.noreply.github.com> -loonerin <132926317+loonerin@users.noreply.github.com> -luoyu-intel -m3ndax -maddes8cht <55592906+maddes8cht@users.noreply.github.com> -makomk -manikbhandari -maor-ps <154728172+maor-ps@users.noreply.github.com> -mdrokz -mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> -minarchist -mj-shifu <77107165+mj-shifu@users.noreply.github.com> -mmyjona -momonga <115213907+mmnga@users.noreply.github.com> -moritzbrantner <31051084+moritzbrantner@users.noreply.github.com> -mzcu -nanahi <130121847+na-na-hi@users.noreply.github.com> -ngc92 <7938269+ngc92@users.noreply.github.com> -nhamanasu <45545786+nhamanasu@users.noreply.github.com> -niansa/tuxifan -niansa/tuxifan -nickp27 -ningshanwutuobang -nold -nopperl <54780682+nopperl@users.noreply.github.com> -nusu-github <29514220+nusu-github@users.noreply.github.com> -olexiyb -omahs <73983677+omahs@users.noreply.github.com> -oobabooga <112222186+oobabooga@users.noreply.github.com> -opparco -ostix360 <55257054+ostix360@users.noreply.github.com> -pengxin99 -perserk -pmysl -postmasters -pudepiedj -qingfengfenga <41416092+qingfengfenga@users.noreply.github.com> -qouoq -qunash -rabidcopy -rankaiyx -rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> -rhuddleston -rimoliga <53384203+rimoliga@users.noreply.github.com> -runfuture -sandyiscool -sasha0552 -semidark -sharpHL <132747147+sharpHL@users.noreply.github.com> -shibe2 -singularity <12184989+singularity-s0@users.noreply.github.com> -sjinzh -sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com> -slaren <2141330+slaren@users.noreply.github.com> -slaren -snadampal <87143774+snadampal@users.noreply.github.com> -staviq -stduhpf -strawberrymelonpanda <152940198+strawberrymelonpanda@users.noreply.github.com> -swittk -takov751 <40316768+takov751@users.noreply.github.com> -tarcey -texmex76 <40733439+texmex76@users.noreply.github.com> -thement <40525767+thement@users.noreply.github.com> -tjohnman -tslmy -ubik2 -uint256_t -uint256_t -unbounded -valiray <133289098+valiray@users.noreply.github.com> -vik -viric -vodkaslime <646329483@qq.com> -vvhg1 <94630311+vvhg1@users.noreply.github.com> -vxiiduu <73044267+vxiiduu@users.noreply.github.com> -wbpxre150 <100937007+wbpxre150@users.noreply.github.com> -whoreson <139810751+whoreson@users.noreply.github.com> -woachk <24752637+woachk@users.noreply.github.com> -wonjun Jang -woodx <124784234+woodx9@users.noreply.github.com> -wzy <32936898+Freed-Wu@users.noreply.github.com> -xaedes -xaedes -xloem <0xloem@gmail.com> -yangli2 -yuiseki -zakkor -zhangkaihuo -zhouwg <6889919+zhouwg@users.noreply.github.com> -zhouwg -zrm -Ștefan-Gabriel Muscalu -源文雨 <41315874+fumiama@users.noreply.github.com> -Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com> +Stanisław Szymczyk +ubergarm diff --git a/CMakeLists.txt b/CMakeLists.txt index cb59656e..3e9c3cc0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,6 +53,14 @@ if (WIN32) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) endif() +# force MSVC compiler charset to utf-8 +if (MSVC) + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/bigobj>") + add_compile_options("$<$:/bigobj>") +endif() + # # option list # @@ -115,6 +123,29 @@ llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL) llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) llama_option_depr(WARNING LLAMA_CANN GGML_CANN) +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + message(STATUS "Using -fsanitize=thread") + + add_compile_options(-fsanitize=thread) + link_libraries (-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + message(STATUS "Using -fsanitize=address") + + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries (-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + message(STATUS "Using -fsanitize=undefined") + + add_compile_options(-fsanitize=undefined) + link_libraries (-fsanitize=undefined) + endif() +endif() + # # build the library # @@ -123,6 +154,11 @@ if (NOT TARGET ggml) add_subdirectory(ggml) # ... otherwise assume ggml is added by a parent CMakeLists.txt endif() + +# +# build the library +# + add_subdirectory(src) # diff --git a/LICENSE b/LICENSE index 03f0ee9c..be8d3b11 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,8 @@ MIT License -Copyright (c) 2023-2024 The ggml authors -Copyright (c) 2024 Iwan Kawrakow +Copyright (c) 2023-2024 The ggml authors (https://github.com/ggml-org/ggml/blob/master/AUTHORS) +Copyright (c) 2023-2024 The llama.cpp authors (https://github.com/ggml-org/llama.cpp/blob/master/AUTHORS) +Copyright (c) 2024-2025 The ik_llama.cpp authors (https://github.com/ikawrakow/ik_llama.cpp/blob/main/AUTHORS) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 782d0150..c1381cad 100644 --- a/README.md +++ b/README.md @@ -1,269 +1,71 @@ -# llama.cpp clone with better CPU performance +# ik_llama.cpp: llama.cpp fork with better CPU performance [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) ## TL;DR -This repository is a clone of [llama.cpp](https://github.com/ggerganov/llama.cpp) with the following improvements -* Better implementation of CPU matrix multiplications (`AVX2` and `ARM_NEON`) for `fp16/fp32` and all k-, i-, and legacy `llama.cpp` quants, that leads to a significant improvement in prompt processing (PP) speed, typically in the range of 2X, but up to 4X for some quantization types. Token generation (TG) also benefits, but to a lesser extent due to TG being memory bound -* Faster CPU inference for MoE models with similar performance gains -* Implementation of the [Bitnet b1.58](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) model for the CPU (`AVX2` and `ARM_NEON`) and GPU (`CUDA` and `Metal`). This implementation is much faster than the unmerged `llama.cpp` [PR-8151](https://github.com/ggerganov/llama.cpp/pull/8151) +This repository is a fork of [llama.cpp](https://github.com/ggerganov/llama.cpp) with better CPU and hybrid GPU/CPU performance, new SOTA quantization types, first-class Bitnet support, better DeepSeek performance via MLA, FlashMLA, fused MoE operations and tensor overrides for hybrid GPU/CPU inference, row-interleaved quant packing, etc. -If you are not already familiar with [llama.cpp](https://github.com/ggerganov/llama.cpp), it is better to start there. For those familiar with `llama.cpp`, everything here works the same as in `llama.cpp` (or at least the way `llama.cpp` worked when I last synced on Aug 12 2024). +>[!IMPORTANT] +>The new GGUFs for DeepSeek-V3/R1/Lite do not work in this repository. This is due to the backwards incompatible change in mainline `llama.cpp` that [added MLA support](https://github.com/ggml-org/llama.cpp/pull/12801) +>2.5 months after MLA was available here, and worked with the original DeepSeek GGUFs. Please use the original GGUF or, if you don't have one, convert the HF safetensors using the Python conversion script in this repository. +> +>**Update** There is now [PR 394](https://github.com/ikawrakow/ik_llama.cpp/pull/394) addressing the issue. Would appreciate testing with DeepSeek-V3/R1. -Note that I have published some, but not all, of the code in this repository in a series of [llamafile](https://github.com/Mozilla-Ocho/llamafile) PRs ([394](https://github.com/Mozilla-Ocho/llamafile/pull/394), [405](https://github.com/Mozilla-Ocho/llamafile/pull/405), [428](https://github.com/Mozilla-Ocho/llamafile/pull/428), [435](https://github.com/Mozilla-Ocho/llamafile/pull/435), [453](https://github.com/Mozilla-Ocho/llamafile/pull/453), and [464](https://github.com/Mozilla-Ocho/llamafile/pull/464)) +## Latest News -The implementation of matrix-matrix and matrix-vector multiplications is in a single C++ source file (`iqk_mul_mat.cpp`) with just two interface functions `iqk_mul_mat` (`fp16/fp32` and quantized matrix multiplications) and `iqk_mul_mat_moe` (as `iqk_mul_mat` but meant to be used for the FFN part of a MoE model). Under the hood `iqk_mul_mat_moe` uses the same implementation as `iqk_mul_mat`, with the only difference being where results are stored in memory. Bitnet quantization related stuff is in `iqk-quantize.cpp`. +* May 9 2025: Support for LlaMA-3-Nemotron models added, see [PR 377](https://github.com/ikawrakow/ik_llama.cpp/pull/377) +* May 7 2025: 🚀 Faster TG for DeepSeek models with GPU or hybrid GPU/CPU inference. See [PR 386](https://github.com/ikawrakow/ik_llama.cpp/pull/386) for details. Caveat: Ampere or newer Nvidia GPU required +* May 4 2025: 🚀 Significant token generation performance improvement on CUDA with Flash Attention for GQA models. For details and benchmarks see [PR #370](https://github.com/ikawrakow/ik_llama.cpp/pull/370) +* April 29 2025: Qwen3 support added +* April 26 2025: GLM-4 support added +* April 26 2025: Command-A support added +* April 22 2025: Support for the latest Microsoft Bitnet model added +* April 21 2025: ik_llama.cpp builds and runs successfully on Android (using termux) +* April 17 2025: 🚀 Better CPU Flash Attention token generation performance +* April 13 2025: `IQ1_M` quantization improvements +* April 10 2025: LLaMA-4 support added +* April 7 2025: `IQ2_XS` quantization improvements +* April 3 2025: 🚀 Much faster MoE implementation on Metal +* April 1 2025: Quantization improvements for `Q2_K, Q4_K, Q5_K, Q4_1, Q5_1` +* March 28 2025: Quantization imrovements for `Q4_0, Q5_0, Q6_0, Q3_K, Q6_K, IQ4_XS, IQ4_NL` +* March 25 2025: 🚀 Better MoE performance on CUDA +* March 23 2025: 🚀 Better batched processing speed for DeepSeek models +* March 22 2025: Gemma3 support added +* March 21 2025: 🚀 FlashMLA-3: fastest CPU-only inference for DeepSeek models +* March 18 2025: Reduce compute buffer size +* March 17 2025: 🚀 FlashMLA-2 performance improvements +* March 12 2025: Allow `Q8_0` KV cache with FlashMLA-2 on CUDA +* March 10 2025: 🚀 Better TG performance for MoE models on CUDA +* March 9 2025: 🚀 FlashMLA on CUDA +* March 8 2025: 🚀 Faster FlashMLA CPU implementation +* March 7 2025: Custom quantization mixes using regular expressions +* March 5 2025: 🚀 FlashMLA on CUDA +* March 3 2025: 🚀 Introducing FlashMLA - MLA with Flash Attention +* March 1 2025: Smart Expert Reduction for faster DeepSeek inference +* Feb 27 2025: MLA without transposed cache +* Feb 25 2025: Tensor overrides for better control where model weights are stored (GPU or CPU) +* Feb 23 2025: 🚀 Fused FFN ops for faster MoE inference +* Feb 23 2025: `sweep-bench` - better performance benchmarking +* Feb 20 2025: 🚀 Fast GEMM/GEMV for `IQ1_S` +* Feb 19 2025: `Q8_KV` - new type for 8-bit KV-cache quantization +* Feb 13 2025: Allow `Q8_0` quantized cache with MLA +* Feb 11 2025: 🚀 Flash Attention support for DeepSeek models +* Feb 9 2025: 🚀 MLA for DeepSeek models +* Jan 23 2025: DeepSeek-V3 support added -## Why? +## Resources -Mostly out of curiosity: -* Justine Tunney's `tinyBLAS`, which she contributed to `llama.cpp` in [PR 6414](https://github.com/ggerganov/llama.cpp/pull/6414), only works for `Q4_0`, `Q8_0` and `fp16/bf16` models. In the surrounding discussion about possibly extending `tinyBLAS` to k- and i-quants, she felt that k-quants are [not amenable to block-tiling](https://github.com/ggerganov/llama.cpp/pull/6840#issuecomment-2072995387), which is required to improve performance. This statement piqued my curiosity, so here we are. -* Bitnet-1.58b has been one of the [most discussed topics](https://github.com/ggerganov/llama.cpp/issues/5761#issuecomment-2198380366) in the `llama.cpp` project, so eventually I decided to see how efficiently one can implement a ternary model +There is no single point of reference describing all new `ik_llama.cpp` features. Pull requests often contain detailed information, so browsing the PRs is often the best way to learn about new features and how to use them. In addition +* [The Wiki page](https://github.com/ikawrakow/ik_llama.cpp/wiki) has performance comparisons to mainline `llama.cpp` +* [This guide](https://github.com/ikawrakow/ik_llama.cpp/discussions/258) is a good place to start if you came here because of DeepSeek models +* [This discussion](https://github.com/ikawrakow/ik_llama.cpp/discussions/266) is about running DeepSeek-V3/R1 on a 16 x 3090 setup +* [This discussion](https://github.com/ikawrakow/ik_llama.cpp/discussions/8) describes the new quantization types available in `ik_llama.cpp` -Curiosity aside, improved CPU performance may be (or may become) important in practice. According to The Register, 70% of AI inference [is done on the CPU of mobile phones](https://www.theregister.com/2024/05/30/arm_cortex_x925_ai_cores/?td=rt-3a), at least in the Android world (but I haven't come around to actually comparing performance on a phone). With ever increasing number of LLM model parameters, and with Meta's 400B model just released, the CPU may become the only viable option for people not willing (or not able to) rent/buy uber expensive GPU instances capable of running such models. Granted, one would need a pretty beefy computer to run a 400B model, and inference speed will be sluggish, but at least one will not need to spend the equivalent of a luxury apartment in the downtown of the city where I live to buy the GPU system capable of running the model. +## Contributing -## Performance comparison to llama.cpp +Contributions in form of pull requests, issue submissions (bug reports, feature requests), or general discussions, are welcome. -The results in the following tables are obtained with these parameters: -* Model is LLaMA-v3-8B for `AVX2` and LLaMA-v2-7B for `ARM_NEON` -* The `AVX2` CPU is a 16-core Ryzen-7950X -* The `ARM_NEON` CPU is M2-Max -* `tinyBLAS` is enabled in `llama.cpp` -* `llama.cpp` results are for `build: 081fe431 (3441)`, which was the current `llama.cpp` master branch when I pulled on July 23 2024. -* The projects are built without `CUDA` support, no `BLAS`, and Accelerate framework disabled - -### Prompt processing - -Here I set the number of threads to be equal to the number of (performance) cores of the CPU, so 16 threads for the Ryzen-7950X and 8 threads for the M2-Max. The following table summarizes the results. To not make the table too long, I have listed only quantized models containing predominantly one quantization type (i.e., excluded the `QX_K - Medium/Large` variants, which are typically a mix of `QX_K` and `Q(X+1)_K`, as well as `IQ2_S` and `IQ3_XS`). - -The command line to generate the benchmark data is -``` -./bin/llama-bench -m $model -p 512 -n 0 -t $num_threads -ngl 0 -``` - -| Quantization| size | backend | threads | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | -| ----------- | ---------: | ---------- | ------: | ---------------: | ---------------: | ------: | -| 8B F16 | 14.96 GiB | AVX2 | 16 | 112.37 ± 0.40 | 131.27 ± 0.38 | 1.168 | -| 7B F16 | 12.55 GiB | NEON | 8 | 90.28 ± 1.25 | 95.34 ± 0.15 | 1.056 | -| 8B Q8_0 | 7.95 GiB | AVX2 | 16 | 118.07 ± 0.53 | 134.00 ± 0.47 | 1.135 | -| 7B Q8_0 | 6.67 GiB | NEON | 8 | 77.25 ± 1.81 | 94.14 ± 1.15 | 1.219 | -| 8B Q4_0 | 4.35 GiB | AVX2 | 16 | 104.46 ± 0.33 | 130.20 ± 0.29 | 1.246 | -| 7B Q4_0 | 3.57 GiB | NEON | 8 | 65.46 ± 0.79 | 76.22 ± 0.71 | 1.164 | -| 8B Q4_1 | 4.77 GiB | AVX2 | 16 | 57.83 ± 0.24 | 160.69 ± 0.49 | 2.779 | -| 7B Q4_1 | 3.95 GiB | NEON | 8 | 37.40 ± 0.50 | 65.83 ± 0.98 | 1.760 | -| 8B Q5_0 | 5.22 GiB | AVX2 | 16 | 53.50 ± 0.35 | 122.62 ± 0.48 | 2.292 | -| 7B Q5_0 | 4.34 GiB | NEON | 8 | 29.31 ± 0.51 | 67.51 ± 1.17 | 2.303 | -| 8B Q5_1 | 5.64 GiB | AVX2 | 16 | 50.85 ± 0.36 | 147.15 ± 0.47 | 2.894 | -| 7B Q5_1 | 4.72 GiB | NEON | 8 | 26.02 ± 0.37 | 58.49 ± 0.85 | 2.248 | -| 8B Q2_K_S | 2.78 GiB | AVX2 | 16 | 110.11 ± 0.28 | 192.47 ± 1.35 | 1.748 | -| 7B Q2_K_S | 2.16 GiB | NEON | 8 | 35.44 ± 0.06 | 77.93 ± 1.64 | 2.199 | -| 8B Q3_K_S | 3.41 GiB | AVX2 | 16 | 77.42 ± 0.36 | 181.64 ± 0.44 | 2.346 | -| 7B Q3_K_S | 2.75 GiB | NEON | 8 | 26.79 ± 0.03 | 59.38 ± 1.08 | 2.216 | -| 8B Q4_K_S | 4.36 GiB | AVX2 | 16 | 98.92 ± 0.34 | 185.35 ± 0.39 | 1.874 | -| 7B Q4_K_S | 3.59 GiB | NEON | 8 | 46.55 ± 0.67 | 76.31 ± 0.38 | 1.639 | -| 8B Q5_K_S | 5.21 GiB | AVX2 | 16 | 69.44 ± 0.31 | 179.62 ± 0.69 | 2.587 | -| 7B Q5_K_S | 4.33 GiB | NEON | 8 | 30.18 ± 0.23 | 65.34 ± 0.79 | 2.165 | -| 8B Q6_K | 6.14 GiB | AVX2 | 16 | 74.89 ± 0.26 | 181.86 ± 0.55 | 2.428 | -| 7B Q6_K | 5.15 GiB | NEON | 8 | 28.12 ± 1.24 | 60.75 ± 1.15 | 2.160 | -| 8B IQ2_XXS | 2.23 GiB | AVX2 | 16 | 42.57 ± 0.16 | 126.63 ± 0.55 | 2.975 | -| 7B IQ2_XXS | 1.73 GiB | NEON | 8 | 20.87 ± 0.20 | 64.29 ± 1.12 | 3.080 | -| 8B IQ2_XS | 2.42 GiB | AVX2 | 16 | 46.45 ± 0.27 | 125.46 ± 0.43 | 2.701 | -| 7B IQ2_XS | 1.89 GiB | NEON | 8 | 22.77 ± 0.21 | 51.15 ± 0.24 | 2.246 | -| 8B IQ2_M | 2.74 GiB | AVX2 | 16 | 40.76 ± 0.18 | 113.07 ± 0.48 | 2.774 | -| 7B IQ2_M | 2.20 GiB | NEON | 8 | 14.95 ± 0.26 | 44.87 ± 0.50 | 3.001 | -| 8B IQ3_XXS | 3.04 GiB | AVX2 | 16 | 31.95 ± 0.20 | 109.86 ± 0.45 | 3.438 | -| 7B IQ3_XXS | 2.41 GiB | NEON | 8 | 14.40 ± 0.10 | 53.58 ± 0.85 | 3.721 | -| 8B IQ3_S | 3.42 GiB | AVX2 | 16 | 28.04 ± 0.08 | 96.28 ± 0.45 | 3.434 | -| 7B IQ3_S | 2.75 GiB | NEON | 8 | 12.08 ± 0.30 | 49.72 ± 0.06 | 4.116 | -| 8B IQ4_XS | 4.13 GiB | AVX2 | 16 | 68.98 ± 0.31 | 180.34 ± 0.55 | 2.614 | -| 7B IQ4_XS | 3.37 GiB | NEON | 8 | 40.67 ± 1.97 | 75.11 ± 1.97 | 1.847 | -| 8B IQ4_NL | 4.35 GiB | AVX2 | 16 | 59.94 ± 0.21 | 129.06 ± 0.43 | 2.153 | -| 7B IQ4_NL | 3.56 GiB | NEON | 8 | 34.36 ± 0.81 | 76.02 ± 1.36 | 2.212 | - -We see that `llama.cpp` achieves respectable performance for `fp16`, `Q8_0`, and `Q4_0`, being only up to 25% slower than this implementation. This is thanks to the use of Justine Tunney's `tinyBLAS`, which is utilized for these quantization types. For all other quants we observe performance gains in the `1.75X - 4X` range, which is not a small feat considering that the `ggml` matrix multiplication functions has been rewritten several times since `llama.cpp` was first published. Performance gains are larger for i-quants due to the higher quant unpacking cost (see discussion in "To tile or not to tile") - -### Token generation - -On the Ryzen-7950X TG is memory bound, and for many quantization types peak performance is achieved at just 4 threads. Hence, only results for 2 and 4 threads are shown for `AVX2`. The M2-Max has a much more capable memory subsystem and as a result performance keep increasing up to 8 threads. Thus, results are given for up to 8 threads for `ARM_NEON`. - -The command line to generate the data was -``` -./bin/llama-bench -m $model -p 0 -n 128 -t $num_threads -ngl 0 -``` - -| Quantization| size | backend | threads | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | -| ---------- | ---------: | ---------- | ------: | ---------------: | ---------------: | ------: | -| 8B F16 | 14.96 GiB | AVX2 | 1 | 2.20 ± 0.00 | 2.25 ± 0.00 | 1.023 | -| | | | 2 | 3.63 ± 0.00 | 3.68 ± 0.00 | 1.014 | -| | | | 4 | 4.20 ± 0.00 | 4.20 ± 0.00 | 1.000 | -| 7B F16 | 12.55 GiB | NEON | 2 | 6.94 ± 0.27 | 7.40 ± 0.01 | 1.066 | -| | | | 4 | 8.73 ± 0.01 | 8.83 ± 0.01 | 1.011 | -| | | | 6 | 9.05 ± 0.02 | 9.05 ± 0.01 | 1.000 | -| 8B Q8_0 | 7.95 GiB | AVX2 | 2 | 5.03 ± 0.00 | 7.87 ± 0.00 | 1.565 | -| | | | 4 | 7.40 ± 0.00 | 7.82 ± 0.00 | 1.057 | -| 7B Q8_0 | 6.67 GiB | NEON | 2 | 8.29 ± 0.44 | 12.07 ± 0.10 | 1.456 | -| | | | 4 | 13.53 ± 0.03 | 15.77 ± 0.08 | 1.166 | -| | | | 8 | 16.24 ± 0.10 | 16.94 ± 0.04 | 1.043 | -| 8B Q4_0 | 4.35 GiB | AVX2 | 2 | 6.36 ± 0.00 | 10.28 ± 0.00 | 1.616 | -| | | | 4 | 10.97 ± 0.06 | 13.55 ± 0.07 | 1.235 | -| 7B Q4_0 | 3.57 GiB | NEON | 2 | 9.77 ± 0.02 | 13.69 ± 0.03 | 1.401 | -| | | | 4 | 17.82 ± 0.06 | 23.98 ± 0.11 | 1.346 | -| | | | 8 | 26.63 ± 0.41 | 29.86 ± 0.04 | 1.121 | -| 8B Q4_1 | 4.77 GiB | AVX2 | 2 | 5.11 ± 0.00 | 11.45 ± 0.00 | 2.241 | -| | | | 4 | 9.08 ± 0.02 | 12.58 ± 0.00 | 1.385 | -| 7B Q4_1 | 3.95 GiB | NEON | 2 | 9.11 ± 0.06 | 14.62 ± 0.04 | 1.605 | -| | | | 4 | 17.04 ± 0.09 | 24.08 ± 0.28 | 1.413 | -| | | | 8 | 25.26 ± 0.24 | 27.23 ± 0.14 | 1.078 | -| 8B Q5_0 | 5.22 GiB | AVX2 | 2 | 5.31 ± 0.01 | 8.30 ± 0.01 | 1.563 | -| | | | 4 | 9.40 ± 0.01 | 11.47 ± 0.00 | 1.220 | -| 7B Q5_0 | 4.34 GiB | NEON | 2 | 7.26 ± 0.06 | 7.52 ± 0.00 | 1.036 | -| | | | 4 | 13.63 ± 0.18 | 14.16 ± 0.10 | 1.039 | -| | | | 8 | 22.55 ± 0.35 | 24.34 ± 0.22 | 1.079 | -| 8B Q5_1 | 5.64 GiB | AVX2 | 2 | 4.52 ± 0.00 | 8.86 ± 0.00 | 1.960 | -| | | | 4 | 7.72 ± 0.05 | 10.68 ± 0.03 | 1.383 | -| 7B Q5_1 | 4.72 GiB | NEON | 2 | 6.51 ± 0.01 | 6.42 ± 0.03 | 0.986 | -| | | | 4 | 12.26 ± 0.18 | 12.21 ± 0.14 | 0.996 | -| | | | 8 | 20.33 ± 0.52 | 21.85 ± 0.22 | 1.075 | -| 8B Q2_K_S | 2.78 GiB | AVX2 | 2 | 11.30 ± 0.00 | 13.06 ± 0.01 | 1.156 | -| | | | 4 | 18.70 ± 0.00 | 19.04 ± 0.65 | 1.014 | -| 7B Q2_K_S | 2.16 GiB | NEON | 2 | 8.42 ± 0.05 | 11.97 ± 0.10 | 1.422 | -| | | | 4 | 15.74 ± 0.01 | 22.09 ± 0.08 | 1.403 | -| | | | 8 | 27.35 ± 0.05 | 38.32 ± 0.05 | 1.401 | -| 8B Q3_K_S | 3.41 GiB | AVX2 | 2 | 8.58 ± 0.00 | 10.82 ± 0.00 | 1.261 | -| | | | 4 | 15.26 ± 0.01 | 16.25 ± 0.01 | 1.065 | -| 7B Q3_K_S | 2.75 GiB | NEON | 2 | 6.40 ± 0.02 | 9.12 ± 0.09 | 1.425 | -| | | | 4 | 12.17 ± 0.00 | 17.11 ± 0.03 | 1.406 | -| | | | 8 | 22.04 ± 0.08 | 31.39 ± 0.31 | 1.424 | -| 8B Q4_K_S | 4.36 GiB | AVX2 | 2 | 9.61 ± 0.00 | 10.72 ± 0.01 | 1.116 | -| | | | 4 | 13.24 ± 0.31 | 13.28 ± 0.01 | 1.003 | -| 7B Q4_K_S | 3.59 GiB | NEON | 2 | 11.15 ± 0.05 | 12.93 ± 0.09 | 1.160 | -| | | | 4 | 20.24 ± 0.16 | 23.49 ± 0.29 | 1.161 | -| | | | 8 | 25.76 ± 0.07 | 28.31 ± 0.22 | 1.099 | -| 8B Q5_K_S | 5.21 GiB | AVX2 | 2 | 7.45 ± 0.00 | 9.73 ± 0.00 | 1.306 | -| | | | 4 | 11.05 ± 0.33 | 11.43 ± 0.02 | 1.034 | -| 7B Q5_K_S | 4.33 GiB | NEON | 2 | 7.20 ± 0.04 | 8.81 ± 0.04 | 1.224 | -| | | | 4 | 13.62 ± 0.15 | 16.81 ± 0.16 | 1.234 | -| | | | 8 | 20.56 ± 0.19 | 23.96 ± 0.14 | 1.165 | -| 8B Q6_K | 6.14 GiB | AVX2 | 2 | 7.53 ± 0.00 | 9.42 ± 0.00 | 1.251 | -| | | | 4 | 9.74 ± 0.00 | 9.97 ± 0.01 | 1.024 | -| 7B Q6_K | 5.15 GiB | NEON | 2 | 6.85 ± 0.04 | 8.30 ± 0.06 | 1.212 | -| | | | 4 | 13.03 ± 0.05 | 15.47 ± 0.17 | 1.187 | -| | | | 8 | 18.52 ± 0.07 | 20.67 ± 0.08 | 1.116 | -| 8B IQ2_XXS | 2.23 GiB | AVX2 | 2 | 5.33 ± 0.01 | 6.40 ± 0.00 | 1.201 | -| | | | 4 | 10.06 ± 0.03 | 11.76 ± 0.03 | 1.169 | -| 7B IQ2_XXS | 1.73 GiB | NEON | 2 | 5.07 ± 0.04 | 5.22 ± 0.05 | 1.030 | -| | | | 4 | 9.63 ± 0.00 | 9.91 ± 0.07 | 1.029 | -| | | | 8 | 17.40 ± 0.50 | 18.65 ± 0.22 | 1.072 | -| 8B IQ2_XS | 2.42 GiB | AVX2 | 2 | 5.83 ± 0.00 | 6.55 ± 0.00 | 1.123 | -| | | | 4 | 10.88 ± 0.09 | 12.07 ± 0.07 | 1.109 | -| 7B IQ2_XS | 1.89 GiB | NEON | 2 | 5.52 ± 0.01 | 5.60 ± 0.00 | 1.014 | -| | | | 4 | 10.50 ± 0.01 | 11.15 ± 0.00 | 1.062 | -| | | | 8 | 18.19 ± 1.30 | 20.94 ± 0.19 | 1.151 | -| 8B IQ2_M | 2.74 GiB | AVX2 | 2 | 5.12 ± 0.01 | 5.17 ± 0.00 | 1.010 | -| | | | 4 | 9.60 ± 0.28 | 9.68 ± 0.16 | 1.008 | -| 7B IQ2_M | 2.20 GiB | NEON | 2 | 3.73 ± 0.02 | 4.53 ± 0.00 | 1.214 | -| | | | 4 | 7.14 ± 0.05 | 8.70 ± 0.06 | 1.218 | -| | | | 8 | 11.99 ± 0.48 | 16.41 ± 0.05 | 1.369 | -| 8B IQ3_XXS | 3.04 GiB | AVX2 | 2 | 4.06 ± 0.01 | 5.00 ± 0.00 | 1.232 | -| | | | 4 | 7.75 ± 0.02 | 9.13 ± 0.45 | 1.178 | -| 7B IQ3_XXS | 2.41 GiB | NEON | 2 | 3.53 ± 0.00 | 3.82 ± 0.00 | 1.082 | -| | | | 4 | 6.74 ± 0.04 | 7.42 ± 0.07 | 1.103 | -| | | | 8 | 11.96 ± 0.40 | 13.19 ± 0.29 | 1.103 | -| 8B IQ3_S | 3.42 GiB | AVX2 | 2 | 3.62 ± 0.00 | 4.06 ± 0.00 | 1.122 | -| | | | 4 | 6.80 ± 0.01 | 7.62 ± 0.10 | 1.121 | -| 7B IQ3_S | 2.75 GiB | NEON | 2 | 2.96 ± 0.01 | 3.21 ± 0.03 | 1.084 | -| | | | 4 | 5.68 ± 0.01 | 6.25 ± 0.05 | 1.100 | -| | | | 8 | 10.32 ± 0.25 | 11.11 ± 0.37 | 1.077 | -| 8B IQ4_XS | 4.13 GiB | AVX2 | 2 | 8.08 ± 0.00 | 11.35 ± 0.00 | 1.405 | -| | | | 4 | 13.36 ± 0.72 | 14.32 ± 0.24 | 1.072 | -| 7B IQ4_XS | 3.37 GiB | NEON | 2 | 9.87 ± 0.03 | 12.06 ± 0.00 | 1.222 | -| | | | 4 | 17.78 ± 0.23 | 22.06 ± 0.28 | 1.241 | -| | | | 8 | 27.62 ± 0.09 | 29.70 ± 0.39 | 1.075 | -| 8B IQ4_NL | 4.35 GiB | AVX2 | 2 | 5.52 ± 0.00 | 10.26 ± 0.00 | 1.859 | -| | | | 4 | 10.78 ± 0.01 | 13.69 ± 0.08 | 1.270 | -| 7B IQ4_NL | 3.56 GiB | NEON | 2 | 8.32 ± 0.01 | 13.54 ± 0.01 | 1.627 | -| | | | 4 | 15.89 ± 0.00 | 24.28 ± 0.29 | 1.528 | -| | | | 8 | 26.56 ± 0.36 | 29.87 ± 0.08 | 1.125 | - -Here gains are generally lower compared to PP due to TG performance being limited by memory bandwidth. Nevertheless, for some quants/architectures/threads the speedup is quite remarkable (e.g., almost a factor of 2 for `Q5_1` on `AVX2` with 2 threads). - -## MoE models - -There is [PR-6840](https://github.com/ggerganov/llama.cpp/pull/6840) from Justine Tunney in `llama.cpp`, but it has not been merged since April 23, so I'll compare performance to the master branch for Mixtral-8x7B. As Mixtral8x7B quantization is quite a lengthy process, the following table shows data only for `Q4_K_S` (a commonly used k-quant, 4 bit), `Q5_0` (a legacy quant, 5 bit), and `IQ4_XXS` (a 3-bit i-quant) - -| model | size | backend | threads | test | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | -| ------------ | ---------: | ---------- | ------: | -------: | ---------------: | ---------------: | ------: | -| 8x7B Q4_K_S | 48.75 GiB | AVX2 | 16 | pp512 | 54.92 ± 0.23 | 102.94 ± 0.37 | 1.874 | -| | | NEON | 8 | pp512 | 23.54 ± 1.56 | 38.32 ± 0.54 | 1.628 | -| | | AVX2 | 4 | tg128 | 7.80 ± 0.07 | 7.83 ± 0.09 | 1.004 | -| | | NEON | 8 | tg128 | 14.95 ± 0.25 | 15.28 ± 0.24 | 2.022 | -| 8x7B IQ3_XXS | 33.07 GiB | AVX2 | 16 | pp512 | 17.58 ± 0.04 | 68.45 ± 0.22 | 3.894 | -| | | NEON | 8 | pp512 | 7.75 ± 0.04 | 34.67 ± 0.40 | 4.474 | -| | | AVX2 | 4 | tg128 | 4.60 ± 0.01 | 5.45 ± 0.09 | 1.185 | -| | | AVX2 | 8 | tg128 | 8.04 ± 0.65 | 9.83 ± 0.06 | 1.223 | -| | | AVX2 | 16 | tg128 | 10.42 ± 0.01 | 10.57 ± 0.01 | 1.014 | -| | | NEON | 8 | tg128 | 6.19 ± 1.16 | 7.27 ± 0.14 | 1.174 | -| 8x7B Q5_0 | 59.11 GiB | AVX2 | 16 | pp512 | 29.06 ± 0.43 | 62.67 ± 0.32 | 2.157 | -| | | NEON | 8 | pp512 | 15.17 ± 0.51 | 27.36 ± 1.03 | 1.804 | -| | | AVX2 | 4 | tg128 | 5.44 ± 0.10 | 6.81 ± 0.06 | 1.252 | -| | | NEON | 8 | tg128 | 12.03 ± 0.77 | 12.41 ± 1.27 | 1.032 | - - -## Bitnet-1.58B - -Two implementations are provided -* `IQ1_BN` - uses 1.625 bits-per-weight (bpw) -* `IQ2_BN` - uses 2.0 bpw - -`IQ2_BN` is faster for PP (CPU and GPU, although the PP performance difference on CUDA is very minor). `IQ1_BN` can arrive at a higher TG performance on the Ryzen-7950X (given enough threads) because of the smaller model size, but it is always slower on the GPU and on the M2-Max CPU. - -There is the unmerged [PR 8151](https://github.com/ggerganov/llama.cpp/pull/8151) in `llama.cpp` that implements Bitnet-1.58B for the CPU (`AVX` and `ARM_NEON`, no GPU implementation). The following table compares performance between this repo and `PR-8151` in `llama.cpp`. The CUDA results were obtained on an RTX-4080, the Metal results on a 30-core M2-Max GPU. - -| model | size | backend | threads | test | t/s (llama.cpp) | t/s (this repo)| Speedup | -| ----------- | ---------: | ---------- | ------: | -----: | ---------------: | -------------: | ------: | -| 3B - IQ1_BN | 729.64 MiB | AVX2 | 16 | pp512 | 120.61 ± 0.48 | 423.19 ± 1.28 | 3.509 | -| | | NEON | 8 | pp512 | 46.64 ± 0.02 | 205.90 ± 0.88 | 4.415 | -| | | CUDA | 8 | pp512 | - | 10660 ± 170 | - | -| | | Metal | 8 | pp512 | - | 698.25 ± 1.91 | - | -| | | AVX2 | 2 | tg128 | 15.79 ± 0.01 | 22.13 ± 0.02 | 1.402 | -| | | AVX2 | 4 | tg128 | 28.64 ± 1.72 | 40.14 ± 0.04 | 1.402 | -| | | AVX2 | 8 | tg128 | 48.91 ± 0.08 | 61.79 ± 0.09 | 1.263 | -| | | AVX2 | 16 | tg128 | 57.73 ± 0.05 | 60.79 ± 0.05 | 1.053 | -| | | NEON | 2 | tg128 | 11.43 ± 0.04 | 16.87 ± 0.02 | 1.476 | -| | | NEON | 4 | tg128 | 21.11 ± 0.05 | 30.66 ± 0.11 | 1.452 | -| | | NEON | 8 | tg128 | 37.36 ± 0.07 | 55.21 ± 0.16 | 1.478 | -| | | CUDA | 8 | tg128 | - | 301.44 ± 0.12 | - | -| | | Metal | 8 | tg128 | - | 76.70 ± 0.07 | - | -| 3B - IQ2_BN | 873.65 MiB | AVX2 | 16 | pp512 | 151.39 ± 0.35 | 540.82 ± 2.48 | 3.572 | -| | | NEON | 8 | pp512 | 46.54 ± 0.03 | 242.05 ± 0.34 | 5.201 | -| | | CUDA | 8 | pp512 | - | 10800 ± 160 | - | -| | | Metal | 8 | pp512 | - | 723.19 ± 0.53 | - | -| | | AVX2 | 2 | tg128 | 18.93 ± 0.02 | 38.34 ± 0.08 | 2.026 | -| | | AVX2 | 4 | tg128 | 34.54 ± 0.06 | 56.29 ± 0.07 | 1.630 | -| | | AVX2 | 8 | tg128 | 52.97 ± 0.07 | 53.44 ± 0.08 | 1.009 | -| | | AVX2 | 16 | tg128 | 51.84 ± 0.25 | 53.46 ± 0.07 | 1.031 | -| | | NEON | 2 | tg128 | 11.40 ± 0.02 | 32.01 ± 0.27 | 2.808 | -| | | NEON | 4 | tg128 | 20.99 ± 0.00 | 56.45 ± 0.11 | 2.689 | -| | | NEON | 8 | tg128 | 37.28 ± 0.08 | 89.77 ± 0.70 | 2.408 | -| | | CUDA | 8 | tg128 | - | 322.10 ± 0.07 | - | -| | | Metal | 8 | tg128 | - | 110.39 ± 0.13 | - | - -We can make the following observations: -* For prompt processing this Bitnet-1.58b implementation is massively better than PR-8151 in `llama.cpp`, with gains between 3.4X and 5.2X! -* We get `PP-512 = 520 t/s` for the 2.0 bpw variant on the Ryzen-7950X, which costs less than $500. Hey, who needs a GPU? -* For low number of threads (2), this implementation is also much faster than PR-8151 for TG, where speed gains are between 1.4X and 2.8X. As we become memory bound on the Ryzen-7950X, the speed advantage goes away there for sufficiently high number of threads. But on the M2-Max this implementation is 1.4X (1.625 bpw) or 2.4X faster even at 8 threads -* Looking at TG on the M2-Max, the GPU looks a bit like wasted silicon (90 vs 110 t/s for TG-128 and the 2.0 bpw variant). If the GPU transistors had been spent to double the M2 number of CPU cores (and all memory bandwidth is given to the CPU), the CPU would be wiping the floor with the GPU. -* I'm of course kidding with the above. Still, it seems there are massive inefficiencies in the `llama.cpp` Metal implementation that start showing up when matrix multiplications become very fast as is the case here. The difference between CPU and GPU prompt processing speed is typically at least a factor of 7 in favor of the GPU on the M2-Max, but it is only around a factor of 3 here. -* It is worth noting that one needs to offload the token embeddings tensor to the GPU, else performance on CUDA/Metal is significantly lower. Bitnet uses the same tensor for token embeddings and for output. Mainline `llama.cpp` currently puts the token embeddings tensor on the CPU, and this results in running the matrix multiplication with the output tensor on the CPU. This most likely affects other models as well (e.g., Gemma), but I haven't yet looked into this. - -To reproduce these results: -* Clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B -* Run `python3 --outtype f16 path_to_bitnet` to convert to GGUF -* Run `./bin/llama-quantize path_to_bitnet/ggml-model-f16.gguf quantized.gguf [iq1_bn | iq2_bn]`. Note: no imatrix is required (and, if you provide one, it is ignored) -* Caveat: only the 3B Bitnet variant works. The smaller Bitnet models contain tensors with number of columns that are not even a multiple of 32, so basically no `llama.cpp` quant will work for these. - -## To tile or not to tile - -The common wisdom for efficient matrix multiplications is to use block tiling, and this is also used here for `fp16/fp32` matrices. But block tiling does not somehow magically reduce the amount of computation that needs to get done. Performance gains are simply due to the better utilization of memory caches. When dealing with quantized matrix multiplications, there is an additional factor that comes into play: the quantized data needs to be unpacked to 8-bit integers before being used in the matrix multiplication multiply-add operations. Depending on quantization type, this unpacking can represent a significant fraction of the overall computation cost. Hence, for best performance, one would want to reuse the unpacked quants as much as possible, thus spending some fraction of the available vector registers to hold the unpacked data. But when using block tiling, one also needs a certain number of vector registers for accumulating results. For instance, on `AVX2` (16 vector registers available), for `fp16/fp32` models best performance is achieved with `2 x 6` tiles (where the `2` refers to rows in the left matrix and is measured in units of the vector register size, so 16/8 floats for `fp16/fp32`, and `6` is for the number of columns in the right matrix). Unpacking quantized data works best when done in blocks of 128 or 256 quants so that, if we wanted to keep unpacked quants for 2 rows, we would need at least 8 vector registers, thus being left with less than 8 registers for result accumulation, so at best `2 x 3` tiles. In practice one needs addition vector registers for various constants that are typically needed for de-quantization, so that, at the end, it becomes better to use `1 x N` "tiles", i.e., a row-wise multiplication where each row in the left matrix is multiplied with `N` columns in the right matrix, thus reusing the unpacked data `N` times. This (i.e., amortizing de-quantization cost) is the main mechanism for seeding up quantized matrix multiplications. Having started with quantized matrices, and having gone from tiles to a row-wise implementation after some experimentation, I did try row-wise multiplication for float matrices first. Performance was not quite as good as for block-tiling, but I did get up to 90-95% of the speed of `tinyBLAS` that way before switching the `fp16/fp32` implementation to `2 x 6` (`AVX2`) or `5 x 5` (`AVX512` and `ARM_NEON`) block-tiles. But even for for `Q8_0 x Q8_0` multiplications, where there is basically no de-quantization cost, row-wise multiplication is faster than tiling (and hence this implemeintation beats `tinyBLAS`, which uses block-tiling also for `Q8_0`). +## License +MIT diff --git a/common/common.cpp b/common/common.cpp index 95e91bc1..f0c618e0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #if defined(_MSC_VER) #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING #endif @@ -265,6 +272,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.kv_overrides.emplace_back(); params.kv_overrides.back().key[0] = 0; } + if (!params.tensor_buft_overrides.empty()) { + params.tensor_buft_overrides.push_back({nullptr, nullptr}); + } return true; } @@ -287,6 +297,60 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return true; } +namespace { +bool parse_buft_overrides(const std::string& value, std::vector& overrides) { + /* static */ std::map buft_list; + if (buft_list.empty()) { + // enumerate all the devices and add their buffer types to the list + for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) { + //auto * dev = ggml_backend_reg_get_name(i); + auto * buft = ggml_backend_reg_get_default_buffer_type(i); + if (buft) { + buft_list[ggml_backend_buft_name(buft)] = buft; + } + } + } + for (const auto & override : string_split(value, ',')) { + std::string::size_type pos = override.find('='); + if (pos == std::string::npos) { + fprintf(stderr, "Invalid buft override argument %s\n", value.c_str()); + return false; + } + std::string tensor_name = override.substr(0, pos); + std::string buffer_type = override.substr(pos + 1); + if (buft_list.find(buffer_type) == buft_list.end()) { + fprintf(stderr, "Available buffer types:\n"); + for (const auto & it : buft_list) { + fprintf(stderr, " %s\n", ggml_backend_buft_name(it.second)); + } + return false; + } + overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); + } + return true; +} +template +std::vector> string_split_pairs(const std::string & str, char delim) { + std::vector> values; + std::istringstream str_stream(str); + std::string token; + T1 first_value; + int i = 0; + while (std::getline(str_stream, token, delim)) { + std::istringstream token_stream(token); + if (i%2 == 0) { + token_stream >> first_value; + } else { + T2 value; + token_stream >> value; + values.emplace_back(first_value, value); + } + i++; + } + return values; +} +} + #define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { @@ -813,6 +877,31 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.flash_attn = true; return true; } + if (arg == "-mla" || arg == "--mla-use") { + CHECK_ARG + params.mla_attn = std::stoi(argv[i]); + return true; + } + if (arg == "-amb" || arg == "--attention-max-batch") { + CHECK_ARG + params.attn_max_batch = std::stoi(argv[i]); + return true; + } + if (arg == "-fmoe" || arg == "--fused-moe") { + params.fused_moe_up_gate = true; + return true; + } + if (arg == "-ser" || arg == "--smart-expert-reduction") { + CHECK_ARG + auto values = string_split_pairs(argv[i], ','); + if (values.size() == 1) { + params.min_experts = values.front().first; + params.thresh_experts = values.front().second; + } else { + invalid_param = true; + } + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -911,6 +1000,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.use_mmap = false; return true; } + if (arg == "-thp" || arg == "--transparent-huge-pages") { + params.use_thp = true; + return true; + } if (arg == "--numa") { CHECK_ARG std::string value(argv[i]); @@ -1112,6 +1205,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "--override-tensor" || arg == "-ot") { + CHECK_ARG + if (!parse_buft_overrides(std::string{argv[i]}, params.tensor_buft_overrides)) { + fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]); + invalid_param = true; + } + return true; + } if (arg == "--host") { CHECK_ARG params.hostname = argv[i]; @@ -1356,6 +1457,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.warmup = false; return true; } + if (arg == "--output-format") { + CHECK_ARG + std::string value(argv[i]); + /**/ if (value == "jsonl") { params.sweep_bench_output_jsonl = true; } + else if (value == "md") { params.sweep_bench_output_jsonl = false; } + else { invalid_param = true; } + return true; + } + #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters if (log_param_single_parse(argv[i])) { @@ -1452,6 +1562,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); + options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); + options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); + options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2164,8 +2278,10 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { if (bos != -1) { tmp.push_back(bos); } - tmp.push_back(eos); - + else + { + tmp.push_back(eos); + } if (llama_model_has_encoder(model)) { llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); @@ -2211,12 +2327,19 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.repack_tensors = params.repack_tensors; + mparams.use_thp = params.use_thp; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key"); mparams.kv_overrides = params.kv_overrides.data(); } + if (params.tensor_buft_overrides.empty()) { + mparams.tensor_buft_overrides = NULL; + } else { + GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern"); + mparams.tensor_buft_overrides = params.tensor_buft_overrides.data(); + } return mparams; } @@ -2252,6 +2375,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { if (s == "q6_0") { return GGML_TYPE_Q6_0; } + if (s == "q8_KV") { + return GGML_TYPE_Q8_KV; + } throw std::runtime_error("Invalid cache type: " + s); } @@ -2283,6 +2409,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; + cparams.attn_max_batch = params.attn_max_batch; + cparams.fused_moe_up_gate = params.fused_moe_up_gate; + cparams.min_experts = params.min_experts; + cparams.thresh_experts = params.thresh_experts; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3252,6 +3383,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); fprintf(stream, "repack: %s # default: false\n", params.repack_tensors ? "true" : "false"); + fprintf(stream, "use_thp: %s # default: false\n", params.use_thp ? "true" : "false"); fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); @@ -3280,6 +3412,10 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); + fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); + fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); + fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); + fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index 73d7d650..b4f75236 100644 --- a/common/common.h +++ b/common/common.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + // Various helper functions and utilities #pragma once @@ -135,6 +142,7 @@ struct gpt_params { std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; + std::vector tensor_buft_overrides; bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) std::vector lora_adapters; // lora adapter path with user defined scale @@ -174,6 +182,11 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache + int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) + bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models + int min_experts = -1; + float thresh_experts = 0; bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens @@ -188,6 +201,7 @@ struct gpt_params { bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data bool repack_tensors = false; // repack tensors if interleaved variant is available + bool use_thp = false; // use transparent huge pages (linux only) std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V @@ -268,6 +282,8 @@ struct gpt_params { bool spm_infill = false; // suffix/prefix/middle pattern for infill std::string lora_outfile = "ggml-lora-merged-f16.gguf"; + + bool sweep_bench_output_jsonl = false; }; void gpt_params_handle_hf_token(gpt_params & params); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3910aa1d..966cfcd3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -14,6 +14,7 @@ from enum import IntEnum from pathlib import Path from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast +from itertools import chain import math import numpy as np @@ -256,10 +257,14 @@ class Model: return False + # some models need extra generated tensors (like rope_freqs) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + return () + def prepare_tensors(self): max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") - for name, data_torch in self.get_tensors(): + for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): # we don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): continue @@ -1559,7 +1564,7 @@ class LlamaModel(Model): return [(self.map_tensor_name(name), data_torch)] - def prepare_tensors(self): + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) @@ -1586,8 +1591,9 @@ class LlamaModel(Model): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + def prepare_tensors(self): super().prepare_tensors() if self._experts is not None: @@ -1597,7 +1603,186 @@ class LlamaModel(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("DeciLMForCausalLM") +class DeciModel(Model): + model_arch = gguf.MODEL_ARCH.DECI + + @staticmethod + def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: + # DeciLM-specific code + intermediate_size = int(2 * ffn_mult * n_embd / 3) + return DeciModel._find_multiple(intermediate_size, 256) + + @staticmethod + def _find_multiple(n: int, k: int) -> int: + # DeciLM-specific code + if n % k == 0: + return n + return n + k - (n % k) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + _block_configs: list[dict[str,Any]] = self.hparams["block_configs"] + assert self.block_count == len(_block_configs) + self._num_kv_heads = list() + self._num_heads = list() + _ffn_multipliers = list() + # ***linear attention layer*** + # if n_heads_in_group is None and replace_with_linear is True + # then _num_kv_heads[il] is 0 and _num_heads[il] is num_attention_heads + # ***attention-free layer*** + # if n_heads_in_group is None and replace_with_linear is False + # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 + # ***normal attention-layer*** + # if n_heads_in_group is not None, then + # _num_kv_heads[il] is num_attention_head // n_heads_in_group and + # _num_heads[il] is num_attention_head + # ***dummy layer*** for nemotron 253B + # if n_heads_in_group is None and ffn_mult is None + # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 and _ffn_dims is 0 + for il in range(len(_block_configs)): + if _block_configs[il]["attention"]["n_heads_in_group"] is None: + if _block_configs[il]["attention"]["replace_with_linear"] is True: + self._num_kv_heads.append(0) + self._num_heads.append(self.hparams["num_attention_heads"]) + else: + self._num_kv_heads.append(0) + self._num_heads.append(0) + else: + self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"]) + self._num_heads.append(self.hparams["num_attention_heads"]) + if _block_configs[il]["ffn"]["ffn_mult"] is None: # dummy layer + _ffn_multipliers.append(0.0) + else: + _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"]) + assert self.block_count == len(self._num_kv_heads) + assert self.block_count == len(self._num_heads) + assert self.block_count == len(_ffn_multipliers) + assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int) + assert isinstance(self._num_heads, list) and isinstance(self._num_heads[0], int) + assert isinstance(_ffn_multipliers, list) and isinstance(_ffn_multipliers[0], float) + self._ffn_dims: list[int] = [ + DeciModel._ffn_mult_to_intermediate_size(multiplier, self.hparams["hidden_size"]) + for multiplier in _ffn_multipliers + ] + + def set_vocab(self): + # Please change tokenizer_config.json of Llama-3_1-Nemotron-51B's + # eos_token from '|eot_id|' to '|end_of_text|' + if self.hparams.get("vocab_size", 128256) == 128256: + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + else: + # DeciLM-7B + self._set_vocab_llama_hf() + + def set_gguf_parameters(self): + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + assert self.block_count == len(self._num_kv_heads) + assert self.block_count == len(self._num_heads) + assert self.block_count == len(self._ffn_dims) + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + self.gguf_writer.add_head_count_kv(self._num_kv_heads) + self.gguf_writer.add_head_count(self._num_heads) + self.gguf_writer.add_feed_forward_length(self._ffn_dims) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_key_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_value_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_file_type(self.ftype) + else: # DeciLM-7B + super().set_gguf_parameters() + if "num_key_value_heads_per_layer" in self.hparams: # DeciLM-7B + self._num_kv_heads: list[int] = self.hparams["num_key_value_heads_per_layer"] + assert self.block_count == len(self._num_kv_heads) + self.gguf_writer.add_head_count_kv(self._num_kv_heads) + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + if bid is not None: + if "num_key_value_heads_per_layer" in self.hparams: + n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid] + elif "block_configs" in self.hparams: + n_kv_head = self._num_kv_heads[bid] + n_head = self._num_heads[bid] + else: + n_kv_head = self.hparams.get("num_key_value_heads") + else: + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = DeciModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = DeciModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + assert low_freq_wavelen != high_freq_wavelen + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + + def prepare_tensors(self): + super().prepare_tensors() + + @Model.register("BitnetForCausalLM") +@Model.register("BitNetForCausalLM") class BitnetModel(Model): model_arch = gguf.MODEL_ARCH.BITNET @@ -1937,6 +2122,13 @@ class Qwen2MoeModel(Model): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("Qwen3ForCausalLM") +class Qwen3Model(Qwen2Model): + model_arch = gguf.MODEL_ARCH.QWEN3 + +@Model.register("Qwen3MoeForCausalLM") +class Qwen3MoeModel(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.QWEN3MOE @Model.register("GPT2LMHeadModel") class GPT2Model(Model): @@ -2121,6 +2313,13 @@ class Phi3MiniModel(Model): self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_sliding_window(self.find_hparam(["sliding_window"])) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rope_dims = n_embd // n_head + # write rope scaling for long context (128k) model rope_scaling = self.find_hparam(['rope_scaling'], True) if rope_scaling is None: @@ -2150,8 +2349,8 @@ class Phi3MiniModel(Model): if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) @Model.register("PlamoForCausalLM") @@ -3123,6 +3322,7 @@ class ArcticModel(Model): @Model.register("DeepseekV2ForCausalLM") +@Model.register("DeepseekV3ForCausalLM") class DeepseekV2Model(Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -3144,6 +3344,15 @@ class DeepseekV2Model(Model): self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -3156,6 +3365,17 @@ class DeepseekV2Model(Model): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] @@ -3188,6 +3408,27 @@ class DeepseekV2Model(Model): return tensors else: return [] + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) + v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1]) + + return [ + (self.map_tensor_name(name), data_torch), + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] return [(self.map_tensor_name(name), data_torch)] diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index a88d0d4a..ef088034 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -331,6 +331,10 @@ if __name__ == '__main__': self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) super().set_gguf_parameters() + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # Never add extra tensors (e.g. rope_freqs) for LoRA adapters + return () + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: tensor_map: dict[str, PartialLoraTensor] = {} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 67b3d277..3987fe13 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -51,5 +51,6 @@ else() add_subdirectory(save-load-state) add_subdirectory(simple) add_subdirectory(speculative) + add_subdirectory(sweep-bench) add_subdirectory(tokenize) endif() diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 497c9d14..d1693fa5 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -19,6 +19,8 @@ #include #include #include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -39,6 +41,7 @@ struct Stats { std::vector values; std::vector counts; int ncall = 0; + int n_as = 1; }; class IMatrixCollector { @@ -48,13 +51,59 @@ public: bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); void save_imatrix(int ncall = -1) const; bool load_imatrix(const char * file_name); + void set_collect_lsim(bool yes_or_no) { m_collect_lsim = yes_or_no; } + void print_layer_importance(); private: std::unordered_map m_stats; gpt_params m_params; std::mutex m_mutex; int m_last_call = 0; + int m_last_layer = 9999; + int m_last_ffn = -1; std::vector m_src1_data; std::vector m_ids; // the expert ids from ggml_mul_mat_id + std::vector m_last_input; + std::vector m_ffn_input; + std::vector> m_layer_sim; + std::vector> m_attn_sim; + std::vector> m_ffn_sim; + bool m_collect_lsim = false; + + std::optional layer_index(const std::string& name) const { + if (name == m_params.output_tensor_name && m_last_layer < 199) { + return m_last_layer + 1; + } + if (auto pos = name.find("blk."); pos == 0) { + pos += 4; + if (auto pos1 = name.find('.', pos); pos1 != std::string::npos) { + auto index_str = name.substr(pos, pos1 - pos); + std::istringstream str(index_str); + int index; str >> index; + if (!str.fail()) return index; + } + } + return std::nullopt; + } + + static inline double cosine_similarity(int n, const float * x, const float * y) { + double sumxy = 0, sumx2 = 0, sumy2 = 0; + for (int j = 0; j < n; ++j) { + sumxy += x[j]*y[j]; sumx2 += x[j]*x[j]; sumy2 += y[j]*y[j]; + } + double cos_sim = sumx2 > 0 && sumy2 > 0 ? sumxy/sqrt(sumx2*sumy2) : 0; + return cos_sim; + } + + static inline void collect_cos_similarity(int nrow, int n, const float * x, const float * y, std::pair& p) { + for (int row = 0; row < nrow; ++row) { + p.first += cosine_similarity(n, x, y); + p.second += 1; + x += n; + y += n; + } + } + + static void print_layer_importance(const char * msg, const std::vector>& sim); }; // remove any prefix and suffixes from the name @@ -76,6 +125,45 @@ static std::string filter_tensor_name(const char * name) { return wname; } +void IMatrixCollector::print_layer_importance(const char * msg, const std::vector>& sim) { + if (sim.empty()) return; + std::vector> layers; + layers.reserve(sim.size()); + for (int i = 0; i < int(sim.size()); ++i) { + if (sim[i].second > 0) layers.emplace_back(float(std::abs(sim[i].first/sim[i].second)), i); + } + if (layers.empty()) return; + std::sort(layers.begin(), layers.end()); + printf("%s\n", msg); + //printf("======================== sorted layer importances\n"); + int j = 0; + for (auto& p : layers) { + int i = p.second; + printf("%3d: Layer %3d, = %g\n", j++, i, sim[i].first/sim[i].second); + } +} + +void IMatrixCollector::print_layer_importance() { + print_layer_importance("\n======================== sorted layer importances", m_layer_sim); + print_layer_importance("\n======================== sorted attention importances", m_attn_sim); + print_layer_importance("\n======================== sorted ffn importances", m_ffn_sim); + //printf("%s: have %d layers\n", __func__, int(m_layer_sim.size())); + //if (m_layer_sim.empty()) return; + //std::vector> layers; + //layers.reserve(m_layer_sim.size()); + //for (int i = 0; i < int(m_layer_sim.size()); ++i) { + // if (m_layer_sim[i].second > 0) layers.emplace_back(float(std::abs(m_layer_sim[i].first/m_layer_sim[i].second)), i); + //} + //if (layers.empty()) return; + //std::sort(layers.begin(), layers.end()); + //printf("======================== sorted layer importances\n"); + //int j = 0; + //for (auto& p : layers) { + // int i = p.second; + // printf("%3d: Layer %3d, = %g\n", j++, i, m_layer_sim[i].first/m_layer_sim[i].second); + //} +} + bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { GGML_UNUSED(user_data); @@ -91,7 +179,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * // why are small batches ignored (<16 tokens)? if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; //printf("wname = %s\n", wname.c_str()); - if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == m_params.output_tensor_name))) return false; + if (!(wname.substr(0, 4) == "blk." || ((m_params.process_output || m_collect_lsim) && wname == m_params.output_tensor_name))) return false; return true; } @@ -107,6 +195,33 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const float * data = is_host ? (const float *) src1->data : m_src1_data.data(); + if (m_collect_lsim) { + if (wname.find(".ffn_") != std::string::npos) { + if (auto index = layer_index(wname); index.has_value() && *index == m_last_layer && *index != m_last_ffn) { + int n = src1->ne[0]; + int nrow = t->op == GGML_OP_MUL_MAT_ID ? src1->ne[2] : src1->ne[1]; + if (t->op == GGML_OP_MUL_MAT_ID) { + GGML_ASSERT(src1->ne[1] == 1); + } + if (m_ffn_input.empty()) { + m_ffn_input.resize(nrow*n); + } else { + if ((int)m_ffn_input.size() != nrow*n) { + printf("Oops, inconsistent ffn size\n"); exit(1); + } + } + std::memcpy(m_ffn_input.data(), data, nrow*n*sizeof(float)); + if (m_ffn_input.size() != m_last_input.size()) { + printf("Oops, inconsistent ffn vs last_input size\n"); exit(1); + } + if (m_attn_sim.size() < *index + 1) m_attn_sim.resize(*index + 1); + auto& p = m_attn_sim[*index]; + collect_cos_similarity(nrow, n, m_ffn_input.data(), m_last_input.data(), p); + m_last_ffn = *index; + } + } + } + // this has been adapted to the new format of storing merged experts in a single 3d tensor // ref: https://github.com/ggerganov/llama.cpp/pull/6387 if (t->op == GGML_OP_MUL_MAT_ID) { @@ -132,11 +247,15 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (e.values.empty()) { e.values.resize(src1->ne[0]*n_as, 0); e.counts.resize(src1->ne[0]*n_as, 0); + e.n_as = n_as; } else if (e.values.size() != (size_t)src1->ne[0]*n_as) { fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); exit(1); //GGML_ABORT("fatal error"); } + else if (e.n_as != n_as) { + fprintf(stderr, "Oops: inconsistent n_as for %s (%d vs %d)\n", wname.c_str(), e.n_as, n_as); + } if (m_params.verbosity > 1) { printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); } @@ -177,6 +296,39 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } } } else { + if (m_collect_lsim) { + // We only need to do it here and not in the MoE branch above because the first tensor in a layer + // never is a MoE tensor + if (auto index = layer_index(wname); index.has_value()) { + if (*index != m_last_layer) { + if (*index > 0) { + if (m_last_input.size() != src1->ne[0]*src1->ne[1]) { + printf("Oops: different size (%d vs %d). Tensor name was %s, m_last_layer = %d\n", + (int)(src1->ne[0]*src1->ne[1]), (int)m_last_input.size(), src0->name, m_last_layer); + exit(1); + } + if (*index > m_layer_sim.size()) m_layer_sim.resize(*index); + auto& p = m_layer_sim[*index - 1]; + collect_cos_similarity(src1->ne[1], src1->ne[0], m_last_input.data(), (const float *)data, p); + if (*index == m_last_ffn + 1) { + if (*index > m_ffn_sim.size()) m_ffn_sim.resize(*index); + auto& p1 = m_ffn_sim[*index-1]; + collect_cos_similarity(src1->ne[1], src1->ne[0], m_ffn_input.data(), (const float *)data, p1); + } + } + m_last_layer = *index; + if (m_last_input.empty()) { + m_last_input.resize(src1->ne[0]*src1->ne[1]); + } else { + if (m_last_input.size() != src1->ne[0]*src1->ne[1]) { + printf("Oops\n"); exit(1); + } + } + //printf("Copying src1 to m_last_input\n"); + std::memcpy(m_last_input.data(), data, src1->ne[0]*src1->ne[1]*sizeof(float)); + } + } + } auto & e = m_stats[wname]; if (e.values.empty()) { e.values.resize(src1->ne[0], 0); @@ -190,7 +342,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (m_params.verbosity > 1) { printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); } - for (int row = 0; row < (int)src1->ne[1]; ++row) { + for (int row = 0; row < (int)(src1->ne[1]*src1->ne[2]); ++row) { const float * x = data + row * src1->ne[0]; for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[j] += x[j]*x[j]; @@ -258,8 +410,38 @@ void IMatrixCollector::save_imatrix(int ncall) const { } if (n_zeros > 0) { - fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); - continue; + fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%)", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); + bool store_it = false; + if (kv.second.n_as > 1) { + int n_per_expert = n_all / kv.second.n_as; + std::vector bad_experts; + bad_experts.reserve(kv.second.n_as); + for (int i = 0; i < kv.second.n_as; ++i) { + auto counts = kv.second.counts.data() + i*n_per_expert; + int nz_i = 0; + for (int j = 0; j < n_per_expert; ++j) { + if (counts[j] == 0) ++nz_i; + } + if (nz_i > 0) bad_experts.push_back(i); + } + fprintf(stderr, " %d out of %d experts are missing data", int(bad_experts.size()), kv.second.n_as); + if (bad_experts.size() < round(kv.second.n_as * 0.05)) { + fprintf(stderr, " Storing **but be aware**\n"); + store_it = true; + for (auto i : bad_experts) { + auto counts = (int *)kv.second.counts.data() + i*n_per_expert; + auto values = (float *)kv.second.values.data() + i*n_per_expert; + for (int j = 0; j < n_per_expert; ++j) { + counts[j] = 1; + values[j] = 1; + } + } + } + } + if (!store_it) { + fprintf(stderr, " - skipping\n"); + continue; + } } n_entries++; @@ -587,7 +769,25 @@ int main(int argc, char ** argv) { params.logits_all = true; params.verbosity = 1; - if (!gpt_params_parse(argc, argv, params)) { + bool lsim = false; + // + // Do not pollute common with totally imatrix specific arguments as it was done in mainline. + // Instead, parse imatrix specific args here, push unknown args into a new array of args, + // and pass that to gpt_params_parse(). + // + std::vector args; + args.reserve(argc); + args.push_back(argv[0]); + for (int i = 1; i < argc; ++i) { + std::string arg{argv[i]}; + if (arg == "-lsim" || arg == "--layer-similarity") { + lsim = true; + } else { + args.push_back(argv[i]); + } + } + + if (!gpt_params_parse(args.size(), args.data(), params)) { print_usage(argc, argv, params); return 1; } @@ -595,6 +795,7 @@ int main(int argc, char ** argv) { params.n_batch = std::min(params.n_batch, params.n_ctx); g_collector.set_params(params); + g_collector.set_collect_lsim(lsim); for (const auto & in_file : params.in_files) { printf("%s : loading imatrix from '%s'\n", __func__, in_file.c_str()); @@ -645,6 +846,7 @@ int main(int argc, char ** argv) { } g_collector.save_imatrix(); + g_collector.print_layer_importance(); llama_print_timings(ctx); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 42320da8..74f51494 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include #include #include @@ -41,6 +48,12 @@ static uint64_t get_time_ns() { return std::chrono::nanoseconds(clock::now().time_since_epoch()).count(); } +template +std::ostream& operator<<(std::ostream& str, const std::pair& item) { + str << '{' << item.first << ", " << item.second << '}'; + return str; +} + template static std::string join(const std::vector & values, const std::string & delim) { std::ostringstream str; @@ -215,6 +228,9 @@ static std::string pair_str(const std::pair & p) { return buf; } +// Ser = Smart Expert Reduction +using Ser = std::pair; + struct cmd_params { std::vector model; std::vector n_prompt; @@ -225,21 +241,27 @@ struct cmd_params { std::vector n_ubatch; std::vector type_k; std::vector type_v; - std::vector n_threads; + std::vector> n_threads; std::vector n_gpu_layers; std::vector rpc_servers; std::vector split_mode; std::vector main_gpu; std::vector no_kv_offload; std::vector flash_attn; + std::vector mla_attn; + std::vector attn_max_batch; + std::vector ser; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; + std::vector buft_overrides; ggml_numa_strategy numa; int reps; bool verbose; bool warmup; bool repack = false; + bool fmoe = false; + bool use_thp = false; output_formats output_format; output_formats output_format_stderr; }; @@ -254,21 +276,27 @@ static const cmd_params cmd_params_defaults = { /* n_ubatch */ {512}, /* type_k */ {GGML_TYPE_F16}, /* type_v */ {GGML_TYPE_F16}, - /* n_threads */ {cpu_get_num_math()}, + /* n_threads */ {{cpu_get_num_math(), cpu_get_num_math()}}, /* n_gpu_layers */ {99}, /* rpc_servers */ {""}, /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {false}, + /* mla_attn */ {0}, + /* attn_max_batch */ {0}, + /* ser */ {{-1,0.0f}}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, + /* buft_overrides */ {}, /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* verbose */ false, /* warmup */ true, /* repack */ false, + /* use_thp */ false, + /* fmoe */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -288,12 +316,16 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -ctk, --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); printf(" -ctv, --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); + printf(" -tgb, --threads-gen-batch (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); printf(" -rpc, --rpc (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str()); printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); + printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); + printf(" -amb, --attn-max-batch (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); + printf(" -ser, --smart-expert-reduction (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -304,6 +336,9 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); + printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0"); + printf(" -ot, --override-tensor pattern (default: none)\n"); + printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); } @@ -336,10 +371,68 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "q6_0") { return GGML_TYPE_Q6_0; } + if (s == "q8_KV") { + return GGML_TYPE_Q8_KV; + } return GGML_TYPE_COUNT; } +namespace { +bool parse_buft_overrides(const std::string& value, std::vector& overrides) { + /* static */ std::map buft_list; + if (buft_list.empty()) { + // enumerate all the devices and add their buffer types to the list + for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) { + //auto * dev = ggml_backend_reg_get_name(i); + auto * buft = ggml_backend_reg_get_default_buffer_type(i); + if (buft) { + buft_list[ggml_backend_buft_name(buft)] = buft; + } + } + } + for (const auto & override : string_split(value, ',')) { + std::string::size_type pos = override.find('='); + if (pos == std::string::npos) { + fprintf(stderr, "Invalid buft override argument %s\n", value.c_str()); + return false; + } + std::string tensor_name = override.substr(0, pos); + std::string buffer_type = override.substr(pos + 1); + if (buft_list.find(buffer_type) == buft_list.end()) { + fprintf(stderr, "Available buffer types:\n"); + for (const auto & it : buft_list) { + fprintf(stderr, " %s\n", ggml_backend_buft_name(it.second)); + } + return false; + } + overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); + } + return true; +} +template +std::vector> string_split_pairs(const std::string & str, char delim) { + std::vector> values; + std::istringstream str_stream(str); + std::string token; + T1 first_value; + int i = 0; + while (std::getline(str_stream, token, delim)) { + std::istringstream token_stream(token); + if (i%2 == 0) { + token_stream >> first_value; + if (token_stream.fail()) return {}; + } else { + T2 value; + token_stream >> value; + if (token_stream.fail()) return {}; + values.emplace_back(first_value, value); + } + i++; + } + return values; +} +} static cmd_params parse_cmd_params(int argc, char ** argv) { cmd_params params; @@ -459,7 +552,23 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } auto p = string_split(argv[i], split_delim); - params.n_threads.insert(params.n_threads.end(), p.begin(), p.end()); + params.n_threads.reserve(params.n_threads.size() + p.size()); + for (auto t : p) params.n_threads.push_back({t, t}); + //params.n_threads.insert(params.n_threads.end(), p.begin(), p.end()); + } else if (arg == "-tgb" || arg == "--threads-gen-batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto ps = string_split(argv[i], ';'); + for (auto& s : ps) { + auto p = string_split(s.c_str(), ','); + if (p.size() != 2) { + invalid_param = true; + break; + } + params.n_threads.push_back({p[0], p[1]}); + } } else if (arg == "-ngl" || arg == "--n-gpu-layers") { if (++i >= argc) { invalid_param = true; @@ -526,6 +635,27 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); + } else if (arg == "-mla" || arg == "--mla-attn") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end()); + } else if (arg == "-amb" || arg == "--attn-max-batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end()); + } else if (arg == "-ser" || arg == "--smart-expert-reduction") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split_pairs(argv[i], split_delim); + params.ser.insert(params.ser.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -594,6 +724,28 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.repack = std::stoi(argv[i]); + } else if (arg == "-thp" || arg == "--transparent-huge-pages") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.use_thp = std::stoi(argv[i]); + } else if (arg == "-fmoe" || arg == "--fused-moe") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.fmoe = std::stoi(argv[i]); + } else if (arg == "-ot" || arg == "--override-tensor") { + if (++i >= argc) { + invalid_param = true; + break; + } + if (!parse_buft_overrides(std::string{argv[i]}, params.buft_overrides)) { + fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]); + invalid_param = true; + break; + } } else { invalid_param = true; break; @@ -621,10 +773,14 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } + if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } + if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; } + if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } + if (!params.buft_overrides.empty()) params.buft_overrides.emplace_back(llama_model_tensor_buft_override{nullptr, nullptr}); return params; } @@ -649,17 +805,23 @@ struct cmd_params_instance { int n_ubatch; ggml_type type_k; ggml_type type_v; - int n_threads; + std::pair n_threads; int n_gpu_layers; std::string rpc_servers; llama_split_mode split_mode; int main_gpu; bool no_kv_offload; bool flash_attn; + int mla_attn; + int attn_max_batch; + Ser ser; std::vector tensor_split; bool use_mmap; bool embeddings; bool repack = false; + bool fmoe = false; + bool use_thp = false; + const llama_model_tensor_buft_override* buft_overrides; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -673,6 +835,8 @@ struct cmd_params_instance { mparams.tensor_split = tensor_split.data(); mparams.use_mmap = use_mmap; mparams.repack_tensors = repack; + mparams.use_thp = use_thp; + mparams.tensor_buft_overrides = buft_overrides; return mparams; } @@ -685,6 +849,7 @@ struct cmd_params_instance { main_gpu == other.main_gpu && use_mmap == other.use_mmap && repack == other.repack && + use_thp == other.use_thp && tensor_split == other.tensor_split; } @@ -698,6 +863,11 @@ struct cmd_params_instance { cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; + cparams.mla_attn = mla_attn; + cparams.attn_max_batch = attn_max_batch; + cparams.fused_moe_up_gate = fmoe; + cparams.min_experts = ser.first; + cparams.thresh_experts = ser.second; cparams.embeddings = embeddings; return cparams; @@ -722,6 +892,9 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) + for (const auto & mla : params.mla_attn) + for (const auto & amb : params.attn_max_batch) + for (const auto & ser : params.ser) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -743,10 +916,16 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, + /* .use_thp = */ params.use_thp, + /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); } @@ -771,10 +950,16 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, + /* .use_thp = */ params.use_thp, + /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); } @@ -799,10 +984,16 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, + /* .use_thp = */ params.use_thp, + /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); } @@ -827,10 +1018,16 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, + /* .use_thp = */ params.use_thp, + /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); } @@ -857,7 +1054,7 @@ struct test { uint64_t model_n_params; int n_batch; int n_ubatch; - int n_threads; + std::pair n_threads; bool has_rpc; ggml_type type_k; ggml_type type_v; @@ -866,10 +1063,15 @@ struct test { int main_gpu; bool no_kv_offload; bool flash_attn; + int mla_attn; + int attn_max_batch; + Ser ser; std::vector tensor_split; bool use_mmap; bool embeddings; bool repack = false; + bool fmoe = false; + bool use_thp = false; int n_prompt; int n_gen; std::string test_time; @@ -895,10 +1097,15 @@ struct test { main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; flash_attn = inst.flash_attn; + mla_attn = inst.mla_attn; + attn_max_batch = inst.attn_max_batch; + ser = inst.ser; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; repack = inst.repack; + fmoe = inst.fmoe; + use_thp = inst.use_thp; n_prompt = inst.n_prompt; n_gen = inst.n_gen; test_kind = inst.test_kind; @@ -988,8 +1195,8 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", - "tensor_split", "use_mmap", "embeddings", "repack", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", + "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "use_thp", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "test", @@ -1004,13 +1211,14 @@ struct test { field == "n_threads" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || - field == "n_prompt" || field == "n_gen" || + field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "attn_max_batch" || field == "avg_ns" || field == "stddev_ns") { return INT; } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { + field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || + field == "fused_moe") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1035,6 +1243,12 @@ struct test { tensor_split_str += "/"; } } + auto ser_to_string = [] (const Ser& ser) { + std::ostringstream str; + str << ser.first << ',' << ser.second; + return str.str(); + }; + bool is_gen = n_gen > 0; std::vector values = { build_commit, std::to_string(build_number), std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan), @@ -1042,10 +1256,12 @@ struct test { cpu_info, gpu_info, model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params), std::to_string(n_batch), std::to_string(n_ubatch), - std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), + std::to_string(is_gen ? n_threads.first : n_threads.second), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), + std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), + tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), + std::to_string(repack), std::to_string(fmoe), std::to_string(use_thp), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()), @@ -1208,12 +1424,27 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return 2; } + if (field == "mla_attn") { + return 3; + } + if (field == "attn_max_batch") { + return 5; + } + if (field == "ser") { + return 10; + } if (field == "use_mmap") { return 4; } if (field == "repack") { return 3; } + if (field == "use_thp") { + return 3; + } + if (field == "fused_moe") { + return 4; + } if (field == "test") { return 13; } @@ -1242,12 +1473,27 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return "fa"; } + if (field == "mla_attn") { + return "mla"; + } + if (field == "attn_max_batch") { + return "amb"; + } + if (field == "attn_max_batch") { + return "ser"; + } if (field == "use_mmap") { return "mmap"; } if (field == "repack") { return "rtr"; } + if (field == "use_thp") { + return "thp"; + } + if (field == "fused_moe") { + return "fmoe"; + } if (field == "embeddings") { return "embd"; } @@ -1294,6 +1540,15 @@ struct markdown_printer : public printer { if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { fields.emplace_back("flash_attn"); } + if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) { + fields.emplace_back("mla_attn"); + } + if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) { + fields.emplace_back("attn_max_batch"); + } + if (params.ser.size() > 1 || params.ser != cmd_params_defaults.ser) { + fields.emplace_back("ser"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } @@ -1306,6 +1561,12 @@ struct markdown_printer : public printer { if (params.repack != cmd_params_defaults.repack) { fields.emplace_back("repack"); } + if (params.use_thp != cmd_params_defaults.use_thp) { + fields.emplace_back("use_thp"); + } + if (params.fmoe != cmd_params_defaults.fmoe) { + fields.emplace_back("fused_moe"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); @@ -1557,10 +1818,10 @@ int main(int argc, char ** argv) { if (params.warmup) { if (t.n_prompt > 0) { //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + test_prompt(ctx, 1, 0, t.n_batch, t.n_threads.second); } if (t.n_gen > 0) { - test_gen(ctx, 1, 0, t.n_threads); + test_gen(ctx, 1, 0, t.n_threads.first); } } @@ -1570,11 +1831,11 @@ int main(int argc, char ** argv) { uint64_t t_start = get_time_ns(); if (t.n_prompt > 0) { - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads.second); } if (t.test_kind == TEST_KIND_GP) t_start = get_time_ns(); if (t.n_gen > 0) { - test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); + test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads.first); } uint64_t t_ns = get_time_ns() - t_start; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 372684f0..12702693 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.h" #include "llama.h" @@ -126,7 +133,7 @@ static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob max_logit = std::max(max_logit, logits[i]); min_logit = std::min(min_logit, logits[i]); } - min_logit = std::max(min_logit, max_logit - 16); + min_logit = std::max(min_logit, max_logit - 24); double sum_exp = 0.0; for (int i = 0; i < n_vocab; ++i) { sum_exp += expf(logits[i] - max_logit); @@ -166,7 +173,7 @@ static void process_logits( break; } lock.unlock(); - const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + const results_log_softmax results = log_softmax(n_vocab, logits + int64_t(i)*n_vocab, tokens[i+1]); const double v = -results.log_softmax; local_nll += v; local_nll2 += v*v; @@ -200,7 +207,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits, break; } lock.unlock(); - const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]); + const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + int64_t(i)*nv, tokens[i+1]); local_nll += v; local_nll2 += v*v; } @@ -618,7 +625,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par if (num_batches > 1 && n_outputs > 0) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + int64_t(n_outputs) * n_vocab); } } diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 4c5d408a..a49ebd92 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #define LLAMA_API_INTERNAL #include "common.h" #include "ggml.h" diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1eab8573..219032cf 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.h" #include "llama.h" @@ -58,6 +65,7 @@ static const std::vector QUANT_OPTIONS = { { "Q5_0_R4", LLAMA_FTYPE_MOSTLY_Q5_0_R4, " 5.50 bpw quantization", }, { "Q6_0_R4", LLAMA_FTYPE_MOSTLY_Q6_0_R4, " 6.50 bpw quantization", }, { "Q8_0_R8", LLAMA_FTYPE_MOSTLY_Q8_0_R8, " 8.50 bpw quantization", }, + { "Q8_KV", LLAMA_FTYPE_MOSTLY_Q8_KV, " 8.00 bpw quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", }, @@ -85,6 +93,7 @@ static const std::vector QUANT_OPTIONS = { { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", }, { "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", }, { "Q8_K_R8", LLAMA_FTYPE_MOSTLY_Q8_K_R8, "Q8_K repacked", }, + { "Q8_KV_R8", LLAMA_FTYPE_MOSTLY_Q8_KV_R8, "Q8_KV repacked", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, { "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, @@ -136,15 +145,19 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp // [[noreturn]] static void usage(const char * executable) { - printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--attn-q-type] [--attn-k-type] [--attn-v-type] [--attn-qkv-type] [--attn-output-type] [--ffn-gate-type] [--ffn-down-type] [--ffn-up-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable); + printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--hide-imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--attn-q-type] [--attn-k-type] [--attn-v-type] [--attn-qkv-type] [--attn-output-type] [--ffn-gate-type] [--ffn-down-type] [--ffn-up-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable); printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); printf(" --imatrix file_name: use data in file_name as importance matrix for quant optimizations\n"); + printf(" --hide-imatrix: do not store imatrix details in the quantized model\n"); printf(" --include-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor.\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token_embd.weight tensor.\n\n"); + printf(" --custom-q regex1=type1,regex2=type2...: use this to specify custom quantization type rules.\n\n"); + printf(" --repack Repack all tensors to the corresponding _r4/8 variant if available.\n\n"); + printf(" --repack-pattern Comma separated list of regexs to use for matching tensor names to be repacked.\n\n"); printf("Additional specific tensor quantization types used in the custom quant scheme 'CQS (default is Q2_K):\n"); printf(" --attn-q-type ggml_type: use this ggml_type for the attn_q.weight tensor.\n"); printf(" --attn-k-type ggml_type: use this ggml_type for the attn_k.weight tensor.\n"); @@ -291,6 +304,28 @@ static ggml_type parse_ggml_type(const char * arg) { return result; } +using CustomQ = std::pair; + +static bool parse_custom_quants(const std::string& arg, std::vector& custom_quants) { + for (const auto & item : string_split(arg, ',')) { + auto pos = item.find('='); + if (pos == std::string::npos) { + fprintf(stderr, "Invalid custom quantization input %s\n", arg.c_str()); + return false; + } + auto pattern = item.substr(0, pos); + auto type_as_string = item.substr(pos + 1); + auto type = parse_ggml_type(type_as_string.c_str()); + if (type == GGML_TYPE_COUNT) { + fprintf(stderr, "Invalid quantization type '%s' in custom quantization input %s\n", type_as_string.c_str(), item.c_str()); + return false; + } + printf("Adding custom rule %s -> %s\n", pattern.c_str(), ggml_type_name(type)); + custom_quants.emplace_back(std::move(pattern), type); + } + return true; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -302,12 +337,26 @@ int main(int argc, char ** argv) { std::string imatrix_file; std::vector included_weights, excluded_weights; std::vector kv_overrides; + std::vector custom_quants; + + std::vector repack_patterns; + + bool hide_imatrix = false; for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { params.quantize_output_tensor = false; } else if (strcmp(argv[arg_idx], "--ignore-imatrix-rules") == 0) { params.ignore_imatrix_rules = true; + } else if (strcmp(argv[arg_idx], "--repack") == 0) { + params.only_repack = true; + } else if (strcmp(argv[arg_idx], "--repack-pattern") == 0) { + if (arg_idx < argc-1) { + auto p = string_split(argv[++arg_idx], ','); + repack_patterns.insert(repack_patterns.end(), p.begin(), p.end()); + } else { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--output-tensor-type") == 0) { if (arg_idx < argc-1) { params.output_tensor_type = parse_ggml_type(argv[++arg_idx]); @@ -372,6 +421,10 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--custom-q") == 0) { + if (arg_idx == argc-1 || !parse_custom_quants(argv[++arg_idx], custom_quants)) { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) { params.allow_requantize = true; } else if (strcmp(argv[arg_idx], "--pure") == 0) { @@ -382,6 +435,8 @@ int main(int argc, char ** argv) { } else { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--hide-imatrix") == 0) { + hide_imatrix = true; } else if (strcmp(argv[arg_idx], "--include-weights") == 0) { if (arg_idx < argc-1) { included_weights.emplace_back(argv[++arg_idx]); @@ -401,6 +456,10 @@ int main(int argc, char ** argv) { } } + if (!repack_patterns.empty()) { + params.repack_pattern = &repack_patterns; + } + if (argc - arg_idx < 2) { printf("%s: bad arguments\n", argv[0]); usage(argv[0]); @@ -418,7 +477,11 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - strncpy(kvo.val_str, imatrix_file.c_str(), 127); + if (hide_imatrix) { + strncpy(kvo.val_str, "top_secret", 127); + } else { + strncpy(kvo.val_str, imatrix_file.c_str(), 127); + } kvo.val_str[127] = '\0'; kv_overrides.emplace_back(std::move(kvo)); } @@ -426,7 +489,11 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - strncpy(kvo.val_str, imatrix_dataset.c_str(), 127); + if (hide_imatrix) { + strncpy(kvo.val_str, "top_secret", 127); + } else { + strncpy(kvo.val_str, imatrix_dataset.c_str(), 127); + } kvo.val_str[127] = '\0'; kv_overrides.emplace_back(std::move(kvo)); } @@ -435,7 +502,11 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.val_i64 = imatrix_data.size(); + if (hide_imatrix) { + kvo.val_i64 = 0; + } else { + kvo.val_i64 = imatrix_data.size(); + } kv_overrides.emplace_back(std::move(kvo)); } @@ -443,7 +514,11 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.val_i64 = m_last_call; + if (hide_imatrix) { + kvo.val_i64 = 0; + } else { + kvo.val_i64 = m_last_call; + } kv_overrides.emplace_back(std::move(kvo)); } } @@ -452,6 +527,9 @@ int main(int argc, char ** argv) { kv_overrides.back().key[0] = 0; params.kv_overrides = &kv_overrides; } + if (!custom_quants.empty()) { + params.custom_quants = &custom_quants; + } llama_backend_init(); diff --git a/examples/sweep-bench/CMakeLists.txt b/examples/sweep-bench/CMakeLists.txt new file mode 100644 index 00000000..e49f0fea --- /dev/null +++ b/examples/sweep-bench/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-sweep-bench) +add_executable(${TARGET} sweep-bench.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/sweep-bench/README.md b/examples/sweep-bench/README.md new file mode 100644 index 00000000..d92740de --- /dev/null +++ b/examples/sweep-bench/README.md @@ -0,0 +1,65 @@ +# ik_llama.cpp/example/sweep-bench + +Benchmark the prompt processing and token generation performance of `ik_llama.cpp` +by doing a sweep over a whole context size and gathering performance metrics +in each ubatch-sized window. Only a single token sequence is used. + +The benchmark steps are: + +for each ubatch-sized window in context: + + 1. generate ubatch/4 tokens (not the whole window to save some time) + 2. measure generation performance + 3. remove generated tokens from KV cache + 4. prepare a ubatch-sized batch of random tokens + 4. process prepated batch + 5. measure prompt processing performance + +The purpose of the benchmark is to visualize how the performance changes with +the context size without averaging the metrics values over the whole context. + +## Usage + +./llama-sweep-bench -c 8704 -ub 512 -m models/Meta-Llama-3.2-3B-Instruct-Q8_0.gguf + +## Sample results + +- `PP` - prompt tokens per ubatch +- `TG` - generated tokens per ubatch +- `N_KV` - current KV cache size +- `T_PP` - prompt processing time (i.e. time to first token) +- `S_PP` - prompt processing speed (`(B*PP)/T_PP` or `PP/T_PP`) +- `T_TG` - time to generate all batches +- `S_TG` - text generation speed (`(B*TG)/T_TG`) + +| PP | TG | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | +|-------|--------|--------|----------|----------|----------|----------| +| 512 | 128 | 0 | 1.100 | 465.51 | 2.311 | 55.38 | +| 512 | 128 | 512 | 1.183 | 432.97 | 1.895 | 67.55 | +| 512 | 128 | 1024 | 1.305 | 392.38 | 2.071 | 61.81 | +| 512 | 128 | 1536 | 1.279 | 400.42 | 2.164 | 59.14 | +| 512 | 128 | 2048 | 1.571 | 325.96 | 2.280 | 56.14 | +| 512 | 128 | 2560 | 1.431 | 357.87 | 2.418 | 52.94 | +| 512 | 128 | 3072 | 1.515 | 337.93 | 2.566 | 49.88 | +| 512 | 128 | 3584 | 1.588 | 322.34 | 2.722 | 47.03 | +| 512 | 128 | 4096 | 1.675 | 305.70 | 2.864 | 44.69 | +| 512 | 128 | 4608 | 1.769 | 289.50 | 2.999 | 42.68 | +| 512 | 128 | 5120 | 1.845 | 277.48 | 3.102 | 41.26 | +| 512 | 128 | 5632 | 1.893 | 270.46 | 3.219 | 39.76 | +| 512 | 128 | 6144 | 1.953 | 262.20 | 3.348 | 38.23 | +| 512 | 128 | 6656 | 2.018 | 253.71 | 3.474 | 36.84 | +| 512 | 128 | 7168 | 2.078 | 246.34 | 3.589 | 35.66 | +| 512 | 128 | 7680 | 2.140 | 239.22 | 3.717 | 34.43 | +| 512 | 128 | 8192 | 2.196 | 233.15 | 3.854 | 33.21 | + +### JSONL output + +Pass `--output-format jsonl` to output JSONL instead of Markdown, á la + +```json lines +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 0, "t_pp": 1.093814, "speed_pp": 468.086884, "t_tg": 1.780312, "speed_tg": 71.897514 } +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 512, "t_pp": 1.169302, "speed_pp": 437.868073, "t_tg": 1.897474, "speed_tg": 67.458099 } +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 1024, "t_pp": 1.183700, "speed_pp": 432.542053, "t_tg": 2.059179, "speed_tg": 62.160694 } +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 1536, "t_pp": 1.428625, "speed_pp": 358.386566, "t_tg": 2.160639, "speed_tg": 59.241734 } +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 2048, "t_pp": 1.360647, "speed_pp": 376.291595, "t_tg": 2.274003, "speed_tg": 56.288403 } +``` diff --git a/examples/sweep-bench/sweep-bench-plot.py b/examples/sweep-bench/sweep-bench-plot.py new file mode 100755 index 00000000..481a604c --- /dev/null +++ b/examples/sweep-bench/sweep-bench-plot.py @@ -0,0 +1,118 @@ +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('file', nargs='+') +args = parser.parse_args() + +df = None + +#for jsonl_file in args.file: +# # Read JSONL file into DataFrame +# df_part = pd.read_json(jsonl_file, lines=True) +# df_part['label'] = jsonl_file +# if df is None: +# df = df_part +# else: +# df = pd.concat([df, df_part]) +# + + + +for md_file in args.file: + # Read markdown table file into DataFrame + df_part = pd.read_csv(md_file, sep=r'\s*\|\s*', engine='python', + header=0, skiprows=[1]) + + # Clean up columns (remove empty columns from markdown formatting) + df_part = df_part.iloc[:, 1:-1] + df_part.columns = [col.strip() for col in df_part.columns] + + # Rename columns to match expected names + df_part = df_part.rename(columns={ + 'N_KV': 'n_kv', + 'S_PP t/s': 'speed_pp', + 'S_TG t/s': 'speed_tg' + }) + + # Convert to numeric types + df_part['n_kv'] = pd.to_numeric(df_part['n_kv']) + df_part['speed_pp'] = pd.to_numeric(df_part['speed_pp']) + df_part['speed_tg'] = pd.to_numeric(df_part['speed_tg']) + + # Add label and append to main DataFrame + df_part['label'] = md_file + df = pd.concat([df, df_part]) if df is not None else df_part + +# Group by label and n_kv, calculate mean and std for both speed metrics +df_grouped = df.groupby(['label', 'n_kv']).agg({ + 'speed_pp': ['mean', 'std'], + 'speed_tg': ['mean', 'std'] +}).reset_index() + +# Flatten multi-index columns +df_grouped.columns = ['label', 'n_kv', 'speed_pp_mean', 'speed_pp_std', + 'speed_tg_mean', 'speed_tg_std'] + +# Replace NaN with 0 (std for a single sample is NaN) +df_grouped['speed_pp_std'] = df_grouped['speed_pp_std'].fillna(0) +df_grouped['speed_tg_std'] = df_grouped['speed_tg_std'].fillna(0) + +# Prepare ticks values for X axis (prune for readability) +x_ticks = df['n_kv'].unique() +while len(x_ticks) > 16: + x_ticks = x_ticks[::2] + +# Get unique labels and color map +labels = df_grouped['label'].unique() +colors = plt.cm.rainbow(np.linspace(0, 1, len(labels))) + +# Create prompt processing plot +plt.figure(figsize=(10, 6)) +ax1 = plt.gca() +plt.grid() +ax1.set_xticks(x_ticks) + +# Plot each label's data +for label, color in zip(labels, colors): + label_data = df_grouped[df_grouped['label'] == label].sort_values('n_kv') + pp = ax1.errorbar(label_data['n_kv'], label_data['speed_pp_mean'], + yerr=label_data['speed_pp_std'], color=color, + marker='o', linestyle='-', label=label) + +# Add labels and title +ax1.set_xlabel('Context Length (tokens)') +ax1.set_ylabel('Prompt Processing Rate (t/s)') +plt.title('Prompt Processing Performance Comparison') +ax1.legend(loc='upper right') + +# Adjust layout and save +plt.tight_layout() +plt.savefig('performance_comparison_pp.png', bbox_inches='tight') +plt.close() + +# Create token generation plot +plt.figure(figsize=(10, 6)) +ax1 = plt.gca() +plt.grid() +ax1.set_xticks(x_ticks) + +# Plot each model's data +for label, color in zip(labels, colors): + label_data = df_grouped[df_grouped['label'] == label].sort_values('n_kv') + tg = ax1.errorbar(label_data['n_kv'], label_data['speed_tg_mean'], + yerr=label_data['speed_tg_std'], color=color, + marker='s', linestyle='-', label=label) + +# Add labels and title +ax1.set_xlabel('Context Length (n_kv)') +ax1.set_ylabel('Token Generation Rate (t/s)') +plt.title('Token Generation Performance Comparison') +ax1.legend(loc='upper right') + +# Adjust layout and save +plt.tight_layout() +plt.savefig('performance_comparison_tg.png', bbox_inches='tight') +plt.close() diff --git a/examples/sweep-bench/sweep-bench.cpp b/examples/sweep-bench/sweep-bench.cpp new file mode 100644 index 00000000..27510687 --- /dev/null +++ b/examples/sweep-bench/sweep-bench.cpp @@ -0,0 +1,189 @@ +#include "ggml.h" +#include "llama.h" +#include "common.h" +#include "llama-vocab.h" + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + +#include +#include +#include +#include +#include + +static void print_usage(int, char ** argv) { + LOG_TEE("\nexample usage:\n"); + LOG_TEE("\n %s -m model.gguf -c 8192 -b 2048 -ub 512\n", argv[0]); + LOG_TEE("\n"); +} + +int main(int argc, char ** argv) { + + gpt_params params; + + if (!gpt_params_parse(argc, argv, params)) { + print_usage(argc, argv); + return 1; + } + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = llama_model_params_from_gpt_params(params); + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + llama_context_params ctx_params = llama_context_params_from_gpt_params(params); + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + const unsigned int n_kv_max = llama_n_ctx(ctx); + + + const llama_vocab * vocab = llama_get_vocab(ctx); + llama_token bos = llama_token_bos_impl(*vocab); + //llama_token eos = llama_token_eos_impl(*vocab); + + const unsigned int n_vocab = llama_n_vocab(model); + + // decode in batches of ctx_params.n_batch tokens + auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + return false; + } + + llama_synchronize(ctx); + } + + return true; + }; + + const unsigned int pp = params.n_ubatch; + const unsigned int tg = params.n_ubatch / 4; + + if (!params.sweep_bench_output_jsonl) { + LOG_TEE("\n"); + LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG_TEE("\n"); + LOG_TEE("|%6s | %6s | %6s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s"); + LOG_TEE("|%6s-|-%6s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "------", "--------", "--------", "--------", "--------"); + } + + llama_batch batch = llama_batch_init(n_kv_max, 0, 1); + + // warm up + { + llama_batch_add(batch, bos, 0, { 0 }, false); + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + } + + llama_batch_clear(batch); + llama_kv_cache_clear(ctx); + + for (unsigned int n_kv = 0; n_kv < n_kv_max; n_kv += params.n_ubatch) { + // clean up KV cache before generation + llama_kv_cache_seq_rm(ctx, 0, n_kv, -1); + + // first measure token generation performance at this context size + const auto t_tg_start = ggml_time_us(); + + for (unsigned int i = 0; i < tg; ++i) { + llama_batch_clear(batch); + llama_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true); + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + } + + const auto t_tg_end = ggml_time_us(); + + // clean up KV cache after generation + llama_kv_cache_seq_rm(ctx, 0, n_kv, -1); + + // prepare batch of pp size for prompt processing performance measurement + llama_batch_clear(batch); + + for (unsigned int i = 0; i < pp; ++i) { + llama_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, false); + } + batch.logits[batch.n_tokens - 1] = true; + + // measure prompt processing performance + const auto t_pp_start = ggml_time_us(); + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + const auto t_pp_end = ggml_time_us(); + + // calculate and print metrics + const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f; + const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f; + + const float speed_pp = pp / t_pp; + const float speed_tg = tg / t_tg; + + if(params.sweep_bench_output_jsonl) { + LOG_TEE( + "{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, " + "\"pp\": %d, \"tg\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f }\n", + n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch, + pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg + ); + } else { + LOG_TEE("|%6d | %6d | %6d | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg); + } + } + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6775fdcb..70e3bbf3 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -118,6 +118,7 @@ option(GGML_MUSA "ggml: use MUSA" option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF) option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) +option(GGML_CUDA_IQK_FORCE_BF16 "ggml: use bf16 cuBLAS when no MMQ kernel is available" OFF) set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 435bbae8..64a9f5c8 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The ggml authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once // @@ -396,8 +403,9 @@ extern "C" { // GGML_TYPE_I2_S = 36, // - GGML_TYPE_Q8_0_X4 = 98, - GGML_TYPE_Q8_1_X4 = 99, + GGML_TYPE_Q8_0_X4 = 97, + GGML_TYPE_Q8_1_X4 = 98, + GGML_TYPE_Q8_2_X4 = 99, GGML_TYPE_Q6_0 = 133, GGML_TYPE_IQ1_BN = 134, GGML_TYPE_IQ2_BN = 135, @@ -416,9 +424,10 @@ extern "C" { GGML_TYPE_Q8_K32 = 148, GGML_TYPE_Q8_KR8 = 149, GGML_TYPE_Q8_K128 = 150, - GGML_TYPE_IQ2_KT = 151, - GGML_TYPE_IQ3_KT = 152, - GGML_TYPE_IQ4_KT = 153, + GGML_TYPE_Q8_KV = 151, + GGML_TYPE_IQ2_KT = 152, + GGML_TYPE_IQ3_KT = 153, + GGML_TYPE_IQ4_KT = 154, GGML_TYPE_Q4_0_R8 = 202, GGML_TYPE_Q5_0_R4 = 206, @@ -445,6 +454,7 @@ extern "C" { GGML_TYPE_IQ4_K_R4 = 339, GGML_TYPE_IQ5_K_R4 = 340, GGML_TYPE_IQ4_KS_R4 = 344, + GGML_TYPE_Q8_KV_R8 = 398, GGML_TYPE_Q8_K_R8 = 399, GGML_TYPE_COUNT, }; @@ -504,9 +514,10 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ2_KT = 140, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ3_KT = 141, // except 1d tensors - GGML_FTYPE_MOSTLY_IQ4_KT = 142, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_KV = 140, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_KT = 141, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_KT = 142, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_KT = 143, // except 1d tensors // GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors @@ -533,6 +544,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors }; @@ -569,6 +581,7 @@ extern "C" { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, GGML_OP_OUT_PROD, + GGML_OP_MOE_FUSED_UP_GATE, GGML_OP_SCALE, GGML_OP_SET, @@ -598,6 +611,7 @@ extern "C" { GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, + GGML_OP_ARGSORT_THRESH, GGML_OP_LEAKY_RELU, GGML_OP_SOFTCAP, GGML_OP_SOFT_CAP_MAX, @@ -1322,6 +1336,15 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * ids); + // MoE up + gate + unary + GGML_API struct ggml_tensor * ggml_moe_up_gate( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + enum ggml_unary_op op); + // A: m columns, n rows, // B: p columns, n rows, // result is m columns, p rows @@ -1905,6 +1928,12 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); + GGML_API struct ggml_tensor * ggml_argsort_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int min_entries, + float threshold); + GGML_API struct ggml_tensor * ggml_arange( struct ggml_context * ctx, float start, @@ -1916,6 +1945,12 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, int k); + GGML_API struct ggml_tensor * ggml_top_k_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k, + int min_entries, + float thresh); #define GGML_KQ_MASK_PAD 32 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 3d1a2970..74ac5374 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -254,11 +254,12 @@ if (GGML_BLAS) endif() set (GGML_SOURCES_IQK iqk/iqk_quantize.cpp) +set (GGML_HEADERS_IQK iqk/iqk_config.h) if (GGML_IQK_MUL_MAT) message(STATUS "Using optimized iqk matrix multiplications") add_compile_definitions(GGML_USE_IQK_MULMAT) - set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp) - set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h) + set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp) + set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h) if (GGML_IQK_FA_ALL_QUANTS) message(STATUS "Including all IQK FA kernels") add_compile_definitions(GGML_IQK_FA_ALL_QUANTS) @@ -296,10 +297,12 @@ if (GGML_CUDA) # 60 == FP16 CUDA intrinsics # 61 == integer CUDA intrinsics # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") + if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + set(CMAKE_CUDA_ARCHITECTURES "native") + elseif (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") + set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80") #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work endif() endif() @@ -318,6 +321,8 @@ if (GGML_CUDA) list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu") file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-mma*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) @@ -361,6 +366,10 @@ if (GGML_CUDA) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() + if (GGML_CUDA_IQK_FORCE_BF16) + add_compile_definitions(GGML_CUDA_IQK_FORCE_BF16) + endif() + if (GGML_CUDA_FORCE_CUBLAS) add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) endif() @@ -416,6 +425,11 @@ if (GGML_CUDA) set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ... endif() endif() + if (NOT GGML_MUSA) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_LIBRARIES 0) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS 0) + endif() else() message(WARNING "CUDA not found") endif() @@ -1074,7 +1088,7 @@ if (NOT MSVC) endif() endif() -set(ARCH_FLAGS "") +set(ARCH_FLAGS ${GGML_ARCH_FLAGS}) if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR @@ -1135,6 +1149,10 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR if (GGML_SVE) list(APPEND ARCH_FLAGS -march=armv8.6-a+sve) endif() + if (CMAKE_CXX_COMPILER_ID MATCHES "GNU") + # else we fail on Gravitons and such + list(APPEND ARCH_FLAGS -flax-vector-conversions) + endif() endif() elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND @@ -1324,7 +1342,7 @@ add_library(ggml ${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS} ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE} ${GGML_SOURCES_IQK_MM} ${GGML_HEADERS_IQK_MM} - ${GGML_SOURCES_IQK} + ${GGML_SOURCES_IQK} ${GGML_HEADERS_IQK} ${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN} ggml-aarch64.c ggml-aarch64.h ) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index e485326a..3f2d2023 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -174,6 +174,8 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz // this should never happen fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", __func__, size, max_avail); + fprintf(stderr, "%s: tensor was %s with %g elements and %zu bytes\n", __func__, tensor->name, + 1.*ggml_nelements(tensor), ggml_nbytes(tensor)); GGML_ABORT("not enough space in the buffer"); } } diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index e1651cc6..fd538f50 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -9,6 +9,7 @@ #include #include +#define IK_PRINT_TIMING 0 #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -229,7 +230,17 @@ GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * return; } + +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif buf->iface.set_tensor(buf, tensor, data, offset, size); +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + //printf("%s(%s) %zu %d us\n", __func__, tensor->name, size, (int)(tim2-tim1)); + printf("%s(%s): %d us\n", __func__, tensor->name, (int)(tim2-tim1)); +#endif + } GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -243,7 +254,15 @@ GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * return; } +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif buf->iface.get_tensor(buf, tensor, data, offset, size); +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + //printf("%s(%s) %zu %d us\n", __func__, tensor->name, size, (int)(tim2-tim1)); + printf("%s(%s): %d us\n", __func__, tensor->name, (int)(tim2-tim1)); +#endif } void ggml_backend_synchronize(ggml_backend_t backend) { @@ -824,7 +843,8 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const op->type != GGML_TYPE_IQ1_S && op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: - return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; + return true; + //return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; default: return true; } @@ -1751,7 +1771,11 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { struct ggml_backend_sched_split * splits = sched->splits; + for (int i = 0; i < sched->n_splits; i++) { +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif struct ggml_backend_sched_split * split = &splits[i]; int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; @@ -1792,6 +1816,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } if (!sched->callback_eval) { +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + printf("%s(.1.): %d us\n", __func__, (int)(tim2-tim1)); +#endif enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); if (ec != GGML_STATUS_SUCCESS) { return ec; @@ -1814,6 +1842,11 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + printf("%s(.2.): %d us\n", __func__, (int)(tim2-tim1)); +#endif + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv); if (ec != GGML_STATUS_SUCCESS) { return ec; diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 39f3b270..7b2b2ad8 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -1,6 +1,6 @@ // -// Copyright (C) 2023-2024 The ggml authors // Copyright (C) 2024 Iwan Kawrakow +// Copyright (C) 2023-2024 The ggml authors // MIT license // SPDX-License-Identifier: MIT // @@ -266,6 +266,20 @@ typedef struct { } block_q8_0x8; static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); +#define QK8_2 32 +typedef struct { + uint16_t d; + uint16_t s; + int8_t qs[QK8_2]; // quants +} block_q8_2; +static_assert(sizeof(block_q8_2) == sizeof(ggml_half) + sizeof(int16_t) + QK8_2, "wrong q8_2 block size/padding"); + +typedef struct { + uint16_t d[8]; + int8_t qs[4*QK8_2]; +} block_q8_2_x4; +static_assert(sizeof(block_q8_2_x4) == 4*sizeof(block_q8_2), "wrong q8_2_x4 block size/padding"); + // // Super-block quantization structures // diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 61ccba23..21384217 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1,3 +1,9 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// #include "ggml-cuda.h" #include "ggml.h" #include "ggml-backend-impl.h" @@ -50,6 +56,8 @@ #include #include +#define IK_PRINT_TIMING 0 + static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); static void ggml_cuda_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) { @@ -1262,6 +1270,47 @@ static void ggml_cuda_op_mul_mat_cublas( return; } +#ifdef GGML_CUDA_IQK_FORCE_BF16 + if (ggml_is_quantized(src0->type) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src0->type); + if (to_bf16_cuda) { + size_t ne = row_diff*ne00; + ggml_cuda_pool_alloc src0_as_bf16(ctx.pool(id), ne); + to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff, ne00, stream); + + ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); + if (src1->type != GGML_TYPE_BF16) { + size_t ne = src1_ncols*ne10; + src1_as_bf16.alloc(ne); + to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type); + GGML_ASSERT(to_bf16_cuda != nullptr); + to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols, ne10, stream); + } + const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get(); + const nv_bfloat16 * src0_ptr = src0_as_bf16.get(); + + ggml_cuda_pool_alloc dst_bf16(ctx.pool(id), row_diff*src1_ncols); + + const float alpha_f32 = 1.0f; + const float beta_f32 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f32, src0_ptr, CUDA_R_16BF, ne00, + src1_ptr, CUDA_R_16BF, ne10, + &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff, src1_ncols, stream); + return; + } + } +#endif + if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); @@ -1516,6 +1565,8 @@ static void ggml_cuda_op_mul_mat( } } + bool quantization_done = false; + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) { continue; @@ -1559,9 +1610,15 @@ static void ggml_cuda_op_mul_mat( } dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size); - if (src1_on_device && src1_is_contiguous) { - quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); + if (src1_on_device && (src1_is_contiguous || (src1->ne[1] == 1 && src1->ne[3] == 1 && src1->nb[0] == sizeof(float)))) { + if (src1_is_contiguous) { + quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); + } else { + //printf("Calling quantize_tensor_q8_1_cuda for %s\n", src0->name); + quantize_tensor_q8_1_cuda(src1, dev[id].src1_ddq, src0->type, stream); + } CUDA_CHECK(cudaGetLastError()); + quantization_done = true; } } @@ -1581,6 +1638,20 @@ static void ggml_cuda_op_mul_mat( } const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; + if (!(split && used_devices > 1) && quantization_done && ne11 == 1 && ne12 > 1 && ne13 == 1) { + //printf("invoking fast path for %s x %s\n", src0->name, src1->name); + int id = ctx.device; + char * src0_dd_i = dev[id].src0_dd; + float * src1_ddf_i = dev[id].src1_ddf; + char * src1_ddq_i = dev[id].src1_ddq; + float * dst_dd_i = dev[id].dst_dd; + cudaStream_t stream = ctx.stream(id, 0); + ggml_cuda_op_mul_mat_vec_q_3D(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, + dev[id].row_low, dev[id].row_high, ne11, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + return; + } + for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0; const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; @@ -1647,13 +1718,17 @@ static void ggml_cuda_op_mul_mat( } } } else if (src1_on_device && !src1_is_contiguous) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d( - src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + if (!quantization_done) { + //printf("Copying %s\n", src1->name); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d( + src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + } } else { GGML_ABORT("fatal error"); } - if (quantize_src1 && !src1_is_contiguous) { + if (quantize_src1 && !src1_is_contiguous && !quantization_done) { + //printf("Quantizing %s\n", src1->name); quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } @@ -1737,6 +1812,93 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } +/* +static void ggml_cuda_op_gemv_id( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src0_ids, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, + quantize_cuda_t quantize_src1) { + + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_nrows(src1) == 1); + GGML_ASSERT(src0_ids->ne[1] == 1); + GGML_ASSERT(src0_ids->ne[0] <= dst->ne[2]); + GGML_ASSERT(dst->ne[1] == 1); + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + + GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer)); + + ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context; + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + + int device_id = ctx.device; + GGML_ASSERT(src0_ctx->device == device_id); + GGML_ASSERT(src1_ctx->device == device_id); + GGML_ASSERT(dst_ctx->device == device_id); + + const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); + GGML_ASSERT(!split); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne10 = src1->ne[0]; + const int64_t nrows1 = 1; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne2 = dst->ne[2]; + + const int64_t nb2 = dst->nb[2]; + + // Why? + GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); + + const size_t src0_rs = ggml_row_size(src0->type, ne00); + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; + + const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + ggml_cuda_pool_alloc src0_dd_alloc; + ggml_cuda_pool_alloc src1_ddf_alloc; + ggml_cuda_pool_alloc src1_ddq_alloc; + ggml_cuda_pool_alloc dst_dd_alloc; + + char * src0_dd = nullptr; + float * src1_ddf = (float *)src1->data; + char * src1_ddq = nullptr; // q8_1 + float * dst_dd = (float *)dst->data; + + bool quantization_done = false; + + const bool src1_on_device = device_id == src1_ctx->device; + const bool dst_on_device = device_id == dst_ctx->device; + + ggml_cuda_set_device(device_id); + cudaStream_t stream = ctx.stream(device_id, 0); + + src0_dd = (char *) src0->data; + + if (quantize_src1) { + size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; + src1_ddq = src1_ddq_alloc.alloc(ctx.pool(device_id), src_1_ddq_size); + quantize_src1(src1_ddf, src1_ddq, ne10, 1, 1, src1_padded_col_size, src0->type, stream); + } + + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, src1, src0_ids, dst, + (const char *)src0->data, (const float *)src1->data, src1_ddq, (float *)dst->data, + 0, ne01, 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + +} +*/ + static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); @@ -2009,35 +2171,19 @@ struct mmid_row_mapping { int32_t i2; }; -static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, - int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, - const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, - int64_t ne11, int64_t ne10, - size_t nb11, size_t nb12) { - int32_t iid1 = blockIdx.x; - int32_t id = blockIdx.y; +static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_original, char * __restrict__ src_contiguous, + const mmid_row_mapping * __restrict__ row_mapping, + int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) { + int32_t i = blockIdx.x; - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + const int32_t i11 = row_mapping[i].i1 % ne11; + const int32_t i12 = row_mapping[i].i2; - if (row_id_i != i02) { - return; - } + float * src_row_contiguous = (float *)(src_contiguous + i*nb11); + const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12); - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; - - __shared__ int src1_row; - if (threadIdx.x == 0) { - src1_row = atomicAdd(cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - __syncthreads(); - - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); - float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); - - for (int i = threadIdx.x; i < ne10; i += blockDim.x) { - src1_row_contiguous[i] = src1_row_original[i]; + for (int j = threadIdx.x; j < ne10; j += blockDim.x) { + src_row_contiguous[j] = src_row_original[j]; } } @@ -2058,11 +2204,103 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin } } +static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids, + const ggml_tensor * ids, std::vector& moe_counts, std::vector& cum_moe_counts, + ggml_cuda_pool_alloc& dev_row_mapping) { + + GGML_ASSERT(moe_counts.empty() && cum_moe_counts.empty()); + + auto stream = ctx.stream(); + + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + std::vector rmapping(ids->ne[1]*n_ids); + moe_counts.resize(n_as, 0); + cum_moe_counts.resize(n_as + 1); + + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i]; + } + } + cum_moe_counts[0] = 0; + for (int i = 0; i < (int)n_as; ++i) { + cum_moe_counts[i+1] = cum_moe_counts[i] + moe_counts[i]; + } + + dev_row_mapping.alloc(cum_moe_counts[n_as]); + + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + if (row_id_i >= 0 && row_id_i < n_as) { + rmapping[cum_moe_counts[row_id_i]++] = {(int)id, (int)iid1}; + } + } + } + + for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i]; + + CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + +} + static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; + if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 && + ggml_is_quantized(src0->type) && + ggml_backend_buffer_is_cuda(src0->buffer) && + ggml_backend_buffer_is_cuda(src1->buffer) && + ggml_backend_buffer_is_cuda(dst->buffer) && + !ggml_backend_buffer_is_cuda_split(src0->buffer) && + src1->type == GGML_TYPE_F32) { + int device_id = ctx.device; + ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context; + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + if (src0_ctx->device == device_id && + src1_ctx->device == device_id && + dst_ctx->device == device_id) { + GGML_ASSERT(src1->ne[0] % QK8_1 == 0); + // Fast TG path + const int64_t n_ids = ids->ne[0]; + auto stream = ctx.stream(device_id, 0); + + auto local_dst = *dst; + local_dst.ne[2] = n_ids; + local_dst.ne[1] = local_dst.ne[3] = 1; + local_dst.nb[2] = local_dst.nb[1]; + + auto local_src1 = *src1; + local_src1.nb[2] = local_src1.nb[3] = 0; + + const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc src1_quantized(ctx.pool()); + auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1; + local_src1.data = src1_quantized.alloc(src_1_ddq_size); + quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size, + src0->type, stream); + CUDA_CHECK(cudaGetLastError()); + + local_src1.nb[1] = src_1_ddq_size; + + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, &local_src1, ids, &local_dst, + (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data, + 0, src0->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + return; + } + } + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); @@ -2072,11 +2310,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; - std::vector ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; ggml_tensor dst_row = *dst; @@ -2102,11 +2335,15 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.nb[3] = nb1; if (ne12 == 1) { + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { for (int64_t id = 0; id < n_ids; id++) { const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; const int64_t i11 = id % ne11; const int64_t i12 = iid1; @@ -2122,6 +2359,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } } else { + + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); + std::vector moe_counts, cum_moe_counts; + prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); + ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_cuda_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); @@ -2129,39 +2371,20 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.data = dst_contiguous.get(); for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); - - if (row_id_i != i02) { - continue; - } - - num_src1_rows++; - } - } + int64_t num_src1_rows = moe_counts[i02]; if (num_src1_rows == 0) { continue; } - ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); + size_t mapping_offset = cum_moe_counts[i02]; { dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(ids->ne[1], n_ids); - k_copy_src1_to_contiguous<<>>( - src1_original, src1_contiguous.get(), - dev_cur_src1_row.get(), dev_row_mapping.get(), - ids_dev, i02, ids->nb[1], ids->nb[0], - ne11, ne10, - nb11, nb12); + dim3 grid_dims(num_src1_rows); + k_copy_src_to_contiguous<<>>( + src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); CUDA_CHECK(cudaGetLastError()); } @@ -2187,7 +2410,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dim3 grid_dims(num_src1_rows); k_copy_dst_from_contiguous<<>>( dst_original, dst_contiguous.get(), - dev_row_mapping.get(), + dev_row_mapping.get() + mapping_offset, ne0, nb1, nb2); CUDA_CHECK(cudaGetLastError()); @@ -2196,12 +2419,352 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } -static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { +static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { + const ggml_tensor * src0_1 = dst->src[0]; + const ggml_tensor * src0_2 = dst->src[1]; + const ggml_tensor * src0 = src0_1; + const ggml_tensor * src1 = dst->src[2]; + const ggml_tensor * ids = dst->src[3]; + + if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 && + ggml_is_quantized(src0_1->type) && + ggml_is_quantized(src0_2->type) && + ggml_backend_buffer_is_cuda(src0_1->buffer) && + ggml_backend_buffer_is_cuda(src0_2->buffer) && + ggml_backend_buffer_is_cuda(src1->buffer) && + ggml_backend_buffer_is_cuda(dst->buffer) && + !ggml_backend_buffer_is_cuda_split(src0_1->buffer) && + !ggml_backend_buffer_is_cuda_split(src0_2->buffer) && + src1->type == GGML_TYPE_F32) { + int device_id = ctx.device; + ggml_backend_cuda_buffer_context * src0_1_ctx = (ggml_backend_cuda_buffer_context *) src0_1->buffer->context; + ggml_backend_cuda_buffer_context * src0_2_ctx = (ggml_backend_cuda_buffer_context *) src0_2->buffer->context; + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + if (src0_1_ctx->device == device_id && + src0_2_ctx->device == device_id && + src1_ctx->device == device_id && + dst_ctx->device == device_id) { + // Fast TG path + const int64_t n_ids = ids->ne[0]; + auto stream = ctx.stream(device_id, 0); + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids); + + auto local_dst = *dst; + local_dst.ne[2] = n_ids; + local_dst.ne[1] = local_dst.ne[3] = 1; + local_dst.nb[1] = local_dst.nb[2] = local_dst.nb[3] = local_dst.ne[0]*sizeof(float); + + auto local_src1 = *src1; + local_src1.nb[2] = local_src1.nb[3] = 0; + + const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc src1_quantized(ctx.pool()); + if (ggml_is_quantized(src0_1->type) || ggml_is_quantized(src0_2->type)) { + GGML_ASSERT(src1->ne[0] % QK8_1 == 0); + auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1; + local_src1.data = src1_quantized.alloc(src_1_ddq_size); + // Note: no use is currently made of the quantization type passed into quantize_row_q8_1_cuda. + // If that were to change, we would need to adjust the code to handle src0_1->type != src0_2->type + quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size, + src0_1->type, stream); + CUDA_CHECK(cudaGetLastError()); + + local_src1.nb[1] = src_1_ddq_size; + } + + local_dst.data = dst_up_contiguous.get(); + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst, + (const char *)src0_1->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_up_contiguous.get(), + 0, src0_1->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + local_dst.data = dst_gate_contiguous.get(); + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst, + (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(), + 0, src0_2->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && + ggml_backend_buffer_is_cuda(next->src[0]->buffer) && + !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && + ((ggml_backend_cuda_buffer_context *)next->src[0]->buffer->context)->device == device_id && + ggml_backend_buffer_is_cuda(next->buffer) && + ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) { + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids, + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); + + const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING); + GGML_ASSERT(dst->ne[0] % QK8_1 == 0); + auto dst_row_size = dst_padded_col_size*sizeof(block_q8_1)/QK8_1; + auto dst_ddq_size = n_ids*dst_row_size; + ggml_cuda_pool_alloc dst_quantized(ctx.pool(), dst_ddq_size); + quantize_row_q8_1_cuda((const float *)dst_gate_contiguous.get(), (void *)dst_quantized.get(), dst->ne[0], n_ids, 1, + dst_padded_col_size, next->src[0]->type, stream); + CUDA_CHECK(cudaGetLastError()); + + local_dst.ne[2] = 1; + + auto local_next = *next; + local_next.ne[2] = local_next.ne[1]; + local_next.ne[1] = local_next.ne[3] = 1; + local_next.nb[2] = local_next.nb[1]; + + local_src1 = *next->src[1]; + local_src1.ne[1] = local_src1.ne[2] = local_src1.ne[3] = 1; + local_src1.nb[1] = local_src1.nb[2] = local_src1.nb[3] = dst_row_size; + + auto local_src0 = *next->src[0]; + local_src0.ne[2] = local_src0.ne[3] = 1; + + ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, + (const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data, + 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + return true; + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + CUDA_CHECK(cudaGetLastError()); + return false; + } + } + } + + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers"); + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_2->buffer) && "mul_mat_id does not support split buffers"); + + cudaStream_t stream = ctx.stream(); + + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; + + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + ggml_tensor src0_1_row = *src0_1; + ggml_tensor src0_2_row = *src0_2; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + ggml_tensor final_dst; + ggml_tensor final_src; + + char * src0_1_original = (char *) src0_1->data; + char * src0_2_original = (char *) src0_2->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + + src0_1_row.ne[2] = 1; + src0_1_row.ne[3] = 1; + src0_1_row.nb[3] = nb02; + src0_2_row.ne[2] = 1; + src0_2_row.ne[3] = 1; + src0_2_row.nb[3] = nb02; + + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + + bool fuse_down = false; + if (next && next->op == GGML_OP_MUL_MAT_ID) { + //printf("Fusing MoE down gemm\n"); + fuse_down = true; + final_dst = *next; + final_dst.ne[1] = final_dst.ne[2] = final_dst.ne[3] = 1; + final_dst.nb[2] = final_dst.nb[3] = final_dst.nb[1]; + final_src = *next->src[0]; + //printf("next->src[0]: %s, %d x %d x %d x %d and %d x %d x %d x %d\n", ggml_type_name(next->src[0]->type), + // (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], + // (int)next->src[0]->nb[0], (int)next->src[0]->nb[1], (int)next->src[0]->nb[2], (int)next->src[0]->nb[3]); + final_src.ne[2] = final_src.ne[3] = 1; + final_src.nb[3] = final_src.nb[2]; + } + + if (ne12 == 1) { + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); + if (fuse_down) { + final_dst.src[1] = &dst_row; + } + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + id*ids->nb[0]); + + if (i02 < 0 || i02 >= n_as) continue; + //GGML_ASSERT(i02 >= 0 && i02 < n_as); + + const int64_t i11 = id % ne11; + const int64_t i12 = 0; + + const int64_t i1 = id; + const int64_t i2 = i12; + + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + //dst_row.data = dst_original + i1*nb1 + i2*nb2; + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + if (fuse_down) { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); + + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + CUDA_CHECK(cudaGetLastError()); + + } else { + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); + CUDA_CHECK(cudaGetLastError()); + + } + } + } else { + ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); + if (fuse_down) { + final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); + final_dst.src[1] = &dst_row; + } + + src1_row.data = src1_contiguous.get(); + + bool first = false; //true; + + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); + std::vector moe_counts, cum_moe_counts; + + prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); + + for (int64_t i02 = 0; i02 < n_as; i02++) { + int64_t num_src1_rows = moe_counts[i02]; + + if (num_src1_rows == 0) continue; + size_t mapping_offset = cum_moe_counts[i02]; + + { + dim3 block_dims(std::min((unsigned int)ne10, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_src_to_contiguous<<>>( + src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + } + + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; + + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); + + src1_row.ne[1] = num_src1_rows; + src1_row.nb[1] = nb11; + src1_row.nb[2] = num_src1_rows*nb11; + src1_row.nb[3] = num_src1_rows*nb11; + + dst_row.ne[1] = num_src1_rows; + dst_row.nb[1] = nb1; + dst_row.nb[2] = num_src1_rows*nb1; + dst_row.nb[3] = num_src1_rows*nb1; + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); + + if (fuse_down) { + + final_dst.ne[1] = num_src1_rows; + final_dst.nb[1] = final_dst.ne[0]*sizeof(float); + final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + if (first) { + printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, + (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], + (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], + (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); + printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", + (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], + (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], + (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); + first = false; + } + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); + CUDA_CHECK(cudaGetLastError()); + + dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + (char *)next->data, final_dst_contiguous.get(), + dev_row_mapping.get() + mapping_offset, + next->ne[0], + next->nb[1], next->nb[2]); + CUDA_CHECK(cudaGetLastError()); + + } + else { + + dim3 block_dims(std::min((unsigned int)ne0, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + dst_original, dst_gate_contiguous.get(), + dev_row_mapping.get() + mapping_offset, + ne0, + nb1, nb2); + CUDA_CHECK(cudaGetLastError()); + } + } + } + + return fuse_down; +} + +static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) { // why is this here instead of mul_mat? if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); } +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif + switch (dst->op) { case GGML_OP_REPEAT: ggml_cuda_op_repeat(ctx, dst); @@ -2310,6 +2873,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_MUL_MAT_ID: ggml_cuda_mul_mat_id(ctx, dst); break; + case GGML_OP_MOE_FUSED_UP_GATE: + skip_next = ggml_cuda_up_gate_unary(ctx, dst, next); + break; case GGML_OP_SCALE: ggml_cuda_op_scale(ctx, dst); break; @@ -2358,6 +2924,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; + case GGML_OP_ARGSORT_THRESH: + ggml_cuda_op_argsort_thresh(ctx, dst); + break; case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; @@ -2371,6 +2940,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg CUDA_CHECK(err); } +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1)); +#endif + return true; } @@ -2596,7 +3170,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #endif } - if (node->op == GGML_OP_MUL_MAT_ID) { + if (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_MOE_FUSED_UP_GATE) { use_cuda_graph = false; // This node type is not supported by CUDA graph capture #ifndef NDEBUG GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); @@ -2667,6 +3241,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (!use_cuda_graph || cuda_graph_update_required) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; @@ -2681,11 +3256,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } #endif - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); + bool skip_next = false; + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next); if (!ok) { GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); + if (skip_next) ++i; } } @@ -2810,10 +3387,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_FUSED_MUL_UNARY: return ggml_is_contiguous(op->src[0]); case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + case GGML_OP_MOE_FUSED_UP_GATE: { struct ggml_tensor * a = op->src[0]; - struct ggml_tensor * b = op->src[1]; - if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { + struct ggml_tensor * b = op->op == GGML_OP_MOE_FUSED_UP_GATE ? op->src[2] : op->src[1]; + if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) { + return false; + } + //================================================================== + //if (ggml_is_quantized(a->type) && ggml_is_quantized(b->type)) { + // return false; + //} + //================================================================== + if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) { return false; } if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { @@ -2917,6 +3503,17 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return true; } + if (ggml_is_quantized(src0_type) && (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_F32)) { + return true; + } + if (ggml_is_contiguous(op->src[0]) && ggml_are_same_shape(op->src[0], op->src[1])) { + if (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F32) { + return true; + } + } + if (ggml_are_same_shape(op->src[0], op->src[1]) && op->src[0]->type == GGML_TYPE_Q8_0 && op->src[1]->type == GGML_TYPE_Q8_0) { + return true; + } return false; } break; case GGML_OP_DUP: @@ -2964,6 +3561,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: + case GGML_OP_ARGSORT_THRESH: case GGML_OP_ACC: case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: @@ -2979,6 +3577,23 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (op->src[0]->ne[0] == 128) { return true; } + if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 && + (op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) && + (op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) { + return true; + } + if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) { + return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) || + (op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0); + } + if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) { + const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc; + int gqa = op->src[0]->ne[2]/op->src[1]->ne[2]; + return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0); + } + if (op->src[1]->ne[0] > 256) { + return false; + } if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { return true; } diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 607ded85..df214082 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -1,3 +1,9 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// #include "argsort.cuh" template @@ -8,7 +14,8 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) { } template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, + int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -51,9 +58,18 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n } } - // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; + if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + __syncthreads(); + float max_val = x_row[dst_row[0]]; + if (col < ncols) { + dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1; + } + } + else { + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } } @@ -65,7 +81,8 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { +static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, + ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); @@ -77,9 +94,9 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else { GGML_ABORT("fatal error"); } @@ -100,5 +117,25 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, -1, 0.f, stream); +} + +void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + int min_experts = dst->op_params[0]; + float thresh; + memcpy(&thresh, dst->op_params + 1, sizeof(float)); + + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, GGML_SORT_ORDER_DESC, min_experts, thresh, stream); } diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 68a00154..7bbfdf4d 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -1,3 +1,11 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// #include "common.cuh" void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 5abbd43c..701b0f80 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "binbcast.cuh" static __device__ __forceinline__ float op_repeat(const float a, const float b) { @@ -248,23 +255,62 @@ static void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) { - GGML_ASSERT(src1->type == GGML_TYPE_F32); + //GGML_ASSERT(src1->type == GGML_TYPE_F32); - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); - } else { - fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, - ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + if (src1->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } + } + else if (src1->type == GGML_TYPE_F16) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const float *)src0_dd, (const half *)src1_dd, (float *)dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (float *)dst_dd, stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } + } + else { GGML_ABORT("fatal error"); } } void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); + GGML_ASSERT(dst->type == dst->src[0]->type); + if (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16) { + ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); + return; + } + auto src = dst->src[0]; + auto bs = ggml_blck_size(src->type); + auto ts = ggml_type_size(src->type); + if (src->nb[0] != ts || ts*(src->ne[0]/bs) % 2 != 0) { + fprintf(stderr, "%s: unsupported case type = %s, nb[0] = %zu, type_size = %zu\n", __func__, ggml_type_name(src->type), src->nb[0], ts); + GGML_ABORT("fatal error"); + } + auto aux_src = *src; + aux_src.type = GGML_TYPE_F16; + aux_src.ne[0] = ts*(src->ne[0]/bs)/2; + aux_src.nb[0] = 2; + auto aux_dst = *dst; + aux_dst.type = GGML_TYPE_F16; + aux_dst.ne[0] = ts*(dst->ne[0]/bs)/2; + aux_dst.nb[0] = 2; + aux_dst.src[0] = &aux_src; + ggml_cuda_op_bin_bcast>(&aux_dst, &aux_src, &aux_dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); } void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8c9a3706..91e03f29 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -46,10 +46,14 @@ #define CC_VOLTA 700 #define CC_TURING 750 #define CC_AMPERE 800 +#define CC_ADA_LOVELACE 890 #define CC_OFFSET_AMD 1000000 +#define CC_OFFSET_MTHREADS 0x0100000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA3 (CC_OFFSET_AMD + 1100) +#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < CC_OFFSET_MTHREADS) +#define GGML_CUDA_CC_IS_AMD(cc) (cc >= CC_OFFSET_AMD) #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -134,6 +138,49 @@ typedef float2 dfloat2; #define INT8_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE + +#ifdef __CUDA_ARCH_LIST__ +constexpr bool ggml_cuda_has_arch_impl(int) { + return false; +} + +template +constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) { + return arch == first || ggml_cuda_has_arch_impl(arch, rest...); +} + +constexpr bool ggml_cuda_has_arch(const int arch) { + return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__); +} + +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) { + if (cur == 0) { + GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch); + } + return cur; +} + +template +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) { + if (first <= arch && first > cur) { + return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...); + } else { + return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...); + } +} + +constexpr int ggml_cuda_highest_compiled_arch(const int arch) { + return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__); +} +#else +static int ggml_cuda_highest_compiled_arch(const int arch) { + return arch; +} +#endif // __CUDA_ARCH_LIST__ + static constexpr bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; } @@ -146,6 +193,15 @@ static constexpr bool int8_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_TURING; } +// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. +static bool new_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_TURING; +} + +static bool cp_async_available(const int cc) { + return cc < CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= CC_AMPERE; +} + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index dac10ec3..ee98bf18 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,7 +1,14 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "concat.cuh" // contiguous kernels -static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -27,7 +34,35 @@ static __global__ void concat_f32_dim0(const float * x, const float * y, float * } } -static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { +// contiguous kernels +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00, + int64_t nb02, int64_t nb12, int64_t nb2) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * nb2; + + if (nidx < ne00) { // src0 + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * nb02; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + (nidx - ne00) + + blockIdx.y * (ne0 - ne00) + + blockIdx.z * nb12; + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne01) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -53,7 +88,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float * } } -static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { +static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne02) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -81,9 +116,23 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float * static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; + if (dim == 0 && ne1 >= 65536) { + int64_t nstep = (ne1 + 32767)/32768; + for (int64_t istep = 0; istep < nstep; ++istep) { + int64_t i1 = 32768*istep; + int64_t n1 = i1 + 32768 <= ne1 ? 32768 : ne1 - i1; + dim3 gridDim(num_blocks, n1, ne2); + const float * xi = x + i1*ne00; + const float * yi = y + i1*(ne0 - ne00); + float * dst_i = dst + i1*ne0; + concat_f32_dim0<<>>(xi, yi, dst_i, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1); + } + return; + } dim3 gridDim(num_blocks, ne1, ne2); if (dim == 0) { concat_f32_dim0<<>>(x, y, dst, ne0, ne00); + //concat_f32_dim0<<>>(x, y, dst, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1); return; } if (dim == 1) { @@ -150,35 +199,90 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + GGML_ASSERT(src0->type == src1->type && src0->type == dst->type); + cudaStream_t stream = ctx.stream(); const int32_t dim = ((int32_t *) dst->op_params)[0]; + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && + (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + CUDA_CHECK(cudaMemcpyAsync((char *)dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *)dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream)); + return; + } + + if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) && + src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) { + auto bs = ggml_blck_size(dst->type); + auto ts = ggml_type_size(dst->type); + auto ne00_eff = (src0->ne[0]/bs)*ts/sizeof(float); + auto ne0_eff = (dst->ne[0]/bs)*ts/sizeof(float); + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) { + // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name); + // GGML_ABORT("fatal error"); + //} + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + //printf("%s(%s, %s): %ld %zu %zu %ld %zu %zu\n", __func__, src0->name, src1->name, src0->ne[0], src0->nb[0], src0->nb[1], dst->ne[0], dst->nb[0], dst->nb[1]); + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + ne00_eff, src0->ne[1], src0->ne[2], + ne0_eff, dst->ne[1], dst->ne[2], dim, stream); + //src0->nb[1]/sizeof(float), src0->ne[1], src0->ne[2], + //dst->nb[1]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream); + //src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], + //dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream); + } + } else { + //printf("%s(not contiguous): %s(%s) and %s(%s)\n", __func__, src0->name, ggml_type_name(src0->type), src1->name, ggml_type_name(src1->type)); + auto ne10_eff = (src1->ne[0]/bs)*ts/sizeof(float); + dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); + concat_f32_non_cont<<>>( + (const char *)src0->data, + (const char *)src1->data, + ( char *)dst->data, + ne00_eff, src0->ne[1], src0->ne[2], src0->ne[3], + //src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3], + sizeof(float), src0->nb[1], src0->nb[2], src0->nb[3], + ne10_eff, src1->ne[1], src1->ne[2], src1->ne[3], + //src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3], + sizeof(float), src1->nb[1], src1->nb[2], src1->nb[3], + ne0_eff, dst->ne[1], dst->ne[2], dst->ne[3], + //dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3], + sizeof(float), dst->nb[1], dst->nb[2], dst->nb[3], dim); + } + return; + } + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) { + // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name); + // GGML_ABORT("fatal error"); + //} const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; float * dst_d = (float *)dst->data; - if (dim != 3) { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda( - src0_d + i3 * (src0->nb[3] / 4), - src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * ( dst->nb[3] / 4), - src0->ne[0], src0->ne[1], src0->ne[2], - dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); - } - } else { - const size_t size0 = ggml_nbytes(src0); - const size_t size1 = ggml_nbytes(src1); - - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 2d8f023f..1a3d96ed 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -674,9 +674,16 @@ static __global__ void dequantize_block_iq4_ks(const void * __restrict__ vx, dst const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const float d = scale * ((x[i].scales[ib] & 254) - 127); const int8_t * values = iq4k_values + ((x[i].scales[ib] & 1) << 4); - for (int j = 0; j < 4; ++j) { - y[j+ 0] = d * values[q4[j] & 0xf]; - y[j+16] = d * values[q4[j] >> 4]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2bfloat16(d * values[q4[j] & 0xf]); + y[j+16] = __float2bfloat16(d * values[q4[j] >> 4]); + } + } else { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * values[q4[j] & 0xf]; + y[j+16] = d * values[q4[j] >> 4]; + } } } @@ -705,9 +712,16 @@ static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, ds aux32[1] = ((aux32[0] >> 4) & 0x0f0f0f0f); aux32[0] &= 0x0f0f0f0f; const uint8_t * aux8 = (const uint8_t *)aux32; - for (int j = 0; j < 4; ++j) { - y[j+ 0] = d * values[aux8[j+0]]; - y[j+16] = d * values[aux8[j+4]]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2bfloat16(d * values[aux8[j+0]]); + y[j+16] = __float2bfloat16(d * values[aux8[j+4]]); + } + } else { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * values[aux8[j+0]]; + y[j+16] = d * values[aux8[j+4]]; + } } } @@ -727,9 +741,16 @@ static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_ const float d2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32); const int8_t * values1 = iq4k_values + 16*((x[i].extra >> (2*ib+0)) & 1); const int8_t * values2 = iq4k_values + 16*((x[i].extra >> (2*ib+1)) & 1); - for (int j = 0; j < 4; ++j) { - y[j+ 0] = d1 * values1[q4[j] & 0xf]; - y[j+16] = d2 * values2[q4[j] >> 4]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2bfloat16(d1 * values1[q4[j] & 0xf]); + y[j+16] = __float2bfloat16(d2 * values2[q4[j] >> 4]); + } + } else { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d1 * values1[q4[j] & 0xf]; + y[j+16] = d2 * values2[q4[j] >> 4]; + } } } @@ -751,12 +772,22 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_ const uint8_t * qs = x[i].qs + 32*ib64 + 2*il; const uint8_t * qh = x[i].qh + 2*il; const uint8_t extra = x[i].extra >> 4*(ib64%4); - for (int j = 0; j < 2; ++j) { - const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); - y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]; - y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]; - y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]; - y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); + y[j+ 0] = __float2bfloat16(dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]); + y[j+16] = __float2bfloat16(dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]); + y[j+32] = __float2bfloat16(dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]); + y[j+48] = __float2bfloat16(dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]); + } + } else { + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); + y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]; + y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]; + y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]; + y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]; + } } } @@ -784,10 +815,17 @@ static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_ uint8_t q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4); uint8_t q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2); uint8_t q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2); - y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0)); - y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0)); - y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0)); - y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0)); + if constexpr (std::is_same_v) { + y[j+ 0] = __float2bfloat16(dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0))); + y[j+16] = __float2bfloat16(dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0))); + y[j+32] = __float2bfloat16(dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0))); + y[j+48] = __float2bfloat16(dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0))); + } else { + y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0)); + y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0)); + y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0)); + y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0)); + } } } @@ -808,11 +846,20 @@ static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_ const float dl4 = d * (((x[i].scales[4*ib128+3] >> 4*(il/8)) & 0xf) - 8); const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); - for (int j = 0; j < 2; ++j) { - y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; - y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]; - y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]; - y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = __float2bfloat16(dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]); + y[j+32] = __float2bfloat16(dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]); + y[j+64] = __float2bfloat16(dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]); + y[j+96] = __float2bfloat16(dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]); + } + } else { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; + y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]; + y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]; + y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]; + } } } @@ -836,11 +883,20 @@ static __global__ void dequantize_block_iq2_ks(const void * __restrict__ vx, dst const float dl3 = d * (((x[i].scales[2*ib128+1] & 0xf) | ((extra >> 6) & 0x10)) - 16); const float dl4 = d * (((x[i].scales[2*ib128+1] >> 4) | ((extra >> 7) & 0x10)) - 16); const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; - for (int j = 0; j < 2; ++j) { - y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; - y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]; - y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]; - y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = __float2bfloat16(dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]); + y[j+32] = __float2bfloat16(dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]); + y[j+64] = __float2bfloat16(dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]); + y[j+96] = __float2bfloat16(dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]); + } + } else { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; + y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]; + y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]; + y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]; + } } } @@ -863,12 +919,22 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_ const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; const uint8_t * qh = x[i].qh + 2*il; const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); - for (int j = 0; j < 2; ++j) { - const uint8_t h = qh[j] >> (4*(ib128%2)); - y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]; - y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]; - y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]; - y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 2; ++j) { + const uint8_t h = qh[j] >> (4*(ib128%2)); + y[j+ 0] = __float2bfloat16(dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]); + y[j+32] = __float2bfloat16(dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]); + y[j+64] = __float2bfloat16(dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]); + y[j+96] = __float2bfloat16(dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]); + } + } else { + for (int j = 0; j < 2; ++j) { + const uint8_t h = qh[j] >> (4*(ib128%2)); + y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]; + y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]; + y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]; + y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]; + } } } @@ -1180,6 +1246,22 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { return convert_to_bf16_cuda; case GGML_TYPE_F16: return convert_to_bf16_cuda; + case GGML_TYPE_IQ2_KS: + return dequantize_row_iq2_ks_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; + case GGML_TYPE_IQ3_K: + return dequantize_row_iq3_k_cuda; + case GGML_TYPE_IQ4_KSS: + return dequantize_row_iq4_kss_cuda; + case GGML_TYPE_IQ4_KS: + return dequantize_row_iq4_ks_cuda; + case GGML_TYPE_IQ4_K: + return dequantize_row_iq4_k_cuda; + case GGML_TYPE_IQ5_K: + return dequantize_row_iq5_k_cuda; + case GGML_TYPE_IQ6_K: + return dequantize_row_iq6_k_cuda; default: return nullptr; } diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 0efcecde..5d107450 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh new file mode 100644 index 00000000..ecb65999 --- /dev/null +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -0,0 +1,46 @@ +// Simplified API for asynchronous data loading. + +#include "common.cuh" + +// Copies data from global to shared memory, cg == cache global. +// Both the src and dst pointers must be aligned to 16 bit. +// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. +// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared. +// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements. +template +static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) { + static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload"); +#ifdef CP_ASYNC_AVAILABLE +#if CUDART_VERSION >= 11040 + if (preload == 256) { + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 128) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 64) { + asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else +#endif // CUDART_VERSION >= 11040 + { + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } +#else + GGML_UNUSED(dst); + GGML_UNUSED(src); + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} + +// Makes each thread wait until its asynchronous data copies are done. +// This does NOT provide any additional synchronization. +// In particular, when copying data with multiple warps a call to __syncthreads will be needed. +static __device__ __forceinline__ void cp_async_wait_all() { +#ifdef CP_ASYNC_AVAILABLE + asm volatile("cp.async.wait_all;"); +#else + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 0b269a86..3b87cbad 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,4 +1,12 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "cpy.cuh" +#include "convert.cuh" typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -65,6 +73,71 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +template +static __global__ void k_cpy_q8_0_to_float(const char * cx, dst_t * dst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; + + const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); + const int ib = i00/QK8_0; + const int iq = i00%QK8_0; + + if constexpr (std::is_same_v) { + dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __float2bfloat16(__half2float(q8[ib].d)*q8[ib].qs[iq]); + } else { + dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq]; + } +} + +static __global__ void k_transpose_q8_0(const char * cx, char * cdst, + const int ne10, const int ne11, const int ne12, + const int nb01, const int nb02, const int nb03, + const int nb11, const int nb12, const int nb13) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + + //const int64_t ne00 = ne11; + //const int64_t ne01 = ne10; + //const int64_t ne02 = ne12; + const int64_t i03 = i13; + const int64_t i02 = i12; + const int64_t i01 = i10; //(i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; + const int64_t i00 = i11; //i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; + + const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); + const int ib0 = i00/QK8_0; + const int iq0 = i00%QK8_0; + + float xi = __half2float(q8[ib0].d)*q8[ib0].qs[iq0]; + float amax = fabsf(xi); + amax = warp_reduce_max(amax); + + //printf("%d, %d, %d: i = %ld, i11 = %ld i10 = %ld, xi = %g, amax = %g\n", blockDim.x, blockIdx.x, threadIdx.x, i, i11, i10, xi, amax); + + float d = amax/127; + int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + block_q8_0 * dst = (block_q8_0 *)(cdst + i11*nb11 + i12*nb12 + i13*nb13); + dst[i10 / QK8_0].qs[i10 % QK8_0] = q; + + if (threadIdx.x == 0) { + dst[i10 / QK8_0].d = __float2half(d); + } +} + static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; block_q8_0 * dsti = (block_q8_0 *) cdsti; @@ -464,6 +537,35 @@ static void ggml_cpy_f16_f16_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { + auto stream = ctx.stream(); + auto num_blocks = ggml_nelements(dst)/QK8_0; + k_transpose_q8_0<<>>( + (const char *)src->data, (char *)dst->data, + dst->ne[0], dst->ne[1], dst->ne[2], src->nb[0], src->nb[2], src->nb[3], + dst->nb[1], dst->nb[2], dst->nb[3]); +} + +static void copy_q8_0_to_float(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { + auto stream = ctx.stream(); + auto num_blocks = ggml_nelements(dst)/QK8_0; + if (dst->type == GGML_TYPE_F16) { + k_cpy_q8_0_to_float<<>>((const char *)src->data, (half *)dst->data, ggml_nelements(dst), + src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]); + } + else if (dst->type == GGML_TYPE_F32) { + k_cpy_q8_0_to_float<<>>((const char *)src->data, (float *)dst->data, ggml_nelements(dst), + src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]); + } + else if (dst->type == GGML_TYPE_BF16) { + k_cpy_q8_0_to_float<<>>((const char *)src->data, (nv_bfloat16 *)dst->data, ggml_nelements(dst), + src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]); + } + else { + GGML_ABORT("fatal error"); + } +} + void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -520,11 +622,40 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && + (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) { + copy_q8_0_to_float(ctx, src0, src1); + } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { + if (src1->type == GGML_TYPE_F16) { + auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); + if (to_fp16) { + to_fp16(src0->data, (half *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); + } + } + else if (src1->type == GGML_TYPE_F32) { + auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); + if (to_fp32) { + to_fp32(src0->data, (float *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); + } + } + else if (src1->type == GGML_TYPE_BF16) { + auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); + if (to_bf16) { + to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); + } + } + } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + transpose_q8_0(ctx, src0, src1); } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); + fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__, + src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]); GGML_ABORT("fatal error"); } } @@ -556,12 +687,33 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; + return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16; - } else { - fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); + } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && + (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) { + return (void*)copy_q8_0_to_float; + } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { + if (src1->type == GGML_TYPE_F16) { + auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); + if (to_fp16) return (void*)to_fp16; + } + else if (src1->type == GGML_TYPE_F32) { + auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); + if (to_fp32) return (void*)to_fp32; + } + else if (src1->type == GGML_TYPE_BF16) { + auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); + if (to_bf16) return (void*)to_bf16; + } } + else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + return (void *)transpose_q8_0; + } + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__, + src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]); + GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 0a664dbd..5e8ee0f6 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #include "common.cuh" @@ -52,7 +59,7 @@ typedef half (*vec_dot_KQ_f16_t)( typedef float (*vec_dot_KQ_f32_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -62,7 +69,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -92,7 +99,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -102,7 +109,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -142,7 +149,7 @@ static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4) return *((const int *) &val0_8); } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -152,7 +159,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -179,7 +186,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -189,7 +196,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -226,7 +233,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -236,7 +243,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -277,7 +284,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -287,7 +294,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -320,7 +327,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -330,7 +337,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; @@ -353,7 +360,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { @@ -368,7 +375,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( half2 sum2 = make_half2(0.0f, 0.0f); #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; @@ -384,7 +391,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( float sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; @@ -603,29 +610,29 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v return x[i]; } -template +template constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; } -template +template constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; } @@ -653,20 +660,20 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -template // D == head size +template // Dv == V head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dv, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { - VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; - dst += D * gridDim.y*blockIdx.x; + VKQ_parts += parallel_blocks*Dv * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += Dv * gridDim.y*blockIdx.x; const int tid = threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dv); __shared__ float2 meta[parallel_blocks]; if (tid < 2*parallel_blocks) { @@ -690,20 +697,20 @@ static __global__ void flash_attn_combine_results( const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*Dv + blockIdx.y*Dv + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; + dst[blockIdx.y*Dv + tid] = VKQ_numerator / VKQ_denominator; } -static void on_no_fattn_vec_case(const int D) { - if (D == 64) { +static void on_no_fattn_vec_case(const int Dk, const int Dv) { + if (Dk == 64 && Dv == 64) { fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); fprintf(stderr, "By default only f16 KV cache is supported.\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); GGML_ABORT("fatal error"); - } else if (D == 128) { + } else if (Dk == 128 && Dv == 128) { fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); fprintf(stderr, "Supported combinations:\n"); fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n"); @@ -715,14 +722,22 @@ static void on_no_fattn_vec_case(const int D) { fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); + } + else if (Dk == 192 && Dv == 128) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 192 / 128\n"); + // TODO: add what is supported + } + else if (Dk == 576 && Dv == 512) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 576 / 512\n"); + // TODO: add what is supported } else { - fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); + fprintf(stderr, "Unsupported KV type combination for head_sizes %d, %d.\n", Dk, Dv); fprintf(stderr, "Only f16 is supported.\n"); GGML_ABORT("fatal error"); } } -template +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V @@ -838,12 +853,345 @@ void launch_fattn( return; } - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(Dv, 1, 1); const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); const int shmem_combine = 0; - flash_attn_combine_results + flash_attn_combine_results <<>> (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); CUDA_CHECK(cudaGetLastError()); } + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_mma_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + constexpr int ncols = ncols1*ncols2; + + const int bidx0 = blockIdx.x; + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + const int channel = kbc0 / (iter_k*iter_j); + const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + + if (jt*ncols1 + j >= ne01) { + return; + } + + dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const float2 tmp = dst_fixup[bidx0*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +template // D == head size +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_mma_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks) { + VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; + dst += D * gridDim.z*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + extern __shared__ float2 meta[]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; +} + +template +void launch_fattn_mma( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE +) { + constexpr int ncols = ncols1 * ncols2; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + GGML_ASSERT(Q->ne[3] == 1); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t main_stream = ctx.stream(); + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + K_f16.alloc(ggml_nelements(K)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); + K_data = (char *) K_f16.ptr; + + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } + + int parallel_blocks = 1; + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + + const dim3 block_dim(warp_size, nwarps, 1); + dim3 blocks_num; + if (stream_k) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int max_blocks = 2*nsm; + const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + + const int nblocks_stream_k = max_blocks; + + const bool use_stream_k = cc >= CC_ADA_LOVELACE || tiles_efficiency_percent < 75; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.y = 1; + blocks_num.z = 1; + + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); + } else { + GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); + const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + + // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: + parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = Q->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(block_dim.x % warp_size == 0); + fattn_kernel<<>>( + (const char *) Q->data, + K_data, + V_data, + mask ? ((const char *) mask->data) : nullptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + nb11, nb12, nb13, + nb21, nb22, nb23, + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (stream_k) { + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; + + flash_attn_mma_stream_k_fixup + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + } + } else if (parallel_blocks > 1) { + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); + + flash_attn_mma_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + } + CUDA_CHECK(cudaGetLastError()); +} + diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh new file mode 100644 index 00000000..af38071d --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -0,0 +1,1047 @@ +#include "common.cuh" +#include "cp-async.cuh" +#include "mma_new.cuh" +#include "fattn-common.cuh" + +using namespace ggml_cuda_mma; + +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, half2> tile_B_16; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 16, float> tile_C_KQ_16; +typedef tile<16, 4, half2> tile_C_VKQ; +typedef tile<16, 8, half2> tile_C_VKQ_16; + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + // If cp.async is available, load up to the highest power of 2 in D asynchronously: +#ifdef CP_ASYNC_AVAILABLE + static_assert(D >= 64 && D < 512, "bad D"); + constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); + + const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); + + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; + constexpr int stride_i = WARP_SIZE / chunks_per_row; +#pragma unroll + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); + const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; + + cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); + } +#else + constexpr int k0_sync_start = 0; +#endif // CP_ASYNC_AVAILABLE + static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); + + // If D is not a power of 2, the rest is loaded synchronously. + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop || k0_stop <= k0_sync_start) { + continue; + } + +#pragma unroll + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; + } + } + } +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( + const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { + static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); +#ifdef CP_ASYNC_AVAILABLE + constexpr int preload = KQ_per_iter * sizeof(half); + constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); + + cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } +#else + constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); + + tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; + } +#endif // CP_ASYNC_AVAILABLE +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_iter( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_KV, + const int stride_mask, + const int jt, + half2 * const __restrict__ tile_K, + half2 * const __restrict__ tile_V, + half2 * const __restrict__ tile_mask, + const tile_B * const __restrict__ Q_B, + tile_C_VKQ * const __restrict__ VKQ_C, + float * const __restrict__ KQ_max, + float * const __restrict__ KQ_rowsum, + const int kb0) { +#ifdef INT8_MMA_AVAILABLE + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + const int k_VKQ_0 = kb0 * KQ_per_iter; + tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; + + // Use wide variants of tiles if ntiles >= 2. + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + +#ifdef CP_ASYNC_AVAILABLE + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); +#else + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + + if (use_logit_softcap) { + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + float KQ_max_new[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_new[col] = KQ_max[col]; + } + float KQ_rowsum_add[cols_per_thread] = {0.0f}; + + if (ntiles == 1) { + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * + __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 16; offset >= 4; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); + + KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + } + } + } else { // ntiles > 1 + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { + const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + + const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); + const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; + KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; + KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; + } + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + } + } + } + + // Values per KQ column are spread across 4 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 2; offset >= 1; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + + KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); + + KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + } + } + } + } + + { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + KQ_max[col] = KQ_max_new[col]; + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; + tile_B_16 * B_16 = (tile_B_16 *) B; + static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); + if (ntiles == 1) { +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } + } else { + for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); + } + } + } + +#ifdef CP_ASYNC_AVAILABLE + // Preload K tile for next iteration: + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); + } +#else + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } + } + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); + GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); + GGML_UNUSED(kb0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q1, + const int stride_Q2, + const int stride_KV, + const int stride_mask, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef INT8_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + + static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); + + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); + + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements: + extern __shared__ half2 tile_K[]; +#ifdef CP_ASYNC_AVAILABLE + half2 * tile_V = tile_K + KQ_per_iter*D2_padded; +#else + half2 * tile_V = tile_K; +#endif // CP_ASYNC_AVAILABLE + half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; + + tile_B Q_B[D/(2*tile_B::J) * ntiles]; + tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; + + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + + float KQ_rowsum[cols_per_thread] = {0.0f}; + float KQ_max[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max[col] = -FLT_MAX/2.0f; + } + + // Temporarily load Q data into tile_K, will be loaded into registers afterwards. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { + break; + } + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (jt*ncols1 + j < ne01) { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; + tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); + } + } + } + } + + __syncthreads(); + + { + const int j0 = (threadIdx.y / np) * cols_per_warp; + +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + if (ntiles == 1) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], + tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); + } + } + } + } + + __syncthreads(); + + // Preload mask and K data for first iteration when using cp_async: +#ifdef CP_ASYNC_AVAILABLE + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); +#endif // CP_ASYNC_AVAILABLE + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + } + { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool last_iter = true; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + } + + // With cp_async there is no __syncthreads at the end of the iter, + // there can be a race condition on shared memory access for combining/writing back results. +#ifdef CP_ASYNC_AVAILABLE + if (nwarps*cols_per_warp > KQ_per_iter) { + __syncthreads(); + } +#endif // CP_ASYNC_AVAILABLE + + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8/4 threads each, does not need full reduce. + { + constexpr int offset_first = ntiles == 1 ? 16 : 2; + constexpr int offset_last = ntiles == 1 ? 4 : 1; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } + } + + // Write VKQ accumulators to shared memory in column-major format. + // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // Also for np > 1 the combination is done via these values in shared memory. + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + +#pragma unroll + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); + + tile_K[jc_cwd*D2_padded + k] = B.x[l]; + } + } + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); + + tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; + } + } + } + } + + if constexpr (ntiles == 1) { + const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } else { + static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); + const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta + + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) + + tile_C_VKQ_16::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } + + static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); + if (np > 1 && threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; + float2 meta[nmeta]; +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; + } + + float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + } + + float KQ_cms[nmeta]; // KQ combine max scale per warp. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); + } + + float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_crs += KQ_cms[imeta]*meta[imeta].y; + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + } + + // Write back combined meta data: +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + // Combined KQ max scale + rowsum. + meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); + } + } + + // Combined KQ max + rowsum. + static_assert(cols_per_warp <= WARP_SIZE); + if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + } + + if (np > 1) { + __syncthreads(); + } + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + continue; + } + + const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; + const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; + } + } + } + } + } + + if (np > 1) { + __syncthreads(); + } +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask); + GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +__launch_bounds__(nwarps*WARP_SIZE, 2) +static __global__ void flash_attn_mma_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const float logit_softcap, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if defined(INT8_MMA_AVAILABLE) + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int stride_Q1 = nb01 / sizeof(float2); + const int stride_Q2 = nb02 / sizeof(float2); + const int stride_KV = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half2); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice. + + // kbc == k block continuous, current index in continuous ijk space. + int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. + // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). + // In the most general case >2 seams can fall into the same tile. + + // kb0 == k start index when in the output tile. + int kb0_start = kbc % iter_k; + int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == iter_k) { + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if (kb0_start == 0) { + constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } else { + constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } + + kbc += iter_k; + kbc -= kbc % iter_k; + + kb0_start = 0; + kb0_stop = min(iter_k, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + constexpr bool needs_fixup = false; + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); +#else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; +#endif // defined(INT8_MMA_AVAILABLE) +} + +template +void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + constexpr int ncols = ncols1 * ncols2; + constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; + constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; + constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); + constexpr int cols_per_warp = ntiles * tile_B::I; + + static_assert(D % tile_B::J == 0, "bad D"); + static_assert(ncols % cols_per_warp == 0, "bad ncols"); + + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; + + const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); + const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); + + const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_mma_ext_f16; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_mma_ext_f16; + } + + launch_fattn_mma + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); +} + + +#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64) + +// Kernels with ncols == 128 are only 4% faster due to register pressure. +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory. diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu new file mode 100644 index 00000000..d1484451 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -0,0 +1,1725 @@ +// Adapted from https://github.com/ggml-org/llama.cpp/pull/13306 +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#include "fattn-new-mma.cuh" +#include "cp-async.cuh" +#include "mma_new.cuh" +#include "fattn-common.cuh" + +using namespace ggml_cuda_mma; + +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, half2> tile_B_16; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 16, float> tile_C_KQ_16; +typedef tile<16, 4, half2> tile_C_VKQ; +typedef tile<16, 8, half2> tile_C_VKQ_16; + +// Config options for specific head sizes. +// Should not affect results, only speed/register pressure/shared memory use. +// +// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. +// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). +// Q_in_reg: whether the Q values should be kept permanently in registers. +// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. +// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. +// nbatch_V2: number of V half2 values in direction of DV to load in parallel. +// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. + +template +struct fattn_mma_f16_config; + +// +// The previous MMA version is better (faster) +// I'm keeping these around commented out for now, +// and only using the 576, 512 case. +// +//template <> +//struct fattn_mma_f16_config< 64, 64> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 32; +// static constexpr int nbatch_V2 = 32; +// static constexpr int nbatch_combine = 32; +//}; +// +//template <> +//struct fattn_mma_f16_config< 80, 80> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 40; +// static constexpr int nbatch_V2 = 40; +// static constexpr int nbatch_combine = 40; +//}; +// +//template <> +//struct fattn_mma_f16_config< 96, 96> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 48; +// static constexpr int nbatch_V2 = 48; +// static constexpr int nbatch_combine = 48; +//}; +// +//template <> +//struct fattn_mma_f16_config<112, 112> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 56; +// static constexpr int nbatch_V2 = 56; +// static constexpr int nbatch_combine = 56; +//}; +// +//template <> +//struct fattn_mma_f16_config<128, 128> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 64; +// static constexpr int nbatch_V2 = 64; +// static constexpr int nbatch_combine = 64; +//}; +// +//template <> +//struct fattn_mma_f16_config<192, 128> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 96; +// static constexpr int nbatch_V2 = 64; +// static constexpr int nbatch_combine = 64; +//}; +// +//template <> +//struct fattn_mma_f16_config<256, 256> { +// static constexpr int nbatch_fa = 32; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 128; +// static constexpr int nbatch_V2 = 128; +// static constexpr int nbatch_combine = 128; +//}; + +template <> +struct fattn_mma_f16_config<576, 512> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 8; + static constexpr bool Q_in_reg = false; + static constexpr int nstages_target = 1; + static constexpr int nbatch_K2 = 160; + static constexpr int nbatch_V2 = 128; + static constexpr int nbatch_combine = 128; +}; + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { + + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + + if constexpr (use_cp_async) { + const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); + + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + + const int chunks_per_row = D2 / h2_per_chunk; + + int k0_start = 0; +#pragma unroll + for (int stride_k = WARP_SIZE; stride_k > WARP_SIZE/32; stride_k >>= 1) { + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + + if (k0_start == k0_stop) { + continue; + } + + const int stride_i = WARP_SIZE / stride_k; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + threadIdx.x / stride_k; + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); + } + } + k0_start = k0_stop; + } + } else { + static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop || k0_stop <= 0) { + continue; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + } + } + } + } +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( + const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { + static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); + + if constexpr (use_cp_async) { + constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; + constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = 4 * (threadIdx.x % (nbatch_fa/8)); + + cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } + return; + } + + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); + + tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; + } +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_iter( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_K, + const int stride_V, + const int stride_mask, + const int jt, + half2 * const __restrict__ tile_Q, + half2 * const __restrict__ tile_K, + half2 * const __restrict__ tile_V, + half2 * const __restrict__ tile_mask, + const tile_B * const __restrict__ Q_B, + tile_C_VKQ * const __restrict__ VKQ_C, + float * const __restrict__ KQ_max, + float * const __restrict__ KQ_rowsum, + const int kb0) { +#ifdef INT8_MMA_AVAILABLE + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; + + const int k_VKQ_0 = kb0 * c::nbatch_fa; + tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; + + // Use wide variants of tiles if ntiles >= 2. + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + } else { + constexpr bool use_cp_async = nstages == 1; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + } + } + +#pragma unroll + for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { + const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + const int k0_diff = k0_stop - k0_start; + + if (nstages <= 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate tile of KQ: + if constexpr (c::Q_in_reg) { +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } + } + } + } else { + static_assert(ntiles == 2, "ntiles != 2 not implemented"); +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); + } + } + } + + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } + + if (use_logit_softcap) { + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + float KQ_max_new[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_new[col] = KQ_max[col]; + } + float KQ_rowsum_add[cols_per_thread] = {0.0f}; + + if (ntiles == 1) { + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * + __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 16; offset >= 4; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); + + KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + } + } + } else { // ntiles > 1 + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { + const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + + const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); + const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; + KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; + KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; + } + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + } + } + } + + // Values per KQ column are spread across 4 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 2; offset >= 1; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + + KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); + + KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + } + } + } + } + + { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + KQ_max[col] = KQ_max_new[col]; + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; + tile_B_16 * B_16 = (tile_B_16 *) B; + static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); + if (ntiles == 1) { +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } + } else { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); + } + } + } + + if (nstages > 1) { + // Preload K tile for next iteration: + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); + } + } + +#pragma unroll + for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { + const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; + const int i0_diff = i0_stop - i0_start; + + if (nstages == 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } + } + } + } + + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); + GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); + GGML_UNUSED(kb0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q1, + const int stride_Q2, + const int stride_K, + const int stride_V, + const int stride_mask, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef INT8_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + + static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); + + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; + + constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; + + extern __shared__ half2 tile_Q[]; + half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; + half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; + + tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; + tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; + + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + + float KQ_rowsum[cols_per_thread] = {0.0f}; + float KQ_max[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max[col] = -FLT_MAX/2.0f; + } + + // Load Q data into tile_Q, either temporarily or permanently. + // Q in registers is faster, but register pressure is the biggest bottleneck. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { + break; + } + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (jt*ncols1 + j < ne01) { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; + tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); + } + } + } + } + + __syncthreads(); + + if (c::Q_in_reg) { + const int j0 = (threadIdx.y / np) * cols_per_warp; + +#pragma unroll + for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { + if (ntiles == 1) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], + tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); + } + } + } + } + + __syncthreads(); + + // Preload mask and K data for first iteration when using cp_async with multiple stages: + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + constexpr bool use_cp_async = true; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); + } + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + } + { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool last_iter = true; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + } + + // With multi-stage loading there is no __syncthreads at the end of the iter, + // there can be a race condition on shared memory access for combining/writing back results. + if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { + __syncthreads(); + } + + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8/4 threads each, does not need full reduce. + { + constexpr int offset_first = ntiles == 1 ? 16 : 2; + constexpr int offset_last = ntiles == 1 ? 4 : 1; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } + } + + // Combine VKQ accumulator values if np > 1. + // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // So also write VKQ accumulators to shared memory in column-major format if np == 1. + + constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int tile_stride = nbatch_combine + 4; + static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); + + if constexpr (ntiles == 1) { + const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } else { + static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); + const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta + + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) + + tile_C_VKQ_16::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } + + static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); + if (np > 1 && threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; + float2 meta[nmeta]; +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; + } + + float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + } + + float KQ_cms[nmeta]; // KQ combine max scale per warp. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); + } + + float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_crs += KQ_cms[imeta]*meta[imeta].y; + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + } + + // Write back combined meta data: +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + // Combined KQ max scale + rowsum. + meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); + } + } + + // Combined KQ max + rowsum. + static_assert(cols_per_warp <= WARP_SIZE); + if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + } + +#pragma unroll + for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + +#pragma unroll + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); + + tile_Q[jc_cwd*tile_stride + k] = B.x[l]; + } + } + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); + + tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; + } + } + } + } + + __syncthreads(); + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + continue; + } + + const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0]; + const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val; + } + } + } + } + } + if (np > 1) { + __syncthreads(); + } + } +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); + GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +__launch_bounds__(nwarps*WARP_SIZE, 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const float logit_softcap, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if defined(INT8_MMA_AVAILABLE) + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + NO_DEVICE_CODE; + return; + } + + typedef fattn_mma_f16_config c; + + static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int stride_Q1 = nb01 / sizeof(float2); + const int stride_Q2 = nb02 / sizeof(float2); + const int stride_K = nb11 / sizeof(half2); + const int stride_V = nb21 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half2); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. + + // kbc == k block continuous, current index in continuous ijk space. + int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. + // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). + // In the most general case >2 seams can fall into the same tile. + + // kb0 == k start index when in the output tile. + int kb0_start = kbc % iter_k; + int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == iter_k) { + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if (kb0_start == 0) { + constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } else { + constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } + + kbc += iter_k; + kbc -= kbc % iter_k; + + kb0_start = 0; + kb0_stop = min(iter_k, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + constexpr bool needs_fixup = false; + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); +#else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; +#endif // defined(INT8_MMA_AVAILABLE) +} + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + constexpr int ncols = ncols1*ncols2; + + const int bidx0 = blockIdx.x; + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + const int channel = kbc0 / (iter_k*iter_j); + const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + + if (jt*ncols1 + j >= ne01) { + return; + } + + dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const float2 tmp = dst_fixup[bidx0*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_combine_results_new( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks) { + VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; + dst += D * gridDim.z*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + extern __shared__ float2 meta[]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; +} + +template +void launch_fattn_new_mma( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE +) { + constexpr int ncols = ncols1 * ncols2; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + GGML_ASSERT(Q->ne[3] == 1); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t main_stream = ctx.stream(); + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + K_f16.alloc(ggml_nelements(K)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); + K_data = (char *) K_f16.ptr; + + nb11 = K->ne[0]*sizeof(half); + nb12 = nb11*K->ne[1]; + nb13 = nb12*K->ne[2]; + + // Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are + // gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory. + //const size_t bs = ggml_blck_size(K->type); + //const size_t ts = ggml_type_size(K->type); + + //nb11 = nb11*bs*sizeof(half)/ts; + //nb12 = nb12*bs*sizeof(half)/ts; + //nb13 = nb13*bs*sizeof(half)/ts; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + if constexpr (DV == 512) { + // DeepSeek. In this case the V cache is the same as the K cache, except that + // it has 512 elements per row instead of 576. + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; + V_data = K_data; + } else { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = K->ne[0]*sizeof(half); + nb22 = nb21*V->ne[1]; + nb23 = nb22*V->ne[2]; + + // Original PR in llama.cpp. Same comment as above for the K cache. + //const size_t bs = ggml_blck_size(V->type); + //const size_t ts = ggml_type_size(V->type); + + //nb21 = nb21*bs*sizeof(half)/ts; + //nb22 = nb22*bs*sizeof(half)/ts; + //nb23 = nb23*bs*sizeof(half)/ts; + } + } + + int parallel_blocks = 1; + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + + const dim3 block_dim(warp_size, nwarps, 1); + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + + dim3 blocks_num; + if (stream_k) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int max_blocks = max_blocks_per_sm*nsm; + const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + + const int nblocks_stream_k = max_blocks; + + const bool use_stream_k = cc >= CC_ADA_LOVELACE || tiles_efficiency_percent < 75; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.y = 1; + blocks_num.z = 1; + + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); + } else { + GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); + const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + + // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: + parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = Q->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(block_dim.x % warp_size == 0); + fattn_kernel<<>>( + (const char *) Q->data, + K_data, + V_data, + mask ? ((const char *) mask->data) : nullptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + nb11, nb12, nb13, + nb21, nb22, nb23, + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (stream_k) { + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; + + flash_attn_stream_k_fixup + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + } + } else if (parallel_blocks > 1) { + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); + + flash_attn_combine_results_new + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + } + CUDA_CHECK(cudaGetLastError()); +} + + +template +void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + typedef fattn_mma_f16_config c; + + constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; + constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; + constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; + + const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + + constexpr int ncols = ncols1 * ncols2; + constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int nwarps_max_x = ncols / cols_per_warp; + constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; + constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); + static_assert(DV % tile_A::J == 0, "bad DV"); + static_assert(ncols % cols_per_warp == 0, "bad ncols"); + + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); + + const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; + + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : + nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + } + + launch_fattn_new_mma + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); +} + +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if constexpr (ncols2 <= 8) { + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + } + + if (Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); +} + +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); +} + +void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 16 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + + //switch (Q->ne[0]) { + // case 64: + // GGML_ASSERT(V->ne[0] == 64); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst); + // break; + // case 80: + // GGML_ASSERT(V->ne[0] == 80); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst); + // break; + // case 96: + // GGML_ASSERT(V->ne[0] == 96); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst); + // break; + // case 112: + // GGML_ASSERT(V->ne[0] == 112); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst); + // break; + // case 128: + // GGML_ASSERT(V->ne[0] == 128); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); + // break; + // case 192: + // GGML_ASSERT(V->ne[0] == 128); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst); + // break; + // case 256: + // GGML_ASSERT(V->ne[0] == 256); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); + // break; + // case 576: { + // // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + // GGML_ASSERT(V->ne[0] == 512); + // float max_bias = 0.0f; + // memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + // const bool use_gqa_opt = mask && max_bias == 0.0f; + // GGML_ASSERT(use_gqa_opt); + + // GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + // const int gqa_ratio = Q->ne[2] / K->ne[2]; + // GGML_ASSERT(gqa_ratio % 16 == 0); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + // } break; + // default: + // GGML_ABORT("fatal error"); + // break; + //} +} + diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cuh b/ggml/src/ggml-cuda/fattn-new-mma.cuh new file mode 100644 index 00000000..40f867df --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-new-mma.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index d1bbf01f..420f0bb0 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" #include "fattn-tile-f16.cuh" @@ -291,13 +298,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 25908d7a..f525f1bb 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -1,10 +1,17 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" #include "fattn-tile-f32.cuh" #define FATTN_KQ_STRIDE_TILE_F32 32 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -44,8 +51,9 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)); // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } @@ -61,15 +69,22 @@ static __global__ void flash_attn_tile_ext_f32( const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const int stride_KV2 = nb11 / sizeof(half2); + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb12 / sizeof(half2); const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + // TODO: is it Dk or Dv or both that need to be multiple of 2*WARP_SIZE ? + // let's assume it is is both. + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + + constexpr int Dkv = Dk < Dv ? Dv : Dk; // let's use this when we don't understand if it is Dk or Dv __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32]; - __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts. + // This is being used to store either K or V data. Hence we need max(Dk, Dv) as the dimension + __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][Dkv + 1]; // Pad D to avoid memory bank conflicts. float2 * KV_tmp2 = (float2 *) KV_tmp; float kqmax[ncols/nwarps]; @@ -79,16 +94,16 @@ static __global__ void flash_attn_tile_ext_f32( } float kqsum[ncols/nwarps] = {0.0f}; - float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; + float2 VKQ[ncols/nwarps][(Dv/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; // Convert Q to half2 and store in registers: - __shared__ float Q_f[ncols][D]; + __shared__ float Q_f[ncols][Dk]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { + for (int i0 = 0; i0 < Dk; i0 += 2*WARP_SIZE) { float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; @@ -112,8 +127,8 @@ static __global__ void flash_attn_tile_ext_f32( const int i_KQ = i_KQ_0 + threadIdx.y; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; + for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 2*WARP_SIZE) { + const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_K2 + k_KQ_0/2 + threadIdx.x]; KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); } @@ -124,7 +139,7 @@ static __global__ void flash_attn_tile_ext_f32( float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}}; #pragma unroll - for (int k_KQ = 0; k_KQ < D; ++k_KQ) { + for (int k_KQ = 0; k_KQ < Dk; ++k_KQ) { float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE]; float Q_k[ncols/nwarps]; @@ -193,7 +208,7 @@ static __global__ void flash_attn_tile_ext_f32( kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; } @@ -206,11 +221,11 @@ static __global__ void flash_attn_tile_ext_f32( const int k = k0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); - KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); + KV_tmp2[k*(Dv/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]); + KV_tmp2[k*(Dv/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]); } } @@ -218,14 +233,14 @@ static __global__ void flash_attn_tile_ext_f32( #pragma unroll for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) { - float2 V_k[(D/2)/WARP_SIZE]; + float2 V_k[(Dv/2)/WARP_SIZE]; float KQ_k[ncols/nwarps]; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i]; + V_k[i0/WARP_SIZE] = KV_tmp2[k*(Dv/2) + i]; } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -235,7 +250,7 @@ static __global__ void flash_attn_tile_ext_f32( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps]; @@ -259,7 +274,7 @@ static __global__ void flash_attn_tile_ext_f32( kqsum_j = warp_reduce_sum(kqsum_j); #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { + for (int i00 = 0; i00 < Dv; i00 += 2*WARP_SIZE) { const int i0 = i00 + 2*threadIdx.x; float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; @@ -268,8 +283,8 @@ static __global__ void flash_attn_tile_ext_f32( dst_val.y /= kqsum_j; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 0] = dst_val.x; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 1] = dst_val.y; } if (parallel_blocks != 1 && threadIdx.x == 0) { @@ -285,14 +300,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 7f14e78b..2cf4f4ef 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,9 +1,16 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dk, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, @@ -43,14 +50,15 @@ static __global__ void flash_attn_vec_ext_f16( const int ne3) { #ifdef FP16_AVAILABLE // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); + constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V); @@ -67,12 +75,13 @@ static __global__ void flash_attn_vec_ext_f16( const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = Dk / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dk); - __shared__ half KQ[ncols*D]; + __shared__ half KQ[ncols*Dk]; half2 * KQ2 = (half2 *) KQ; half kqmax[ncols]; @@ -94,9 +103,9 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; - half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + half2 Q_h2[ncols][Dk/(2*WARP_SIZE)]; + int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk/(sizeof(int)*QK8_1)]; + half2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1]; if (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -107,18 +116,18 @@ static __global__ void flash_attn_vec_ext_f16( } // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int)); // Set memory to zero if out of bounds: if (ncols > 2 && ic0 + j >= ne01) { #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; tmp_q_i32[i] = 0; } - if (threadIdx.x < D/QK8_1) { + if (threadIdx.x < Dk/QK8_1) { tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); } continue; @@ -126,7 +135,7 @@ static __global__ void flash_attn_vec_ext_f16( const float * Q_f = (const float *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); } } @@ -135,11 +144,11 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; @@ -154,7 +163,7 @@ static __global__ void flash_attn_vec_ext_f16( const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); @@ -166,13 +175,13 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; + KQ[j*Dk + tid] = -HALF_MAX_HALF; } half2 VKQ[ncols] = {{0.0f, 0.0f}}; - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + const int k_start = parallel_blocks == 1 ? 0 : ip*Dk; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) { // Calculate KQ tile and keep track of new maximum KQ values: // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, @@ -186,10 +195,10 @@ static __global__ void flash_attn_vec_ext_f16( } #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -209,7 +218,7 @@ static __global__ void flash_attn_vec_ext_f16( } if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; + KQ[j*Dk + i_KQ] = sum; } } } @@ -234,9 +243,9 @@ static __global__ void flash_attn_vec_ext_f16( const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); kqmax[j] = kqmax_new_j; - const half val = hexp(KQ[j*D + tid] - kqmax[j]); + const half val = hexp(KQ[j*Dk + tid] - kqmax[j]); kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; + KQ[j*Dk + tid] = val; VKQ[j] *= __half2half2(KQ_max_scale); } @@ -244,8 +253,8 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); #pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + for (int k0 = 0; k0 < Dv; k0 += 2) { + if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k0 >= ne11) { break; } @@ -254,7 +263,7 @@ static __global__ void flash_attn_vec_ext_f16( reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; + VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2]; } } @@ -285,27 +294,28 @@ static __global__ void flash_attn_vec_ext_f16( dst_val /= kqsum[j_VKQ]; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val; } if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } + } #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr int nwarps = Dk/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + constexpr bool need_f16_K = Dk != 128 && Dk != 192; + constexpr bool need_f16_V = Dv != 128 && Dv != 64; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -325,9 +335,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } @@ -336,9 +346,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } @@ -347,9 +357,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } @@ -358,9 +368,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } @@ -368,15 +378,19 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } } #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ template void ggml_cuda_flash_attn_ext_vec_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); @@ -435,3 +449,6 @@ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 1aa88272..c91cef3d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,9 +1,16 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dk, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ Q, @@ -42,14 +49,15 @@ static __global__ void flash_attn_vec_ext_f32( const int ne2, const int ne3) { // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); + constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); @@ -64,15 +72,16 @@ static __global__ void flash_attn_vec_ext_f32( const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = Dk / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dk); - __shared__ float KQ[ncols*D]; + __shared__ float KQ[ncols*Dk]; #pragma unroll for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -FLT_MAX/2.0f; + KQ[j*Dk + tid] = -FLT_MAX/2.0f; } float kqmax[ncols]; @@ -94,9 +103,9 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: - float2 Q_f2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)]; - float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + float2 Q_f2[ncols][Dk/(2*WARP_SIZE)]; + int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk >= Dk/(sizeof(int)*QK8_1)]; + float2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1]; if (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -107,18 +116,18 @@ static __global__ void flash_attn_vec_ext_f32( } // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int)); // Set memory to zero if out of bounds: if (ncols > 2 && ic0 + j >= ne01) { #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; tmp_q_i32[i] = 0; } - if (threadIdx.x < D/QK8_1) { + if (threadIdx.x < Dk/QK8_1) { tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); } continue; @@ -126,7 +135,7 @@ static __global__ void flash_attn_vec_ext_f32( const float * Q_f = (const float *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); } } @@ -135,11 +144,11 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; @@ -153,7 +162,7 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); @@ -165,8 +174,8 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + const int k_start = parallel_blocks == 1 ? 0 : ip*Dk; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) { // Calculate KQ tile and keep track of new maximum KQ values: float kqmax_new_arr[ncols]; @@ -176,10 +185,10 @@ static __global__ void flash_attn_vec_ext_f32( } #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -195,7 +204,7 @@ static __global__ void flash_attn_vec_ext_f32( kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; + KQ[j*Dk + i_KQ] = sum; } } } @@ -220,9 +229,9 @@ static __global__ void flash_attn_vec_ext_f32( const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); kqmax[j] = kqmax_new_j; - const float val = expf(KQ[j*D + tid] - kqmax[j]); + const float val = expf(KQ[j*Dk + tid] - kqmax[j]); kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; + KQ[j*Dk + tid] = val; VKQ[j] *= KQ_max_scale; } @@ -230,15 +239,15 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); #pragma unroll - for (int k = 0; k < D; ++k) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) { + for (int k = 0; k < Dv; ++k) { + if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k >= ne11) { break; } const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_ki*KQ[j*D + k]; + VKQ[j] += V_ki*KQ[j*Dk + k]; } } @@ -269,24 +278,25 @@ static __global__ void flash_attn_vec_ext_f32( dst_val /= kqsum[j_VKQ]; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val; } if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } + } } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr int nwarps = Dk/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + constexpr bool need_f16_K = Dk != 128 && Dk != 192; + constexpr bool need_f16_V = Dv != 128 && Dv != 64; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -303,9 +313,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } @@ -314,9 +324,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } @@ -325,9 +335,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } @@ -336,9 +346,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } @@ -346,15 +356,19 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } } #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ template void ggml_cuda_flash_attn_ext_vec_f32_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_f32_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); @@ -406,3 +420,6 @@ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index efe78a2f..d39c6a6e 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" @@ -5,8 +12,8 @@ #include #endif // FP16_MMA_AVAILABLE -// D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template +// Dk == K head size, Dv = V head size, VKQ_stride == num VKQ rows calculated in parallel: +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -47,8 +54,9 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)); // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } @@ -58,11 +66,11 @@ static __global__ void flash_attn_ext_f16( const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(Dk <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; - static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + static_assert(Dk % frag_m == 0, "If ncols == 8 then Dk % frag_m must be 0."); typedef nvcuda::wmma::fragment frag_a_K; typedef nvcuda::wmma::fragment frag_a_V; typedef nvcuda::wmma::fragment frag_b; @@ -74,30 +82,32 @@ static __global__ void flash_attn_ext_f16( static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: - constexpr int D_padded = D + 8; + constexpr int Dk_padded = Dk + 8; + constexpr int Dv_padded = Dv + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); - const int stride_Q = nb01 / sizeof(float); - const int stride_KV = nb11 / sizeof(half); + const int stride_Q = nb01 / sizeof(float); + const int stride_K = nb11 / sizeof(half); + const int stride_V = nb21 / sizeof(half); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); const half2 softcap_2 = make_half2(softcap, softcap); - frag_b Q_b[D/16][ncols/frag_n]; + frag_b Q_b[Dk/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: constexpr int mem_KQ = ncols*kqs_padded*kqar; - constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*Dv_padded; __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; float * KQ_f = (float *) KQ; half2 * KQ2 = (half2 *) KQ; @@ -120,18 +130,18 @@ static __global__ void flash_attn_ext_f16( KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); } - __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + __shared__ half VKQ[ncols*Dv_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) { break; } - VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + VKQ2[j*(Dv_padded/2) + i] = make_half2(0.0f, 0.0f); } } @@ -140,12 +150,12 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + WARP_SIZE > Dk && i >= Dk) { break; } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + KQ[j*Dk_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -153,10 +163,10 @@ static __global__ void flash_attn_ext_f16( // Load Q into tensor core fragments/registers since it will be used frequently: #pragma unroll - for (int i0 = 0; i0 < D; i0 += 16) { + for (int i0 = 0; i0 < Dk; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*Dk_padded + i0, Dk_padded); } } @@ -173,9 +183,9 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 16) { frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_K + k_KQ_0, stride_K); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -309,9 +319,9 @@ static __global__ void flash_attn_ext_f16( } } - frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; + frag_c_VKQ VKQ_c[Dv/VKQ_stride][ncols/frag_n]; #pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { + for (int i_VKQ_0 = 0; i_VKQ_0 < Dv; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); @@ -322,7 +332,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_V + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_V); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -332,15 +342,15 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*Dk_padded); #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { nvcuda::wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ + offset_k + j0*Dk_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); + Dk_padded, nvcuda::wmma::mem_col_major); } } @@ -358,18 +368,18 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) { break; } half2 VKQ_add = make_half2(0.0f, 0.0f); #pragma unroll for (int l = 0; l < VKQ_ratio; ++l) { - VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + VKQ_add += KQ2[l*(ncols*Dk_padded/2) + j*(Dk_padded/2) + i]; } - VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + VKQ2[j*(Dv_padded/2) + i] = VKQ_scale*VKQ2[j*(Dv_padded/2) + i] + VKQ_add; } } @@ -392,16 +402,16 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + WARP_SIZE > Dv && i >= Dv) { break; } - float dst_val = VKQ[j_VKQ*D_padded + i]; + float dst_val = VKQ[j_VKQ*Dv_padded + i]; if (parallel_blocks == 1) { dst_val /= KQ_rowsum_j; } - dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + dst[j_dst*gridDim.y*Dv + blockIdx.y*Dv + i] = dst_val; } if (parallel_blocks == 1 || threadIdx.x != 0) { @@ -446,13 +456,13 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -template +template void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; constexpr int nwarps = 4; - constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; + constexpr int frag_m = cols_per_block == 8 && Dk % 32 == 0 ? 32 : 16; const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; @@ -462,29 +472,33 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16 : - flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16 : + flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16 : - flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16 : + flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16 : - flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16 : + flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ template void ggml_cuda_flash_attn_ext_wmma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_WMMA_F16_CASE_DKDV(Dk, Dv, cols_per_block, KQ_acc_t) \ + template void ggml_cuda_flash_attn_ext_wmma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float); extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float); @@ -518,3 +532,7 @@ extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half); extern DECL_FATTN_WMMA_F16_CASE(112, 32, half); extern DECL_FATTN_WMMA_F16_CASE(128, 32, half); extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); + +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half); +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half); +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index c15d6c81..725b443d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" #include "fattn-tile-f16.cuh" @@ -5,13 +12,106 @@ #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" +#include "fattn-mma-f16.cuh" +#include "fattn-new-mma.cuh" #include "fattn.cuh" #include +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); +} + +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const float use_gqa_opt = mask && max_bias == 0.0f; + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio == 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio == 2) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); +} + static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * V = dst->src[2]; + + if (Q->ne[0] != V->ne[0]) { + if (!((Q->ne[0] == 192 && V->ne[0] == 128) || (Q->ne[0] == 576 && V->ne[0] == 512))) { + fprintf(stderr, "======================= %s: Unhandled head size combination %d, %d\n", __func__, (int)Q->ne[0], (int)V->ne[0]); + GGML_ABORT("fatal error"); + } + } const int32_t precision = KQV->op_params[3]; @@ -20,22 +120,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 16; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -46,19 +149,22 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); break; // case 256: // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); @@ -76,16 +182,19 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 8; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -99,22 +208,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 16; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -127,22 +239,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -152,7 +267,13 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g } #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ + ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ + return; \ + } \ + +#define FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \ + if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ return; \ } \ @@ -212,12 +333,16 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) + + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #else FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -226,19 +351,30 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) + + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } #define FATTN_VEC_F32_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ + ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ + return; \ + } \ + +#define FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \ + if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ return; \ } \ @@ -298,6 +434,10 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) + + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #else FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -306,14 +446,21 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) + + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; @@ -330,7 +477,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (!fast_fp16_available(cc)) { - if (Q->ne[1] <= 8) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); } else { ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); @@ -339,23 +486,60 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (!fp16_mma_available(cc)) { - if (Q->ne[1] <= 8) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + if (precision == GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + } } else { - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); + } } return; } - if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations + // So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster. + // On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache. + //const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; + //const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion; + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies; + const bool can_use_vector_kernel = Q->ne[0] % (2*WARP_SIZE) == 0; + if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (precision == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - return; - } else if(Q->ne[0] <= 128) { + } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - return; } + return; } - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + // + // It turns out the new new MMA implementation is slower than the + // previous MMA implementation. + // Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512, + // so no other implementation works. + // + if (new_mma_available(cc) && Q->ne[0] == 576) { + ggml_cuda_flash_attn_ext_mma_new(ctx, dst); + return; + } + + // + // We need this because I haven't adapted new MMA kernels to work for different + // K and V head sizes. + // We also need it if the new MMA is not available + // + if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + return; + } + + // As mentioned above, the new new MMA is slower than then the new MMA. + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); + //ggml_cuda_flash_attn_ext_mma_new(ctx, dst); } diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 4c370323..f734271a 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -1,10 +1,17 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "getrows.cuh" #include "dequantize.cuh" template static __global__ void k_get_rows( const void * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + int64_t ne00, int64_t ne01, /*int64_t ne02, int64_t ne03,*/ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ /*size_t s0,*/ size_t s1, size_t s2, size_t s3, /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, @@ -31,7 +38,11 @@ static __global__ void k_get_rows( // dequantize dfloat2 v; - dequantize_kernel(src0_row, ib, iqs, v); + if (i01 >= 0 && i01 < ne01) { + dequantize_kernel(src0_row, ib, iqs, v); + } else { + v.x = v.y = 0; + } dst_row[iybs + iqs + 0] = v.x; dst_row[iybs + iqs + y_offset] = v.y; @@ -40,7 +51,7 @@ static __global__ void k_get_rows( template static __global__ void k_get_rows_float( const src0_t * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + int64_t ne00, int64_t ne01, /*int64_t ne02, int64_t ne03,*/ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ /*size_t s0,*/ size_t s1, size_t s2, size_t s3, /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, @@ -56,11 +67,10 @@ static __global__ void k_get_rows_float( } const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); - dst_row[i00] = src0_row[i00]; + dst_row[i00] = i01 >= 0 && i01 < ne01 ? dst_t(src0_row[i00]) : dst_t(0); } template @@ -88,7 +98,7 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg k_get_rows<<>>( src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ + ne00, ne01, /*ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, @@ -120,7 +130,7 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr k_get_rows_float<<>>( src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ + ne00, ne01, /*ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 8e741613..1b8fbff5 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -15,11 +15,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ namespace { template -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -// tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__global__ void iqk_mul_mat_vec_q( +__device__ void iqk_mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) { @@ -94,10 +90,29 @@ __global__ void iqk_mul_mat_vec_q( } } +template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +// tell the compiler to use as many registers as it wants, see nwarps definition below +__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__global__ void iqk_mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size, + const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { + int i2 = blockIdx.y; + int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; + if (i02 < 0) return; + const char * cx = (const char *)vx + i02*nb02; + const char * cy = (const char *)vy + i2*nb12; + char * cdst = (char *)dst + i2*nb2; + iqk_mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); +} + template void iqk_mul_mat_vec_q_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); //GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); @@ -132,35 +147,35 @@ void iqk_mul_mat_vec_q_cuda( } } const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; - const dim3 block_nums(nblocks, 1, 1); + const dim3 block_nums(nblocks, ne2, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); const int64_t row_size = ggml_row_size(type, ncols_x); switch (ncols_y) { case 1: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 2: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 3: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 4: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 5: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 6: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 7: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 8: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; default: GGML_ASSERT(false); @@ -734,45 +749,51 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1( } // namespace void mul_mat_vec_iq2_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq3_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq4_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq4_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq4_kss_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq2_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq2_kt_q8_1_cuda( @@ -783,27 +804,31 @@ void mul_mat_vec_iq2_kt_q8_1_cuda( } void mul_mat_vec_iq5_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq6_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq1_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq2_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index a4b271bd..781e1afa 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -1,45 +1,61 @@ +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" void mul_mat_vec_iq2_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq3_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq4_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq5_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq6_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq4_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq4_kss_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq2_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq2_kt_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); void mul_mat_vec_iq1_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq2_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mma_new.cuh b/ggml/src/ggml-cuda/mma_new.cuh new file mode 100644 index 00000000..5bba41c6 --- /dev/null +++ b/ggml/src/ggml-cuda/mma_new.cuh @@ -0,0 +1,396 @@ +// This file contains primitives that expose the tensor core PTX instructions for CUDA code. +// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. +// The documentation for the PTX instructions can be found under: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction +// +// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. +// A is a row-major matrix with shape M x K. +// B is a column-major matrix with shape K x N. +// C is a column-major matrix with shape M x N. +// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns. +// Note that J is measured in physical 32 bit elements instead of logical elements. +// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// All matrix tiles have ne physical 32 bit elements per warp. +// +// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. + +#include "common.cuh" + + +#if CUDART_VERSION >= 11080 + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + int ret = 0; + +#ifdef INT8_MMA_AVAILABLE + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "=r"(ret) : "r"(x)); +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // defined(INT8_MMA_AVAILABLE) + return ret; +} + +#else + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + // Imagine transposing row-major matrix to column-major matrix. + const int src_i_low = 2 * (threadIdx.x % 4); + const int src_i_high = src_i_low + 1; + const int src_j = threadIdx.x / 4; + + const int src_laneid_low = src_i_low * 4 + src_j / 2; + const int src_laneid_high = src_i_high * 4 + src_j / 2; + + const int shift_low = ((src_j + 0) % 2) * 16; + const int shift_high = ((src_j + 1) % 2) * 16; + + const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + + return ret_low | ret_high; +} + +#endif // CUDART_VERSION >= 11080 + +static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { + half2 ret; + *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x)); + return ret; +} + +namespace ggml_cuda_mma { + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && (J == 4 || J == 8)) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 16) { + return ((l / 2) % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 8 && J == 8) { + return 4 * l + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x % 4) + l % 2; + } else if constexpr (I == 16 && J == 16) { + return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + half2 x[ne] = {{0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return l * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return l * 4 + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 4 + threadIdx.x % 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + } + return ret; + } + + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + tile<8, 8, half2> ret; + ret.x[0] = ggml_cuda_movmatrix(t.x[0]); + ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + + return ret; + } + + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } + } + + template + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) + : "l"(xs)); +#else + load_generic(t, xs0, stride); +#endif // INT8_MMA_AVAILABLE + } + + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) + : "l"(xs)); +#else + load_generic(xs0, stride); + GGML_UNUSED(t); +#endif // INT8_MMA_AVAILABLE + } + + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) + : "l"(xs)); +#else + load_generic(t, xs0, stride); +#endif // INT8_MMA_AVAILABLE + } + + template + static __device__ __forceinline__ void load_ldmatrix_trans( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(t); + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1])); +#else + // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[2]), "r"(B.x[1])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[3]), "r"(B.x[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } +} diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index f9fc2438..67897a83 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "mmq.cuh" void ggml_cuda_op_mul_mat_q( @@ -7,6 +14,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; + const int64_t nb01 = src0->nb[1]; const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; @@ -15,7 +23,6 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - const int64_t stride00 = ne00 / ggml_blck_size(src0->type); int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; @@ -24,7 +31,7 @@ void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; + const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, nb01, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; switch (src0->type) { case GGML_TYPE_Q4_0: @@ -84,6 +91,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_NL: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ4_KS: + mul_mat_q_case(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -121,6 +131,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_KS: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 416b4336..148697e2 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #include "common.cuh" @@ -75,6 +82,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_KS: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -153,26 +161,28 @@ static constexpr __device__ int get_mmq_y_device() { #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { - return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : - type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : - type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 : - type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : - type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : - type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : - tile_x_sizes{0, 0, 0}; + switch (type) { + case GGML_TYPE_Q4_1 : return MMQ_DP4A_TXS_Q4_1; + case GGML_TYPE_Q5_0 : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q5_1 : return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q6_0 : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q8_0 : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q2_K : return MMQ_DP4A_TXS_Q2_K; + case GGML_TYPE_Q3_K : return MMQ_DP4A_TXS_Q3_K; + case GGML_TYPE_Q4_K : return MMQ_DP4A_TXS_Q4_K; + case GGML_TYPE_Q5_K : return MMQ_DP4A_TXS_Q5_K; + case GGML_TYPE_Q6_K : return MMQ_DP4A_TXS_Q6_K; + case GGML_TYPE_IQ2_XXS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_XS : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_S : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_XXS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_S : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; + default : return tile_x_sizes{0, 0, 0}; + } } #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) @@ -188,26 +198,29 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : - 0; + switch (type) { + case GGML_TYPE_Q4_0 : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q4_1 : return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_0 : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q5_1 : return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_0 : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q8_0 : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q2_K : return MMQ_MMA_TILE_X_K_Q2_K; + case GGML_TYPE_Q3_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_Q4_K : return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_K : return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_K : return MMQ_MMA_TILE_X_K_Q6_K; + case GGML_TYPE_IQ2_XXS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_XS : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_S : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_XXS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_S : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; + default : return 0; + } } #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) @@ -251,7 +264,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE @@ -273,7 +286,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -346,7 +359,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE @@ -368,7 +381,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; @@ -441,7 +454,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b2(bxi->qs, kqsx); const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); @@ -480,7 +493,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -513,7 +526,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b4(bxi->qs, kqsx); const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); @@ -550,7 +563,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; @@ -583,7 +596,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbx; + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b2(bxi->qs, kqsx); const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2); @@ -613,7 +626,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbxd; + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -646,7 +659,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbx; #ifdef INT8_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); @@ -668,7 +681,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -1034,7 +1047,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; + const block_q2_K * bxi = (const block_q2_K *)(x + i*stride) + kbx0; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); @@ -1265,7 +1278,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); @@ -1295,7 +1308,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; const int ksc = threadIdx.x % (WARP_SIZE/8); @@ -1331,7 +1344,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; x_df[i] = bxi->d; } @@ -1402,7 +1415,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; const int qs0 = get_int_b4(bxi->qs, threadIdx.x); #ifdef INT8_MMA_AVAILABLE @@ -1423,7 +1436,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/16); @@ -1452,7 +1465,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; x_dm[i] = bxi->dm; } @@ -1465,7 +1478,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0 + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; @@ -1531,7 +1544,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int ky = QR5_K*threadIdx.x; const int ql = get_int_b4(bxi->qs, threadIdx.x); @@ -1564,7 +1577,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/16); @@ -1593,7 +1606,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; x_dm[i] = bxi->dm; } @@ -1606,7 +1619,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; @@ -1673,7 +1686,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0; const int ql = get_int_b2(bxi->ql, threadIdx.x); const int ql0 = (ql >> 0) & 0x0F0F0F0F; @@ -1706,7 +1719,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; @@ -1723,7 +1736,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + (threadIdx.x % (WARP_SIZE/8)) / 4; #ifdef INT8_MMA_AVAILABLE x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); @@ -1898,7 +1911,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbx; const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); @@ -1923,7 +1936,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); @@ -1955,7 +1968,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + const block_iq2_xxs * bxi = (const block_iq2_xxs *)(x + i*stride) + kbx0; const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); const uint8_t * aux8 = (const uint8_t *) &q2; @@ -2013,7 +2026,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + const block_iq2_xs * bxi = (const block_iq2_xs *)(x + i*stride) + kbx0; const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint16_t * q2 = (const uint16_t *) &q2_packed; @@ -2069,7 +2082,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + const block_iq2_s * bxi = (const block_iq2_s *)(x + i*stride) + kbx0; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2132,7 +2145,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + const block_iq3_xxs * bxi = (const block_iq3_xxs *)(x + i*stride) + kbx0; const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * q3 = (const uint8_t *) &q3_packed; @@ -2188,7 +2201,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + const block_iq3_s * bxi = (const block_iq3_s *)(x + i*stride) + kbx0; const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2251,7 +2264,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + const block_iq1_s * bxi = (const block_iq1_s *)(x + i*stride) + kbx0; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2308,7 +2321,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx; + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx; const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); @@ -2330,7 +2343,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0; const float d = __half2float(bxi->d); @@ -2345,6 +2358,64 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_iq4_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = 0; // threadIdx.x / QI4_XS + const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx; + + auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4); + const int aux_q4 = get_int_b4(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, values); + const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0; + const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls; +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void mmq_write_back_dp4a( const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { @@ -2566,6 +2637,14 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template static __device__ void mul_mat_q_process_tile( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, @@ -2598,7 +2677,7 @@ static __device__ void mul_mat_q_process_tile( const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { - load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); + load_tiles(x + stride01*it*mmq_y, tile_x, kb0, tile_x_max_i, stride01); { const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); @@ -2879,6 +2958,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<<>> (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + } else { constexpr bool need_check = true; @@ -2887,6 +2967,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<<>> (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + } } @@ -3000,6 +3081,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 80364373..ff76f34d 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,61 +1,68 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "mmvq.cuh" -#include "vecdotq.cuh" #include "iqk_mmvq.cuh" +#include "vecdotq.cuh" typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 : - type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : - type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : - type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : - type == GGML_TYPE_Q6_0 ? vec_dot_q6_0_q8_1 : - type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : - type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : - type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : - type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : - type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : - type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : - type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : - type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : - type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : - type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : - type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : - type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : - type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : - type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : - type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : - nullptr; + switch (type) { + case GGML_TYPE_Q4_0 : return vec_dot_q4_0_q8_1; + case GGML_TYPE_Q4_1 : return vec_dot_q4_1_q8_1; + case GGML_TYPE_Q5_0 : return vec_dot_q5_0_q8_1; + case GGML_TYPE_Q5_1 : return vec_dot_q5_1_q8_1; + case GGML_TYPE_Q6_0 : return vec_dot_q6_0_q8_1; + case GGML_TYPE_Q8_0 : return vec_dot_q8_0_q8_1; + case GGML_TYPE_Q2_K : return vec_dot_q2_K_q8_1; + case GGML_TYPE_Q3_K : return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q4_K : return vec_dot_q4_K_q8_1; + case GGML_TYPE_Q5_K : return vec_dot_q5_K_q8_1; + case GGML_TYPE_Q6_K : return vec_dot_q6_K_q8_1; + case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; + case GGML_TYPE_IQ2_XS : return vec_dot_iq2_xs_q8_1; + case GGML_TYPE_IQ2_S : return vec_dot_iq2_s_q8_1; + case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; + case GGML_TYPE_IQ1_S : return vec_dot_iq1_s_q8_1; + case GGML_TYPE_IQ1_M : return vec_dot_iq1_m_q8_1; + case GGML_TYPE_IQ4_NL : return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_IQ4_XS : return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_IQ3_S : return vec_dot_iq3_s_q8_1; + default : return nullptr; + } } static constexpr __device__ int get_vdr_mmvq(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : - type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : - type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : - type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : - type == GGML_TYPE_Q6_0 ? VDR_Q6_0_Q8_1_MMVQ : - type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : - type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : - type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : - type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : - type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : - type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : - 1; + switch (type) { + case GGML_TYPE_Q4_0 : return VDR_Q4_0_Q8_1_MMVQ; + case GGML_TYPE_Q4_1 : return VDR_Q4_1_Q8_1_MMVQ; + case GGML_TYPE_Q5_0 : return VDR_Q5_0_Q8_1_MMVQ; + case GGML_TYPE_Q5_1 : return VDR_Q5_1_Q8_1_MMVQ; + case GGML_TYPE_Q6_0 : return VDR_Q6_0_Q8_1_MMVQ; + case GGML_TYPE_Q8_0 : return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_Q2_K : return VDR_Q2_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_K : return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q4_K : return VDR_Q4_K_Q8_1_MMVQ; + case GGML_TYPE_Q5_K : return VDR_Q5_K_Q8_1_MMVQ; + case GGML_TYPE_Q6_K : return VDR_Q6_K_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XXS : return VDR_IQ2_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XS : return VDR_IQ2_XS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_S : return VDR_IQ2_S_Q8_1_MMVQ; + case GGML_TYPE_IQ3_XXS : return VDR_IQ3_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ3_S : return VDR_IQ3_S_Q8_1_MMVQ; + case GGML_TYPE_IQ4_NL : return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_IQ4_XS : return VDR_IQ4_XS_Q8_1_MMVQ; + default : return 1; + } } -template -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -// tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void mul_mat_vec_q( +template +static __device__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -66,10 +73,8 @@ static __global__ void mul_mat_vec_q( constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) - constexpr int nwarps = 1; constexpr int rows_per_cuda_block = 1; #else - constexpr int nwarps = ncols_y <= 4 ? 4 : 2; constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) @@ -132,71 +137,106 @@ static __global__ void mul_mat_vec_q( } } -template -static void mul_mat_vec_q_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { +template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +// tell the compiler to use as many registers as it wants, see nwarps definition below +__launch_bounds__(nwarps*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, + const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { + int i2 = blockIdx.y; + char * cdst = (char *)dst + i2*nb2; + int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; + if (i02 < 0) { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) + constexpr int rows_per_cuda_block = 1; +#else + constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) + const int row0 = rows_per_cuda_block*blockIdx.x; + if (threadIdx.y == 0) { + dst = (float *)cdst; + for (int j = 0; j < ncols_y; ++j) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { + dst[j*nrows_dst + row0 + threadIdx.x] = 0; + } + } + } + return; + } + const char * cx = (const char *)vx + i02*nb02; + const char * cy = (const char *)vy + i2*nb12; + mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +template +static void mul_mat_vec_q_cuda_T( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); int id = ggml_cuda_get_device(); - int64_t nwarps = 1; - int64_t rows_per_cuda_block = 1; + int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ? + ncols_y < 4 ? 1 : 2 : 1; - if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 - switch(ncols_y) { - case 1: - nwarps = 4; - rows_per_cuda_block = 1; - break; - case 2: - case 3: - case 4: - nwarps = 4; - rows_per_cuda_block = 2; - break; - case 5: - case 6: - case 7: - case 8: - nwarps = 2; - rows_per_cuda_block = 2; - break; - default: - GGML_ABORT("fatal error"); - break; - } - } + //if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + // switch(ncols_y) { + // case 1: + // nwarps = 4; + // rows_per_cuda_block = 1; + // break; + // case 2: + // case 3: + // case 4: + // nwarps = 4; + // rows_per_cuda_block = 2; + // break; + // case 5: + // case 6: + // case 7: + // case 8: + // nwarps = 2; + // rows_per_cuda_block = 2; + // break; + // default: + // GGML_ABORT("fatal error"); + // break; + // } + //} const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; - const dim3 block_nums(nblocks, 1, 1); + const dim3 block_nums(nblocks, ne2, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); switch (ncols_y) { case 1: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 2: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 3: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 4: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 5: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 6: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 7: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 8: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; default: GGML_ABORT("fatal error"); @@ -204,144 +244,328 @@ static void mul_mat_vec_q_cuda( } } -static void mul_mat_vec_q4_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { +template +static void mul_mat_vec_q_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { + int nwarps = 1; + int id = ggml_cuda_get_device(); + if (ne2 < 2 && ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + nwarps = ncols_y <= 4 ? 4 : 2; + } + switch (nwarps) { + case 1: + mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, + ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case 2: + mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, + ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + default: + mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, + ne2, nb02, nb12, nb2, ids_nb0, stream); + } +} - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +static void mul_mat_vec_q4_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q4_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q5_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q5_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q6_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q8_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q2_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q3_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q4_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q5_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q6_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq2_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq2_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq2_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq3_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq1_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq1_m_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq4_nl_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq4_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq3_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +} + +static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, + const int64_t ne00, const int64_t ne0, const int64_t ne2, + const int64_t nb02, const int64_t nb12, const int64_t nb2, const int64_t ids_nb0, + const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, const char * ids_data, + const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t row_diff = row_high - row_low; + + int id = ggml_cuda_get_device(); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + + switch (type) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q6_0: + mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ1_BN: + mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ2_BN: + mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ2_K: + mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ3_K: + mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ4_K: + mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ4_KS: + mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ4_KSS: + mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ2_KS: + mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ5_K: + mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ6_K: + mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } + +} + +void ggml_cuda_op_mul_mat_vec_q_3D( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1); + GGML_ASSERT(src0->ne[2] == src1->ne[2] && src0->ne[2] == dst->ne[2]); + + const int64_t ne0 = dst->ne[0]; + + const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size); + + ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, + ne00, ne0, dst->ne[2], + src0->nb[2], src1_row_size, dst->nb[2], 0, + src0_dd_i, src1_ddq_i, dst_dd_i, nullptr, + row_low, row_high, src1_ncols, + src1_padded_row_size, stream); + + GGML_UNUSED(src1_ddf_i); } void ggml_cuda_op_mul_mat_vec_q( @@ -351,120 +575,44 @@ void ggml_cuda_op_mul_mat_vec_q( const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); const int64_t ne0 = dst->ne[0]; - int id = ggml_cuda_get_device(); + ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, + ne00, ne0, 1, 0, 0, 0, 0, + src0_dd_i, src1_ddq_i, dst_dd_i, nullptr, + row_low, row_high, src1_ncols, + src1_padded_row_size, stream); + + GGML_UNUSED(src1_ddf_i); +} + +void ggml_cuda_op_mul_mat_vec_q_id( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1); + GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1); + GGML_ASSERT(ids->ne[0] == dst->ne[2]); + + const int64_t ne0 = dst->ne[0]; + + ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, + ne00, ne0, dst->ne[2], + src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], + src0_dd_i, src1_ddq_i, dst_dd_i, (const char *)ids->data, + row_low, row_high, src1_ncols, + src1_padded_row_size, stream); - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - - switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q6_0: - mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_BN: - mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_BN: - mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_K: - mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ3_K: - mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_K: - mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_KS: - mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_KSS: - mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_KS: - mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ5_K: - mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ6_K: - mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - default: - GGML_ABORT("fatal error"); - break; - } - - GGML_UNUSED(src1); - GGML_UNUSED(dst); GGML_UNUSED(src1_ddf_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); } bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index e8ec6850..d17765f1 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,11 +1,27 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. -void ggml_cuda_op_mul_mat_vec_q( - ggml_backend_cuda_context & ctx, +void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_mmvq_type_supported(ggml_type src0_type); +void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); + +void ggml_cuda_op_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 65c7e5f1..953eb9d9 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "quantize.cuh" #include @@ -37,6 +44,42 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded, const uint64_t stride) { + const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (ix0 >= kx0_padded) { + return; + } + + const int64_t ix1 = blockIdx.y; + + const int64_t i_padded = ix1*kx0_padded + ix0; + + block_q8_1 * y = (block_q8_1 *) vy; + + const int64_t ib = i_padded / QK8_1; // block index + const int64_t iqs = i_padded % QK8_1; // quant index + + const float xi = ix0 < kx ? x[ix1*stride + ix0] : 0.0f; + float amax = fabsf(xi); + float sum = xi; + + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + reinterpret_cast(y[ib].ds.x) = d; + reinterpret_cast(y[ib].ds.y) = sum; +} + template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { @@ -164,3 +207,19 @@ void quantize_mmq_q8_1_cuda( break; } } + +void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream) { + GGML_ASSERT(src->ne[1] == 1 && src->ne[3] == 1); + GGML_ASSERT(src->type == GGML_TYPE_F32); + const int64_t src_padded_col_size = GGML_PAD(src->ne[0], MATRIX_ROW_PADDING); + GGML_ASSERT(src_padded_col_size % QK8_1 == 0); + if (src->ne[2] == 1 || ggml_is_contiguous(src)) { + quantize_row_q8_1_cuda((const float *)src->data, vy, src->ne[0], 1, 1, src_padded_col_size, type, stream); + return; + } + const int64_t block_num_x = (src_padded_col_size + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, src->ne[2]*src->ne[3], 1); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); + const uint64_t stride = src->nb[2]/sizeof(float); + quantize_q8_1<<>>((const float *)src->data, vy, src->ne[0], src_padded_col_size, stride); +} diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 03bf322b..0be5bf0e 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #include "common.cuh" @@ -22,3 +29,6 @@ void quantize_row_q8_1_cuda( void quantize_mmq_q8_1_cuda( const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream); + +// For now only applicable for tensors with ne[1] = 1, ne[3] = 1, and useful if ne[2] > 1 +void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 42ca72af..dd25a2ce 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "rope.cuh" struct rope_corr_dims { diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index 499025d1..268b9ae5 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -1,3 +1,9 @@ +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "softcap.cuh" static __global__ void softcap_f32(const float * x, float * dst, float s_before, float s_after, const int k) { diff --git a/ggml/src/ggml-cuda/softcap.cuh b/ggml/src/ggml-cuda/softcap.cuh index 2b875bfb..4c345b2e 100644 --- a/ggml/src/ggml-cuda/softcap.cuh +++ b/ggml/src/ggml-cuda/softcap.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define CUDA_SOFTCAP_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 6f3056e6..c006301f 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "softmax.cuh" diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh index 49a83dfa..b97e9a92 100644 --- a/ggml/src/ggml-cuda/softmax.cuh +++ b/ggml/src/ggml-cuda/softmax.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define CUDA_SOFT_MAX_BLOCK_SIZE 1024 diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu new file mode 100644 index 00000000..80108615 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu new file mode 100644 index 00000000..66161c0a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu new file mode 100644 index 00000000..ee88c72a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 16, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu new file mode 100644 index 00000000..d888a5a4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu new file mode 100644 index 00000000..d93a2d08 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu new file mode 100644 index 00000000..617464c9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu new file mode 100644 index 00000000..970d2b68 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 32, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu new file mode 100644 index 00000000..65cd377c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 32, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu new file mode 100644 index 00000000..f4a8bf34 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu new file mode 100644 index 00000000..de191a8a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu new file mode 100644 index 00000000..e8cb0e1b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu new file mode 100644 index 00000000..a532e962 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 64, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu new file mode 100644 index 00000000..bf25181a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 8, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu new file mode 100644 index 00000000..378c132e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu new file mode 100644 index 00000000..372641be --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu new file mode 100644 index 00000000..9ff5968b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu new file mode 100644 index 00000000..7dda0133 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu new file mode 100644 index 00000000..740ac37d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu new file mode 100644 index 00000000..f257f5d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu new file mode 100644 index 00000000..1ea24302 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu new file mode 100644 index 00000000..6be4d042 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu new file mode 100644 index 00000000..a0f03f49 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu index 2d94e65c..334e1deb 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, float); DECL_FATTN_WMMA_F16_CASE(112, 16, float); DECL_FATTN_WMMA_F16_CASE(128, 16, float); DECL_FATTN_WMMA_F16_CASE(256, 16, float); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu index c3d9df3c..1faf3c9b 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu @@ -7,3 +7,5 @@ DECL_FATTN_WMMA_F16_CASE(80, 32, float); DECL_FATTN_WMMA_F16_CASE(96, 32, float); DECL_FATTN_WMMA_F16_CASE(112, 32, float); DECL_FATTN_WMMA_F16_CASE(128, 32, float); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu index bb680e40..48973618 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, half); DECL_FATTN_WMMA_F16_CASE(112, 16, half); DECL_FATTN_WMMA_F16_CASE(128, 16, half); DECL_FATTN_WMMA_F16_CASE(256, 16, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu index 073f71b1..ed92963e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 32, half); DECL_FATTN_WMMA_F16_CASE(112, 32, half); DECL_FATTN_WMMA_F16_CASE(128, 32, half); DECL_FATTN_WMMA_F16_CASE(256, 32, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu index d30710c5..4e221003 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu @@ -6,3 +6,5 @@ DECL_FATTN_WMMA_F16_CASE(64, 8, half); DECL_FATTN_WMMA_F16_CASE(96, 8, half); DECL_FATTN_WMMA_F16_CASE(128, 8, half); DECL_FATTN_WMMA_F16_CASE(256, 8, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu new file mode 100644 index 00000000..940c2da8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 8ffddd6d..6312f25c 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "unary.cuh" static __global__ void gelu_f32(const float * x, float * dst, const int k) { @@ -297,6 +304,19 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream); } +void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, + int64_t nelements, const float * src0_d, const float * src1_d, float * dst_d) { + + cudaStream_t stream = ctx.stream(); + + switch (op) { + case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break; + case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break; + case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break; + default: GGML_ASSERT(false); + } +} + void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -304,19 +324,22 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, src1)); - cudaStream_t stream = ctx.stream(); ggml_unary_op op = (ggml_unary_op)dst->op_params[0]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data); - switch (op) { - case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - default: GGML_ASSERT(false); - } + //cudaStream_t stream = ctx.stream(); + + //const float * src0_d = (const float *)src0->data; + //const float * src1_d = (const float *)src1->data; + //float * dst_d = (float *)dst->data; + + //switch (op) { + // case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + // case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + // case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + // default: GGML_ASSERT(false); + //} } void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 0235a319..9bcd30a8 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define CUDA_GELU_BLOCK_SIZE 256 @@ -36,5 +43,7 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, + int64_t nelements, const float * x, const float * y, float * z); void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index e9af29b9..cae5e04f 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1131,6 +1131,18 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) { return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); } +static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) { + const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; + const int8_t * q0_8 = (const int8_t *) &q0_32; + const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]); + + const int q1_32 = (q4 >> 4) & 0x0F0F0F0F; + const int8_t * q1_8 = (const int8_t *) &q1_32; + const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]); + + return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); +} + #define VDR_IQ4_NL_Q8_1_MMVQ 2 #define VDR_IQ4_NL_Q8_1_MMQ 4 diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 0498be1f..501fe5a2 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -225,6 +225,39 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, @@ -276,9 +309,33 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F16_F16, @@ -290,15 +347,17 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, - GGML_METAL_KERNEL_TYPE_CONCAT, + GGML_METAL_KERNEL_TYPE_CONCAT_F32, + GGML_METAL_KERNEL_TYPE_CONCAT_F16, GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_COUNT }; +#define GGML_METAL_MAX_COMMAND_BUFFERS 8 + struct ggml_backend_metal_context { - int n_cb; id device; id queue; @@ -307,6 +366,26 @@ struct ggml_backend_metal_context { struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; + // capture state + bool capture_next_compute; + bool capture_started; + + id capture_scope; + + // command buffer state + int n_cb; // number of extra threads used to submit the command buffers + int n_nodes_0; // number of nodes submitted by the main thread + int n_nodes_1; // remaining number of nodes submitted by the n_cb threads + int n_nodes_per_cb; + + struct ggml_cgraph * gf; + + // the callback given to the thread pool + void (^encode_async)(size_t ith); + + // n_cb command buffers + 1 used by the main thread + id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + bool support_simdgroup_reduction; bool support_simdgroup_mm; @@ -373,7 +452,6 @@ static void * ggml_metal_host_malloc(size_t n) { const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); if (result != 0) { GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); - return NULL; } #endif @@ -533,6 +611,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); ctx->should_capture_next_compute = false; + ctx->capture_started = false; + ctx->capture_scope = nil; + + ctx->gf = nil; + ctx->encode_async = nil; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + ctx->command_buffers[i] = nil; + } #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { @@ -765,6 +851,39 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, mul_mm_iq4_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, mul_mm_iq5_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, mul_mm_iq6_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16, mul_mm_f32_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16, mul_mm_f16_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16, mul_mm_bf16_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16, mul_mm_q4_0_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16, mul_mm_q4_1_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16, mul_mm_q5_0_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16, mul_mm_q5_1_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16, mul_mm_q6_0_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16, mul_mm_q8_0_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16, mul_mm_q2_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16, mul_mm_q3_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16, mul_mm_q4_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16, mul_mm_q5_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16, mul_mm_q6_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16, mul_mm_iq2_xxs_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16, mul_mm_iq2_xs_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16, mul_mm_iq3_xxs_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16, mul_mm_iq3_s_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16, mul_mm_iq2_s_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16, mul_mm_iq1_s_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16, mul_mm_iq1_m_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, mul_mm_iq1_bn_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, mul_mm_iq2_bn_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, mul_mm_iq4_nl_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, mul_mm_iq4_xs_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, mul_mm_iq4_ks_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16, mul_mm_iq4_kss_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16, mul_mm_iq2_k_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16, mul_mm_iq2_ks_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16, mul_mm_iq3_k_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16, mul_mm_iq4_k_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16, mul_mm_iq5_k_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16, mul_mm_iq6_k_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, ctx->support_simdgroup_mm); @@ -816,9 +935,33 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,flash_attn_ext_f16_hk192_hv128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,flash_attn_ext_f16_hk576_hv512, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,flash_attn_ext_q8_0_hk192_hv128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,flash_attn_ext_q8_0_hk576_hv512, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,flash_attn_ext_vec_f16_hk192_hv128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,flash_attn_ext_vec_f16_hk576_hv512, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80, flash_attn_ext_vec_q8_0_h80, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112, flash_attn_ext_vec_q8_0_h112, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,flash_attn_ext_vec_q8_0_hk192_hv128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,flash_attn_ext_vec_q8_0_hk576_hv512, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); @@ -830,7 +973,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, cpy_f32_q6_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT_F32, concat_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT_F16, concat_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } @@ -846,6 +990,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->kernels[i].pipeline release]; } + Block_release(ctx->encode_async); + [ctx->queue release]; [ctx->device release]; @@ -971,17 +1117,24 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: - if (op->src[1]->type != GGML_TYPE_F16) { + if (!ctx->support_simdgroup_mm) { + return false; // TODO: over-restricted for vec-kernels + } + if (op->src[1]->type != op->src[2]->type || + (op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_Q8_0)) { return false; } - if (op->src[2]->type != GGML_TYPE_F16) { - return false; + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + return (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) || + (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512); } - if (op->src[0]->ne[0] == 256) { - return false; - } - return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels + return (op->src[1]->ne[0] == 64 || op->src[1]->ne[0] == 80 || + op->src[1]->ne[0] == 96 || op->src[1]->ne[0] == 112 || + op->src[1]->ne[0] == 128 || op->src[1]->ne[0] == 256); case GGML_OP_MUL_MAT: + return ctx->support_simdgroup_reduction && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + !(op->src[0]->type >= GGML_TYPE_Q4_0_R8 && op->src[0]->type <= GGML_TYPE_Q8_K_R8); case GGML_OP_MUL_MAT_ID: return ctx->support_simdgroup_reduction && (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); @@ -1027,934 +1180,882 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx } } -static enum ggml_status ggml_metal_graph_compute( +static void ggml_metal_encode_node( struct ggml_backend_metal_context * ctx, - struct ggml_cgraph * gf) { + struct ggml_tensor * node, + id encoder) { - @autoreleasepool { - MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - edesc.dispatchType = MTLDispatchTypeSerial; - // create multiple command buffers and enqueue them - // then, we encode the graph into the command buffers in parallel + struct ggml_tensor * src0 = node->src[0]; + struct ggml_tensor * src1 = node->src[1]; + struct ggml_tensor * src2 = node->src[2]; + struct ggml_tensor * dst = node; - const int n_nodes = gf->n_nodes; - const int n_cb = ctx->n_cb; - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; - - const bool should_capture = ctx->should_capture_next_compute; - if (should_capture) { - ctx->should_capture_next_compute = false; - - MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; - descriptor.captureObject = ctx->queue; - - NSError * error = nil; - if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { - GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); - GGML_ABORT("capture failed"); - } + if (ggml_is_empty(dst)) { + return; } - id command_buffer_builder[n_cb]; - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - command_buffer_builder[cb_idx] = command_buffer; - - // always enqueue the first two command buffers - // enqueue all of the command buffers if we don't need to abort - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer enqueue]; - } + switch (dst->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: return; // noop + default: break; } - const id *command_buffers = command_buffer_builder; + if (!ggml_metal_supports_op(ctx, dst)) { + GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); + GGML_ABORT("unsupported op"); + } - dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { - const int cb_idx = iter; + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_src2 = 0; - size_t offs_dst = 0; + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; - id command_buffer = command_buffers[cb_idx]; - id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; - for (int i = node_start; i < node_end; ++i) { - if (i == -1) { - [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; - continue; - } + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; - struct ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct ggml_tensor * src2 = gf->nodes[i]->src[2]; - struct ggml_tensor * dst = gf->nodes[i]; + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; - if (ggml_is_empty(dst)) { - continue; - } + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - { - // noop -> next node - } continue; + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_src2 = 0; + size_t offs_dst = 0; + + id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; + id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; + id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; + id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + + //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + //if (src0) { + // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, + // ggml_is_contiguous(src0), src0->name); + //} + //if (src1) { + // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, + // ggml_is_contiguous(src1), src1->name); + //} + //if (dst) { + // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, + // dst->name); + //} + + switch (dst->op) { + case GGML_OP_CONCAT: + { + GGML_ASSERT(src0->type == src1->type && src0->type == dst->type); + + id pipeline; + if (dst->type == GGML_TYPE_F32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT_F32].pipeline; + } + else if (dst->type == GGML_TYPE_F16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT_F16].pipeline; + } + else { + GGML_ABORT("CONCAT not implemented for this type"); + } + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + const size_t offs = 0; + + bool bcast_row = false; + + int64_t nb = ne00; // used by the "row" kernels + + id pipeline = nil; + + if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) { + float scale; + memcpy(&scale, src1->data, sizeof(float)); + //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + + int64_t n = ggml_nelements(dst); + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + break; + } + else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) && + dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] && + dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] && + dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] && + dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) { + + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break; + default: GGML_ASSERT(false); + } + + int64_t n = ggml_nelements(dst)/4; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + break; + } + else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + nb = ne00 / 4; + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + bcast_row = true; + } else { + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + } break; + case GGML_OP_MULTI_ADD: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + GGML_ASSERT(ne02 == 1 && ne03 == 1); + GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + int n_expert = dst->op_params[0]; + GGML_ASSERT(n_expert >= 2); + + id pipeline = nil; + int64_t n = ne0*ne1; + if (ne0%4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline; + } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ACC: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t pnb1 = ((int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((int32_t *) dst->op_params)[2]; + const size_t offs = ((int32_t *) dst->op_params)[3]; + + const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFTCAP: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scales[2]; + memcpy(scales, dst->op_params, sizeof(scales)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2]; + [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_CLAMP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + + case GGML_UNARY_OP_TANH: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_RELU: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SILU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SWIGLU: + { + int64_t n = ggml_nelements(dst); + GGML_ASSERT(ne0 == src0->ne[0]/2); + + id pipeline = nil; + + uint32_t n_per_row = ne0; + uint32_t stride = src0->nb[1]/sizeof(float); + + if (ne0 % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; + n /= 4; + n_per_row /= 4; + stride /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2]; + [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: - { - } break; - } - - if (!ggml_metal_supports_op(ctx, dst)) { - GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); - GGML_ABORT("unsupported op"); - } - - if (should_capture) { - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; - } - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; - - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); - const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; - id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - - //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} - - switch (dst->op) { - case GGML_OP_CONCAT: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ADD: - case GGML_OP_MUL: - case GGML_OP_DIV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - const size_t offs = 0; - - bool bcast_row = false; - - int64_t nb = ne00; // used by the "row" kernels - - id pipeline = nil; - - if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) { - float scale; - memcpy(&scale, src1->data, sizeof(float)); - //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - - int64_t n = ggml_nelements(dst); - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - break; - } - else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) && - dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] && - dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] && - dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] && - dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) { - - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break; - default: GGML_ASSERT(false); - } - - int64_t n = ggml_nelements(dst)/4; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - break; - } - else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - nb = ne00 / 4; - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - bcast_row = true; - } else { - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; - - if (bcast_row) { - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case GGML_OP_MULTI_ADD: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - GGML_ASSERT(ne02 == 1 && ne03 == 1); - GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - int n_expert = dst->op_params[0]; - GGML_ASSERT(n_expert >= 2); - - id pipeline = nil; - int64_t n = ne0*ne1; - if (ne0%4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline; - } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_REPEAT: - { - id pipeline; - - switch (src0t) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; - case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ACC: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - const size_t pnb1 = ((int32_t *) dst->op_params)[0]; - const size_t pnb2 = ((int32_t *) dst->op_params)[1]; - const size_t pnb3 = ((int32_t *) dst->op_params)[2]; - const size_t offs = ((int32_t *) dst->op_params)[3]; - - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - // run a separete kernel to cpy src->dst - // not sure how to avoid this - // TODO: make a simpler cpy_bytes kernel - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SCALE: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFTCAP: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scales[2]; - memcpy(scales, dst->op_params, sizeof(scales)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2]; - [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_CLAMP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; - - float min; - float max; - memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(gf->nodes[i])) { - // we are not taking into account the strides, so for now require contiguous tensors - GGML_ASSERT(ggml_is_contiguous(src0)); - - case GGML_UNARY_OP_TANH: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_RELU: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SIGMOID: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SILU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SWIGLU: - { - int64_t n = ggml_nelements(dst); - GGML_ASSERT(ne0 == src0->ne[0]/2); - - id pipeline = nil; - - uint32_t n_per_row = ne0; - uint32_t stride = src0->nb[1]/sizeof(float); - - if (ne0 % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; - n /= 4; - n_per_row /= 4; - stride /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2]; - [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_FUSED_MUL_UNARY: - { - int64_t n = ggml_nelements(dst); - enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; - id pipeline = nil; - if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) { - pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline - : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline; - n /= 4; - } else { - pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline - : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline - : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline; - } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SQR: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SUM_ROWS: - { - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFT_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SOFT_CAP_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - float s_before; - float s_after; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before)); - memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after)); - - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10]; - [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((int32_t *)(dst->op_params))[0]; - - id pipeline = nil; - - if (ne00%8 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - } break; - case GGML_OP_MUL_MAT: - { - GGML_ASSERT(ne00 == ne10); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - const uint r2 = ne12/ne02; - const uint r3 = ne13/ne03; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - int ne11_mm_min = 1; + { + GGML_METAL_LOG_WARN("%s: node %s, op = %8s not implemented\n", __func__, dst->name, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_FUSED_MUL_UNARY: + { + int64_t n = ggml_nelements(dst); + enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; + id pipeline = nil; + if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline; + n /= 4; + } else { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline + : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline; + } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SQR: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SUM_ROWS: + { + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFT_MAX: + { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + int nth = 32; // SIMD width + + id pipeline = nil; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } + } + + float scale; + float max_bias; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SOFT_CAP_MAX: + { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + int nth = 32; // SIMD width + + id pipeline = nil; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; + } + } + + float scale; + float max_bias; + float s_before; + float s_after; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before)); + memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after)); + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10]; + [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_DIAG_MASK_INF: + { + const int n_past = ((int32_t *)(dst->op_params))[0]; + + id pipeline = nil; + + if (ne00%8 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + } break; + case GGML_OP_MUL_MAT: + { + GGML_ASSERT(ne00 == ne10); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const uint r2 = ne12/ne02; + const uint r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 4; #if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { switch (src0t) { case GGML_TYPE_F16: ne11_mm_min = 2; break; @@ -1976,11 +2077,11 @@ static enum ggml_status ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - !ggml_is_transposed(src0) && - !ggml_is_transposed(src1) && - src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { + !ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && + (src1t == GGML_TYPE_F32 || src1t == GGML_TYPE_F16) && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers @@ -1993,41 +2094,84 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break; - case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break; - case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break; - case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break; - case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break; - default: GGML_ABORT("MUL MAT-MAT not implemented"); + if (src1->type == GGML_TYPE_F32) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break; + default: GGML_ABORT("MUL MAT-MAT not implemented"); + } + } + else if (src1->type == GGML_TYPE_F16) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16 ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16 ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16 ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16 ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16 ].pipeline; break; + default: GGML_ABORT("MUL MAT-MAT not implemented"); + } + } + else { + GGML_ABORT("Unsupported src1 type for MUL-MAT"); } [encoder setComputePipelineState:pipeline]; @@ -2305,9 +2449,9 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) { + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { @@ -2326,8 +2470,8 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || - src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| - src0t == GGML_TYPE_IQ4_KSS) { + src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| + src0t == GGML_TYPE_IQ4_KSS) { const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -2348,1179 +2492,1403 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } - } break; - case GGML_OP_MUL_MAT_ID: - { - const int n_as = src0->ne[2]; - - // src2 = ids - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); - - GGML_ASSERT(src2t == GGML_TYPE_I32); - - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src1t == GGML_TYPE_F32); - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - GGML_ASSERT(dst_rows <= dst_rows_max); - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - // !!! - // TODO: for now, always use mat-vec kernels until we figure out how to improve the - // indirect matrix multiplication - // !!! - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne00 % 32 == 0 && ne00 >= 64 && - dst_rows > dst_rows_min) { - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break; - case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break; - case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; - case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break; - case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - id pipeline = nil; - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; - } break; - case GGML_TYPE_BF16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; - } break; - case GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q6_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ1_BN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; - } break; - case GGML_TYPE_IQ2_BN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ4_KS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline; - } break; - case GGML_TYPE_IQ4_KSS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_KS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline; - } break; - case GGML_TYPE_IQ4_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline; - } break; - case GGML_TYPE_IQ5_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline; - } break; - case GGML_TYPE_IQ6_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline; - } break; - default: - { - GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); - GGML_ABORT("not implemented"); - } - }; - - if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; - - const int64_t _ne1 = 1; - const int tgz = dst_rows; - - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { - const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || - src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| - src0t == GGML_TYPE_IQ4_KSS) { - const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } - } break; - case GGML_OP_GET_ROWS: - { - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; - case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; - case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; - case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; - case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break; - case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break; - case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break; - case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break; - case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; - case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break; - case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } break; - case GGML_OP_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < 1024) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_FUSED_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(src1->ne[0] == src0->ne[0]); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_nrows(src1) == 1); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < 1024) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&eps length:sizeof( float) atIndex:5]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_GROUP_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous(src0)); - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - const int32_t n_groups = ((int32_t *) dst->op_params)[0]; - - int nth = 32; // SIMD width - - //while (nth < ne00/4 && nth < 1024) { - // nth *= 2; - //} - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&eps length:sizeof( float) atIndex:9]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_NORM: - { - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - const int nth = MIN(256, ne00); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ROPE: - { - GGML_ASSERT(ne10 == ne02); - - const int nth = MIN(1024, ne00); - - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; - - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - - const bool is_neox = mode & 2; - - id pipeline = nil; - - if (!is_neox) { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_IM2COL: - { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int32_t N = src1->ne[is_2D ? 3 : 2]; - const int32_t IC = src1->ne[is_2D ? 2 : 1]; - const int32_t IH = is_2D ? src1->ne[1] : 1; - const int32_t IW = src1->ne[0]; - - const int32_t KH = is_2D ? src0->ne[1] : 1; - const int32_t KW = src0->ne[0]; - - const int32_t OH = is_2D ? dst->ne[2] : 1; - const int32_t OW = dst->ne[1]; - - const int32_t CHW = IC * KH * KW; - - const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - - id pipeline = nil; - - switch (dst->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; - - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; - } break; - case GGML_OP_UPSCALE: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; - [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; - [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; - [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_PAD: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARANGE: - { - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - float start; - float step; - - memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; - [encoder setBytes:&start length:sizeof(start) atIndex:2]; - [encoder setBytes:&step length:sizeof(step) atIndex:3]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int dim = dst->op_params[0]; - const int max_period = dst->op_params[1]; - - const int half = dim / 2; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; - [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; - - const int nth = MIN(1024, half); - - [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARGSORT: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_I32); - - const int nrows = ggml_nrows(src0); - - enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - - // bitonic sort requires the number of elements to be power of 2 - int64_t ne00_padded = 1; - while (ne00_padded < ne00) { - ne00_padded *= 2; - } - - // Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); - - id pipeline = nil; - - switch (order) { - case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; - case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; - } break; - case GGML_OP_LEAKY_RELU: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - float slope; - memcpy(&slope, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_are_same_shape (src1, src2)); - - struct ggml_tensor * src3 = gf->nodes[i]->src[3]; - - size_t offs_src3 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - - GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); - GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - //const int64_t ne31 = src3 ? src3->ne[1] : 0; - const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); - const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); - - const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); - const uint64_t nb31 = src3 ? src3->nb[1] : 0; - const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); - const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); - - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); - - float scale; - float max_bias; - float softcap; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap)); - if (softcap != 0.0f) { - scale /= softcap; - } - - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - id pipeline = nil; - - bool use_vec_kernel = false; - - if (ne01 >= 4 || (ne00%128 != 0)) { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28]; - - if (!use_vec_kernel) { - // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - int64_t nsgmax = 2; - - while (true) { - const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); - if (smem > ctx->device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - - const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else { - // half1x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 1 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - - int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; - - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - { - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); - - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); + } break; + case GGML_OP_MUL_MAT_ID: + { + const int n_as = src0->ne[2]; + + // src2 = ids + const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); + + GGML_ASSERT(src2t == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src1t == GGML_TYPE_F32); + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows + const int dst_rows = ne20*ne21; + const int dst_rows_min = n_as; + //const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength/2 - 8192)/4; + const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 8192)/4; + + // max size of the rowids array in the kernel shared buffer + //GGML_ASSERT(dst_rows <= dst_rows_max); + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + // !!! + // TODO: for now, always use mat-vec kernels until we figure out how to improve the + // indirect matrix multiplication + // !!! + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne00 % 32 == 0 && ne00 >= 64 && + dst_rows > dst_rows_min && + dst_rows <= dst_rows_max) { + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; } - } - if (should_capture) { - [encoder popDebugGroup]; + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + + [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + + id pipeline = nil; + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; + } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; + } break; + case GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; + } break; + case GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; + } break; + case GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; + } break; + case GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; + } break; + case GGML_TYPE_Q6_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline; + } break; + case GGML_TYPE_Q8_0: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; + } break; + case GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; + } break; + case GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; + } break; + case GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; + } break; + case GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; + } break; + case GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; + } break; + case GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_M: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; + } break; + case GGML_TYPE_IQ1_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ2_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ4_NL: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; + } break; + case GGML_TYPE_IQ4_XS: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ4_KS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline; + } break; + case GGML_TYPE_IQ4_KSS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_KS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline; + } break; + case GGML_TYPE_IQ4_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline; + } break; + case GGML_TYPE_IQ5_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline; + } break; + case GGML_TYPE_IQ6_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline; + } break; + default: + { + GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); + GGML_ABORT("not implemented"); + } + }; + + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nth0*nth1); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + + const int64_t _ne1 = 1; + const int tgz = dst_rows; + + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { + const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { + const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || + src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| + src0t == GGML_TYPE_IQ4_KSS) { + const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q3_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case GGML_OP_GET_ROWS: + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; + default: GGML_ABORT("not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; + case GGML_OP_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_FUSED_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src1->ne[0] == src0->ne[0]); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nrows(src1) == 1); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&eps length:sizeof( float) atIndex:5]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous(src0)); + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + + const int32_t n_groups = ((int32_t *) dst->op_params)[0]; + + int nth = 32; // SIMD width + + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_NORM: + { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int nth = MIN(256, ne00); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ROPE: + { + GGML_ASSERT(ne10 == ne02); + + const int nth = MIN(1024, ne00); + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & 2; + + id pipeline = nil; + + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + id pipeline = nil; + + switch (dst->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; + [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; + [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; + [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_PAD: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARANGE: + { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float start; + float step; + + memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; + [encoder setBytes:&start length:sizeof(start) atIndex:2]; + [encoder setBytes:&step length:sizeof(step) atIndex:3]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + const int half = dim / 2; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; + [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; + + const int nth = MIN(1024, half); + + [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARGSORT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + + id pipeline = nil; + + switch (order) { + case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; + } break; + case GGML_OP_LEAKY_RELU: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne11 % 32 == 0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == src2->type); + GGML_ASSERT(ne11 == ne21); + GGML_ASSERT(ne12 == ne22); + + struct ggml_tensor * src3 = node->src[3]; + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + //const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + float max_bias; + float softcap; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap)); + if (softcap != 0.0f) { + scale /= softcap; + } + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + id pipeline = nil; + + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192 && ne00 != 576)) { + switch (src1->type) { + case GGML_TYPE_F16: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; + } + else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_Q8_0: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; + } + else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + default: + { + GGML_METAL_LOG_ERROR("unsupported type: %s\n", ggml_type_name(src1->type)); + GGML_METAL_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } else { + use_vec_kernel = true; + switch (src1->type) { + case GGML_TYPE_F16: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; + } + else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_Q8_0: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; + } + else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + default: + { + GGML_METAL_LOG_ERROR("unsupported type: %s\n", ggml_type_name(src1->type)); + GGML_METAL_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + + } + + typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; + } ggml_metal_kargs_flash_attn_ext; + + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.nb31 =*/ nb31, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ softcap, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // 2*(2*ncpsg + nqptg)*(nsg) + // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) + // + // 16*32*(nsg) + // the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // ne00 + 2*ncpsg*(nsg) + // for each query, we load it as f16 in shared memory (ne00) + // and store the soft_max values and the mask + // + // ne00*(nsg) + // each simdgroup has a full f16 head vector in shared mem to accumulate results + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + } + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + + id pipeline = nil; + + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); + + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + default: GGML_ABORT("not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + { + GGML_METAL_LOG_ERROR("%s: error: node %s, op = %8s not implemented\n", __func__, dst->name, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } + +} + +static enum ggml_status ggml_metal_graph_compute( + struct ggml_backend_metal_context * ctx, + struct ggml_cgraph * gf) { + + // number of nodes encoded by the main thread (empirically determined) + const int n_main = 128; + + // number of threads in addition to the main thread + const int n_cb = ctx->n_cb; + + @autoreleasepool { + ctx->gf = gf; + + ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); + ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; + + ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; + + const bool should_capture = ctx->capture_next_compute; + if (should_capture) { + ctx->capture_next_compute = false; + + if (!ctx->capture_started) { + // create capture scope + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device]; //ctx_dev->mtl_device]; + + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->capture_scope; + descriptor.destination = MTLCaptureDestinationGPUTraceDocument; + descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + printf("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + } else { + [ctx->capture_scope beginScope]; + ctx->capture_started = true; + } } } - [encoder endEncoding]; + // the main thread commits the first few commands immediately + // command_buffer[n_cb] + { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[n_cb] = command_buffer; - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer commit]; + [command_buffer enqueue]; + ctx->encode_async(n_cb); } - }); - // Wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) + // prepare the rest of the command buffers asynchronously + // command_buffer[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[cb_idx] = command_buffer; - for (int i = 0; i < n_cb; ++i) { - id command_buffer = command_buffers[i]; - [command_buffer waitUntilCompleted]; + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer enqueue]; + } + } - MTLCommandBufferStatus status = [command_buffer status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - NSString * error_code = [command_buffer error].localizedDescription; - GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]); + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id command_buffer = ctx->command_buffers[n_cb]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + printf("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + printf("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + for (int i = 0; i < n_cb; ++i) { + id command_buffer = ctx->command_buffers[i]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + printf("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + printf("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; } - return GGML_STATUS_FAILED; + id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + printf("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; } - id next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil); - if (!next_buffer) { - continue; + if (!should_capture && ctx->capture_started) { + [ctx->capture_scope endScope]; + [[MTLCaptureManager sharedCaptureManager] stopCapture]; } - - bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } - - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; - } - - [next_buffer commit]; } - if (should_capture) { - [[MTLCaptureManager sharedCaptureManager] stopCapture]; - } - - } return GGML_STATUS_SUCCESS; } @@ -3905,8 +4273,60 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { GGML_ASSERT(ggml_backend_is_metal(backend)); struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + if (ctx->n_cb != n_cb) { + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); + + if (ctx->n_cb > 2) { + GGML_METAL_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); + } + } + + if (ctx->encode_async) { + Block_release(ctx->encode_async); + } + + ctx->encode_async = Block_copy(^(size_t iter) { + const int cb_idx = iter; + const int n_cb_l = ctx->n_cb; + + const int n_nodes_0 = ctx->n_nodes_0; + const int n_nodes_1 = ctx->n_nodes_1; + + const int n_nodes_per_cb = ctx->n_nodes_per_cb; + + id command_buffer = ctx->command_buffers[cb_idx]; + id encoder = [command_buffer computeCommandEncoder]; + + int node_start = 0; + int node_end = n_nodes_0; + + if (cb_idx < n_cb_l) { + node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); + node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); + } + + const bool should_capture = ctx->capture_next_compute; + + for (int idx = node_start; idx < node_end; ++idx) { + struct ggml_tensor * node = ctx->gf->nodes[idx]; + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(node) encoding:NSUTF8StringEncoding]]; + } + + ggml_metal_encode_node(ctx, node, encoder); + + if (should_capture) { + [encoder popDebugGroup]; + } + } + + [encoder endEncoding]; + + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer commit]; + } + }); - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); } void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 89cd412a..d3a2858c 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1652,13 +1652,18 @@ void kernel_mul_mv_q8_0_f32_impl( yl[i] = yb[i]; } + device const block_q8_0 * xr = x + ib; + for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + //device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + device const int8_t * qs = xr->qs + NB_Q8_0*il; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } - sumf[row] += sumq*x[ib+row*nb].d; + //sumf[row] += sumq*x[ib+row*nb].d; + sumf[row] += sumq*xr->d; + xr += nb; } yb += NB_Q8_0 * nw; @@ -2571,262 +2576,358 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } -typedef void (flash_attn_ext_f16_t)( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant float & softcap, - constant uint32_t & n_head_log2, - threadgroup half * shared, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]); +//========================================================================================== +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src)); +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + if constexpr (is_same_v) { + const half d = xb->d; + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (half)qs[i + 16*il] * d; + } + } else { + const float d = xb->d; + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = qs[i + 16*il] * d; + } + } +} + +template +void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const float d = xb->d; + for (int i = 0; i < 4; i++) { + reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d); + } +} + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; // ref: https://arxiv.org/pdf/2307.08691.pdf -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant float & softcap, - constant uint32_t & n_head_log2, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // key type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = 8, // queries per threadgroup + short KV = 8, // key/value processed per each simdgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]*Q; + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]*Q; - const short D4 = D/4; - const short D8 = D/8; - //const short Q8 = Q/8; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) + constexpr short DK4 = DK/4; + constexpr short DK8 = DK/8; + constexpr short DK16 = DK/16; + constexpr short DV4 = DV/4; + constexpr short DV8 = DV/8; + constexpr short DV16 = DV/16; - const short T = D + 2*nsg*SH; // shared memory size per query in (half) - const short TF = T/2; // shared memory size per query in (float) - const short T4 = T/4; // shared memory size per query in (half4) + constexpr short NW = N_SIMDWIDTH; + constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) default: 72 - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + const short TS = nsg*SH; // shared memory size per query in (s_t == float) = 288 for nsg = 4 + const short T = DK + 2*TS; // shared memory size per query in (half) = 1152 for nsg = 4 and DK = 576 + // => Q*T is 9216 for nsg = 4 and DK = 576 => 18432 bytes => overflows the 16384 bytes predicted as shmem in ggml-metal.m + + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix + + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; + o8x8_t lo[DV8]; // For DV = 512 we have DV8 = 64 => 4096 entries per thread. Do we even have so much // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); - for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { - sq4[j*T4 + i] = (half4) q4[i]; + for (short i = tiisg; i < DK4; i += NW) { + if (iq1 + j < args.ne01) { + sq4[j*DK4 + i] = (q4_t) q4[i]; } else { - sq4[j*T4 + i] = 0.0h; + sq4[j*DK4 + i] = (q4_t) 0.0f; } } } // zero out lo - for (short i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); + for (short i = 0; i < DV8; ++i) { + lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); } // zero out shared memory SH for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TF + i] = 0.0f; + ss[j*TS + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); { - float S[Q] = { [0 ... Q-1] = 0.0h }; - float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + float S[Q] = { [0 ... Q-1] = 0.0f }; + float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; + // thread indices inside the simdgroup + // TODO: see if we can utilize quad-group functions for better performance + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) + const short tx = tiisg%4; + const short ty = tiisg/4; - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); - // k indices - const short ik2 = iq2/rk2; - const short ik3 = iq3/rk3; - - // v indices - const short iv2 = iq2/rv2; - const short iv3 = iq3/rv3; - - // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; - - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); - } - - // pointer to the mask - device const half * mp = (device const half *) (mask + iq1*nb31); + const bool has_mask = mask != q; float slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; + if (args.max_bias > 0.0f) { + const short h = iq2; - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } + if (has_mask) { + // used to detect blocks full of -INF + float smax = -INFINITY; + + // load the mask in shared memory + #pragma unroll(Q) + for (short j = 0; j < Q; ++j) { + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + + const float m = pm[ic + tiisg]; + + ss[j*TS + C + tiisg] = m; + smax = max(smax, m); + } + + smax = simd_max(smax); + + if (smax == -INFINITY) { + continue; + } + } + // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + #pragma unroll(DK8) + for (short i = 0; i < DK8; ++i) { + k8x8_t mk; + simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + q8x8_t mq; + simdgroup_load(mq, sq + i*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } + } } - simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + // cast qk_t -> s_t + //s8x8_t mqks(1.0f); + //simdgroup_multiply(mqks, mqk, mqks); + //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); - const short tx = tiisg%4; - const short ty = tiisg/4; - - // mqk = mqk*scale - ss[8*cc + ty*TF + 2*tx + 0] *= scale; - ss[8*cc + ty*TF + 2*tx + 1] *= scale; - - if (softcap != 0.0f) { - ss[8*cc + ty*TF + 2*tx + 0] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); - ss[8*cc + ty*TF + 2*tx + 1] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); - } - - if (mask != q) { - // mqk = mqk*scale + mask*slope - ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; - ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; - } + simdgroup_store(mqk, ss + 8*cc, TS, 0, false); } } - // used to detect blocks full of -INF - float smax = -INFINITY; - // online softmax { - float ms[Q]; - - for (short j = 0; j < Q; ++j) { - const short p = tiisg; - + for (ushort j = 0; j < Q; ++j) { const float m = M[j]; - const float s = ss[j*TF + p]; - smax = simd_max(max(smax, s)); + // scale and apply the logitcap / mask + float s = ss[j*TS + tiisg]*args.scale; + + if (args.logit_softcap != 0.0f) { + s = args.logit_softcap*precise::tanh(s); + } + + // mqk = mqk + mask*slope + s += slope*ss[j*TS + C + tiisg]; + M[j] = simd_max(max(M[j], s)); - ms[j] = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float ms = exp(m - M[j]); + const float vs = exp(s - M[j]); - S[j] = S[j]*ms[j] + simd_sum(vs); + S[j] = S[j]*ms + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; - } + ss[j*TS + tiisg] = vs; - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*TF + C + tiisg] = ms[tiisg]; + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*TS + 2*C + j] = ms; + } } } - // skip -INF blocks - if (smax == -INFINITY) { - continue; - } - // O = diag(ms)*O { - simdgroup_float8x8 mm; - simdgroup_load(mm, ss + C, TF, 0, false); + s8x8_t mm; + simdgroup_load(mm, ss + 2*C, TS, 0, false); - for (short i = 0; i < D8; ++i) { + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { simdgroup_multiply(lo[i], mm, lo[i]); } } @@ -2834,16 +2935,64 @@ kernel void kernel_flash_attn_ext_f16( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + s8x8_t ms; + simdgroup_load(ms, ss + 8*cc, TS, 0, false); - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + if (is_same::value) { + // we can read directly from global memory + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); - simdgroup_float8x8 mv; - simdgroup_load(mv, ss + 8*cc, TF, 0, false); + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + v8x8_t mv; + simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 - simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); + } + } else { + for (short ii = 0; ii < DV16; ii += 4) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + + if (DV16%4 == 0) { + // no need for bound checks + { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + } + } else { + if (ii + tx < DV16) { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DV16; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + } + } + } } } } @@ -2852,23 +3001,23 @@ kernel void kernel_flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (short j = 0; j < Q; ++j) { if (tiisg == 0) { - ss[j*TF + 0] = S[j]; - ss[j*TF + 1] = M[j]; + ss[j*TS + 0] = S[j]; + ss[j*TS + 1] = M[j]; } } } // reduce the warps sequentially - for (short sg = 1; sg < nsg; ++sg) { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; + for (ushort sg = 1; sg < nsg; ++sg) { + float S = { 0.0f }; + float M = { -__FLT16_MAX__/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); } } @@ -2877,11 +3026,11 @@ kernel void kernel_flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TF + 0]; - const float S1 = ss[j*TF + sg*SH + 0]; + const float S0 = ss[j*TS + 0]; + const float S1 = ss[j*TS + sg*SH + 0]; - const float M0 = ss[j*TF + 1]; - const float M1 = ss[j*TF + sg*SH + 1]; + const float M0 = ss[j*TS + 1]; + const float M1 = ss[j*TS + sg*SH + 1]; M = max(M0, M1); @@ -2891,25 +3040,27 @@ kernel void kernel_flash_attn_ext_f16( S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[j*TF + 0] = S; - ss[j*TF + 1] = M; + ss[j*TS + 0] = S; + ss[j*TS + 1] = M; - ss[j*TF + C + j ] = ms0; - ss[j*TF + C + j + sg*SH] = ms1; + ss[j*TS + 2*C + j ] = ms0; + ss[j*TS + 2*C + j + sg*SH] = ms1; } } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 { - simdgroup_half8x8 t; - simdgroup_float8x8 ms0; - simdgroup_float8x8 ms1; + s8x8_t ms0; + s8x8_t ms1; - simdgroup_load(ms0, ss + C, TF, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + simdgroup_load(ms0, ss + 2*C, TS, 0, false); + simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); - for (short i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + o8x8_t t; + + simdgroup_load (t, so + i*8, DV, 0, false); simdgroup_multiply(t, ms1, t); simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); @@ -2920,8 +3071,8 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); } } @@ -2929,206 +3080,246 @@ kernel void kernel_flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const float S = ss[j*TF + 0]; + for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { + const float S = ss[j*TS + 0]; - for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S; } } } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; -//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; +// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as +// template to be able to explore different combinations +// +#define FA_TYPES \ + half, half4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8 -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_vec_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant float & softcap, - constant uint32_t & n_head_log2, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +#undef FA_TYPES + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // key type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]; + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]; - const short D4 = D/4; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) + constexpr short DK4 = DK/4; + constexpr short DV4 = DV/4; + constexpr short NW = N_SIMDWIDTH; + constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads + constexpr short SH = 4*C; // shared memory per simdgroup - const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short T = DK + nsg*SH; // shared memory size per query in (half) - float slope = 1.0f; + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results - - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[D4/NW]; + // store the result for all queries in local memory (the O matrix from the paper) + o4_t lo[DV4/NL]; // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); - for (short i = tiisg; i < D4; i += NW) { - if (iq1 < ne01) { - sq4[i] = (half4) q4[i]; + for (short i = tiisg; i < DK4; i += NW) { + if (iq1 < args.ne01) { + sq4[i] = (q4_t) q4[i]; } else { - sq4[i] = 0.0h; + sq4[i] = (q4_t) 0.0f; } } // zero out lo - for (short i = tiisg; i < D4; i += NW) { - lo[i/NW] = 0.0h; + for (short i = 0; i < DV4/NL; ++i) { + lo[i] = (o4_t) 0.0f; } // zero out shared memory SH for (short i = tiisg; i < SH/4; i += NW) { - ss4[i] = 0.0h; + ss4[i] = (s4_t) 0.0f; } threadgroup_barrier(mem_flags::mem_threadgroup); { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; + float S = 0.0f; + float M = -__FLT16_MAX__/2; - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; + // thread indices inside the simdgroup + const short tx = tiisg%NL; + const short ty = tiisg/NL; - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); - // k indices - const short ik2 = iq2 / rk2; - const short ik3 = iq3 / rk3; - - // v indices - const short iv2 = iq2 / rv2; - const short iv3 = iq3 / rv3; - - // load the queries from shared memory into local memory - float4 mq[D4]; - - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - mq[i] = (float4)sq4[i]; - } + const bool has_mask = mask != q; // pointer to the mask - device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31); + + float slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const short h = iq2; + + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exph); + } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } + if (has_mask) { + sm[tiisg] = pm[ic + tiisg]; + } + // Q*K^T { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - float4 mqk = { 0.0h }; + // each simdgroup processes 1 query and NE (NW/NL) head elements + for (short cc = 0; cc < C/NE; ++cc) { + qk_t mqk = 0.0f; - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; + #pragma unroll(DK4/NL) + for (short ii = 0; ii < DK4; ii += NL) { + const short i = ii + tx; - float4x4 mk; - mk[0] = (float4)pk4[i + 0*(nb11/8)]; - mk[1] = (float4)pk4[i + 1*(nb11/8)]; - mk[2] = (float4)pk4[i + 2*(nb11/8)]; - mk[3] = (float4)pk4[i + 3*(nb11/8)]; + k4_t mk; + deq_k_t4(pk + i/nl_k, i%nl_k, mk); - mqk += (float4) (mq[i] * mk); + // note: this is less precise than the version below + //mqka[0] += dot(mq[0], mk[0]); + //mqka[1] += dot(mq[1], mk[1]); + //mqka[2] += dot(mq[2], mk[2]); + //mqka[3] += dot(mq[3], mk[3]); + + //q4x4_t mq = sq4x4[i]; + //mqka[0] += dot((float4) mq[0], (float4) mk[0]); + //mqka[1] += dot((float4) mq[1], (float4) mk[1]); + //mqka[2] += dot((float4) mq[2], (float4) mk[2]); + //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + + mqk += dot((float4) mk, (float4) sq4[i]); } - // reduce the results from the threads in the simdgroup - mqk += simd_shuffle_down(mqk, 16); - mqk += simd_shuffle_down(mqk, 8); - mqk += simd_shuffle_down(mqk, 4); - mqk += simd_shuffle_down(mqk, 2); - mqk += simd_shuffle_down(mqk, 1); + static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails + + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk += simd_shuffle_down(mqk, 16); + } + if (NE <= 2) { + mqk += simd_shuffle_down(mqk, 8); + } + if (NE <= 4) { + mqk += simd_shuffle_down(mqk, 4); + } + if (NE <= 8) { + mqk += simd_shuffle_down(mqk, 2); + } + if (NE <= 16) { + mqk += simd_shuffle_down(mqk, 1); + } // mqk = mqk*scale + mask*slope - if (tiisg == 0) { - mqk *= scale; - if (softcap != 0.0f) { - mqk = softcap*precise::tanh(mqk); + if (tx == 0) { + mqk *= args.scale; + + if (args.logit_softcap != 0.0f) { + mqk = args.logit_softcap*precise::tanh(mqk); } - mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; - ss4[cc] = mqk; + mqk += sm[NE*cc + ty]*slope; + + ss[NE*cc + ty] = mqk; } - } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // online softmax { - const short p = tiisg; - const float m = M; - const float s = ss[p]; + const float s = ss[tiisg]; M = simd_max(max(M, s)); @@ -3138,47 +3329,96 @@ kernel void kernel_flash_attn_ext_vec_f16( S = S*ms + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[p] = vs; + ss[tiisg] = vs; // O = diag(ms)*O -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - lo[i/NW] *= ms; + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { + lo[ii/NL] *= ms; } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // O = O + (Q*K^T)*V { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + //#pragma unroll(C/NE) + for (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; + const s4_t ms(ss[NE*cc + ty]); - lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; - lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; - lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; - lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { + const short i = ii + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii/NL] += o4_t(float4(mv)*float4(ms)); } } } - } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) if (tiisg == 0) { - ss[0] = S; - ss[1] = M; + ss[0] = (s_t) S; + ss[1] = (s_t) M; } } + // simdgroup reduce (NE = 4) + // [ 0, 8, 16, 24] -> [ 0] + // [ 1, 9, 17, 25] -> [ 1] + // [ 2, 10, 18, 26] -> [ 2] + // [ 3, 11, 19, 27] -> [ 3] + // [ 4, 12, 20, 28] -> [ 4] + // [ 5, 13, 21, 29] -> [ 5] + // [ 6, 14, 22, 30] -> [ 6] + // [ 7, 15, 23, 31] -> [ 7] + for (short ii = 0; ii < DV4; ii += NL) { + if (NE > 1) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); + } + + if (NE > 2) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); + } + + if (NE > 4) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); + } + + if (NE > 8) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); + } + + if (NE > 16) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // store results to shared memory - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = lo[ii/NW]; + for (short i = tiisg; i < DV4; i += NL) { + sr4[i] = lo[i/NL]; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3186,11 +3426,11 @@ kernel void kernel_flash_attn_ext_vec_f16( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { - const float S0 = ss[ 0]; - const float S1 = ss[r*SH + 0]; + const float S0 = ss[ 0]; + const float S1 = ss[r*(SH/2) + 0]; - const float M0 = ss[ 1]; - const float M1 = ss[r*SH + 1]; + const float M0 = ss[ 1]; + const float M1 = ss[r*(SH/2) + 1]; const float M = max(M0, M1); @@ -3205,9 +3445,8 @@ kernel void kernel_flash_attn_ext_vec_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + for (short i = tiisg; i < DV4; i += NW) { + sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; } } @@ -3220,15 +3459,50 @@ kernel void kernel_flash_attn_ext_vec_f16( if (sgitg == 0) { const float S = ss[0]; - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; } } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; -//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; +// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem +// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max +// +#define FA_TYPES \ + half4, \ + half4, \ + half4, \ + float, \ + float, float4, \ + half4 + +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h80")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h80")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h112")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +#undef FA_TYPES template kernel void kernel_cpy( @@ -3804,7 +4078,8 @@ kernel void kernel_cpy_f32_iq4_nl( } } -kernel void kernel_concat( +template +static inline void concat_impl( device const char * src0, device const char * src1, device char * dst, @@ -3844,21 +4119,93 @@ kernel void kernel_concat( int64_t o[4] = {0, 0, 0, 0}; o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - device const float * x; + device const src_t * x; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + x = (device const src_t *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + x = (device const src_t *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device src_t * y = (device src_t *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); *y = *x; } } +kernel void kernel_concat_f32( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & dim, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + concat_impl(src0, src1, dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, dim, tgpig, tpitg, ntg); +} + +kernel void kernel_concat_f16( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & dim, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + concat_impl(src0, src1, dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, dim, tgpig, tpitg, ntg); +} + + + void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, @@ -7087,21 +7434,6 @@ kernel void kernel_mul_mv_iq6_k_f32( //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} template void dequantize_bf16(device const half4x4 * src, short il, thread type4x4 & reg) { @@ -7215,16 +7547,6 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg } } -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); - } -} - template void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -7956,7 +8278,7 @@ struct DequantizerRSBN { }; // each block_q contains 16*nl weights -template +template kernel void kernel_mul_mm(device const uchar * src0, device const uchar * src1, device float * dst, @@ -8006,8 +8328,8 @@ kernel void kernel_mul_mm(device const uchar * src0, uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); - device const char * cx = (device const char *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0); - device const float * y = (device const float *)(src1 + device const char * cx = (device const char *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0); + device const src1_t * y = (device const src1_t *)(src1 + nb12 * im + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); @@ -8027,7 +8349,12 @@ kernel void kernel_mul_mm(device const uchar * src0, + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + if (is_same_v) { + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + } else { + half2x4 h = *((device half2x4 *)y); + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = (float2x4)h; + } deq.next(); y += BLOCK_SIZE_K; @@ -8246,39 +8573,6 @@ kernel void kernel_mul_mm_id( uint ntg = ntg3.x * ntg3.y * ntg3.z; uint n = nei0*nei1; - //uint npt = (n + ntg - 1) / ntg; - //uint first = tiitg * npt; - //uint last = first + npt <= n ? first + npt : n; - - //uint nhave = 0; - //for (uint i = first; i < last; ++i) { - // uint ii0 = i % nei0; - // uint ii1 = i / nei0; - // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - // if (id == i02) ++nhave; - //} - //threadgroup uint * nums = (threadgroup uint *)shared_memory; - //nums[tiitg] = nhave; - //threadgroup_barrier(mem_flags::mem_threadgroup); - - //uint nprev = 0; - //for (uint i = 0; i < tiitg; ++i) nprev += nums[i]; - //int64_t _ne1 = nprev; - //for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i]; - - //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); - //for (uint i = first; i < last; ++i) { - // uint ii0 = i % nei0; - // uint ii1 = i / nei0; - // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - // if (id == i02) rowids[nprev++] = ushort2(ii0, ii1); - //} - - //threadgroup_barrier(mem_flags::mem_threadgroup); - - // - // The following is slightly faster than the commented out version above - // uint nhave = 0; for (uint i = tiitg; i < n; i += ntg) { uint ii0 = i % nei0; @@ -8290,10 +8584,24 @@ kernel void kernel_mul_mm_id( nums[tiitg] = nhave; threadgroup_barrier(mem_flags::mem_threadgroup); - uint nprev = 0; - for (uint i = 0; i < tiitg; ++i) nprev += nums[i]; - int64_t _ne1 = nprev; - for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i]; + uint stride = 1; + while (stride <= ntg/2) { + uint index = (tiitg+1)*stride*2 - 1; // index - stride = 2*tiitg*stride + stride - 1; + if (index < ntg) nums[index] += nums[index-stride]; + stride <<= 1; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + stride = ntg/2; + while (stride > 0) { + uint index = (tiitg+1)*stride*2 - 1; + if (index+stride < ntg) nums[index+stride] += nums[index]; + stride >>= 1; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + uint _ne1 = nums[ntg-1]; + if (!_ne1) return; + + uint nprev = tiitg > 0 ? nums[tiitg-1] : 0; threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); for (uint i = tiitg; i < n; i += ntg) { @@ -8304,47 +8612,37 @@ kernel void kernel_mul_mm_id( } threadgroup_barrier(mem_flags::mem_threadgroup); - // This is the original version that is ridiculously slow. - //// row indices - //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + uint nstep = (_ne1 + BLOCK_SIZE_N - 1)/BLOCK_SIZE_N; - //// TODO: parallelize this loop - //int64_t _ne1 = 0; - //for (ushort ii1 = 0; ii1 < nei1; ii1++) { - // for (ushort ii0 = 0; ii0 < nei0; ii0++) { - // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - // if (id == i02) { - // //if (tiitg == 0) { - // rowids[_ne1] = ushort2(ii0, ii1); - // //} - // _ne1++; - // } - // } - //} + for (uint istep = 0; istep < nstep; ++istep) { - //threadgroup_barrier(mem_flags::mem_threadgroup); + uint first = BLOCK_SIZE_N*istep; + uint last = first + BLOCK_SIZE_N < _ne1 ? first + BLOCK_SIZE_N : _ne1; + int64_t this_ne1 = last - first; + threadgroup ushort2 * this_rowids = rowids + istep*BLOCK_SIZE_N; - kernel_mul_mm_id_impl( - src0, - src1, - rowids, - dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, - tgpig, - tiitg, - sgitg); + kernel_mul_mm_id_impl( + src0, + src1, + this_rowids, + dst, + ne00, + ne02, + nb01, + nb02, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + this_ne1, + ne0*ne1, + shared_memory, + tgpig, + tiitg, + sgitg); + } } #define QK_NL 16 @@ -8399,41 +8697,76 @@ template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get template using DD = DefaultDequantizer; -typedef decltype(kernel_mul_mm>) mat_mm_t; +typedef decltype(kernel_mul_mm, float>) mat_mm_t; + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q6_0_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm, float>; + +template [[host_name("kernel_mul_mm_f32_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_f16_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q6_0_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq3_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq5_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq6_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq1_bn_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_bn_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_ks_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_kss_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_ks_f16")]] kernel mat_mm_t kernel_mul_mm, half>; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q6_0_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm>; // // indirect matrix-matrix multiplication diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index d0154b72..44dce586 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -1,6 +1,6 @@ // -// Copyright (C) 2023-2024 The ggml authors // Copyright (C) 2024 Iwan Kawrakow +// Copyright (C) 2023-2024 The ggml authors // MIT license // SPDX-License-Identifier: MIT // @@ -11,6 +11,7 @@ #include "ggml-quants.h" #include "ggml-impl.h" #if GGML_USE_IQK_MULMAT +#include "iqk/iqk_config.h" #include "iqk/iqk_mul_mat.h" #include "iqk/iqk_quantize.h" #endif @@ -1767,10 +1768,8 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * float scale = suml2 ? sumlx/suml2 : 0.0f; if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; float best = scale * sumlx; + float best_sumlx = sumlx, best_suml2 = suml2; for (int is = -9; is <= 9; ++is) { - if (is == 0) { - continue; - } iscale = -(nmax + 0.1f*is) / max; sumlx = suml2 = 0; for (int i = 0; i < n; ++i) { @@ -1786,7 +1785,66 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); } scale = sumlx/suml2; best = scale*sumlx; + best_sumlx = sumlx; best_suml2 = suml2; } + iscale = (nmax-1 + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + best_sumlx = sumlx; best_suml2 = suml2; + } + } + + sumlx = best_sumlx; suml2 = best_suml2; + for (int iter = 0; iter < n*(2*nmax-1); ++iter) { + float abs_gmax = 0, gmax = 0; + int best_j = -1; + for (int j = 0; j < n; ++j) { + float w = qw ? qw[j] : rmse_type == 1 ? x[j] * x[j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[j]) : sqrtf(fabsf(x[j])); + int l = L[j] - nmax; + float g = scale * w * (x[j] - scale*l); + if ((g > 0 && l < nmax-1) || (g < 0 && l > -nmax)) { + float ag = fabsf(g); + if (ag > abs_gmax) { + abs_gmax = ag; gmax = g; best_j = j; + } + } + } + if (best_j < 0) break; + + float new_sumlx = sumlx, new_suml2 = suml2; + float w = qw ? qw[best_j] : rmse_type == 1 ? x[best_j] * x[best_j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[best_j]) : sqrtf(fabsf(x[best_j])); + int l = L[best_j] - nmax; + if (gmax > 0) { + new_sumlx += w*x[best_j]; + new_suml2 += w*(2*l + 1); + l += 1; + } else { + new_sumlx -= w*x[best_j]; + new_suml2 -= w*(2*l - 1); + l -= 1; + } + if (new_suml2 > 0 && new_sumlx*new_sumlx > best*new_suml2) { + sumlx = new_sumlx; suml2 = new_suml2; + scale = sumlx/suml2; best = scale*sumlx; + L[best_j] = l + nmax; + GGML_ASSERT(L[best_j] >= 0 && L[best_j] <= 2*nmax-1); + } + else { + break; + } + } return scale; } @@ -2141,8 +2199,9 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f float rmin, float rdelta, int nstep, bool use_mad) { float min = x[0]; float max = x[0]; - float sum_w = weights ? weights[0] : x[0]*x[0]; - float sum_x = sum_w * x[0]; + double sum_w = weights ? (double)weights[0] : (double)(x[0]*x[0]); + double sum_x = sum_w * (double)x[0]; + double sum_x2 = sum_w * (double)x[0] * (double)x[0]; #ifdef HAVE_BUGGY_APPLE_LINKER // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 for (volatile int i = 1; i < n; ++i) { @@ -2152,8 +2211,9 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f if (x[i] < min) min = x[i]; if (x[i] > max) max = x[i]; float w = weights ? weights[i] : x[i]*x[i]; - sum_w += w; - sum_x += w * x[i]; + sum_w += (double)w; + sum_x += (double)w * (double)x[i]; + sum_x2 += (double)w * (double)x[i] * (double)x[i]; } if (min > 0) { min = 0; @@ -2165,13 +2225,13 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f } float iscale = nmax/(max - min); float scale = 1/iscale; - float best_mad = 0; + double best_mad = 0; for (int i = 0; i < n; ++i) { int l = nearest_int(iscale*(x[i] - min)); L[i] = MAX(0, MIN(nmax, l)); - float diff = scale * L[i] + min - x[i]; - diff = use_mad ? fabsf(diff) : diff*diff; - float w = weights ? weights[i] : x[i]*x[i]; + double diff = (double)scale * L[i] + (double)min - (double)x[i]; + diff = use_mad ? fabs(diff) : diff*diff; + double w = weights ? (double)weights[i] : (double)(x[i]*x[i]); best_mad += w * diff; } if (nstep < 1) { @@ -2180,30 +2240,35 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f } for (int is = 0; is <= nstep; ++is) { iscale = (rmin + rdelta*is + nmax)/(max - min); - float sum_l = 0, sum_l2 = 0, sum_xl = 0; + double sum_l = 0, sum_l2 = 0, sum_xl = 0; for (int i = 0; i < n; ++i) { int l = nearest_int(iscale*(x[i] - min)); l = MAX(0, MIN(nmax, l)); Laux[i] = l; float w = weights ? weights[i] : x[i]*x[i]; - sum_l += w*l; - sum_l2 += w*l*l; - sum_xl += w*l*x[i]; + sum_l += (double)w*l; + sum_l2 += (double)w*l*l; + sum_xl += (double)w*l*(double)x[i]; } - float D = sum_w * sum_l2 - sum_l * sum_l; + double D = sum_w * sum_l2 - sum_l * sum_l; if (D > 0) { - float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; - float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + double this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + double this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; if (this_min > 0) { this_min = 0; this_scale = sum_xl / sum_l2; } - float mad = 0; - for (int i = 0; i < n; ++i) { - float diff = this_scale * Laux[i] + this_min - x[i]; - diff = use_mad ? fabsf(diff) : diff*diff; - float w = weights ? weights[i] : x[i]*x[i]; - mad += w * diff; + double mad = 0; + if (use_mad) { + for (int i = 0; i < n; ++i) { + double diff = (double)this_scale * Laux[i] + (double)this_min - (double)x[i]; + diff = fabs(diff); + double w = weights ? (double)weights[i] : (double)(x[i]*x[i]); + mad += w * diff; + } + } else { + mad = sum_x2 - 2*this_scale*sum_xl - 2*this_min*sum_x + 2*this_scale*this_min*sum_l + + this_scale*this_scale*sum_l2 + this_min*this_min*sum_w; } if (mad < best_mad) { for (int i = 0; i < n; ++i) { @@ -2215,6 +2280,57 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f } } } + if (use_mad) { + *the_min = -min; + return scale; + } + + double sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = L[i]; + double w = weights ? (double)weights[i] : (double)(x[i]*x[i]); + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*(double)x[i]; + } + double best = 2*(double)scale*sum_xl + 2*(double)min*sum_x - 2*(double)scale*(double)min*sum_l + - (double)scale*(double)scale*sum_l2 - (double)min*(double)min*sum_w; + int last_j = -1, last_dir = 0; + for (int itry = 0; itry < nmax*n; ++itry) { + float gmax = 0; + int best_j = -1, dir = 0; + for (int j = 0; j < n; ++j) { + float g = x[j] - scale*L[j] - min; + if (g > 0 && L[j] < nmax && g > gmax) { + gmax = g; best_j = j; dir = 1; + } + else if (g < 0 && L[j] > 0 && -g > gmax) { + gmax = -g; best_j = j; dir = -1; + } + } + if (best_j < 0 || (best_j == last_j && dir == -last_dir)) break; + double w = weights ? (double)weights[best_j] : (double)(x[best_j]*x[best_j]); + sum_l += w*dir; + sum_l2 += w*(2*L[best_j]*dir + 1); + sum_xl += w*(double)x[best_j]*dir; + double D = (double)sum_w * sum_l2 - sum_l * sum_l; + if (D <= 0) break; + double this_scale = ((double)sum_w * sum_xl - (double)sum_x * sum_l)/D; + double this_min = (sum_l2 * (double)sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + if (this_scale < 0) break; + double score = 2*this_scale*sum_xl + 2*this_min*(double)sum_x - 2*this_scale*this_min*sum_l + - this_scale*this_scale*sum_l2 - this_min*this_min*(double)sum_w; + if (score <= best) break; + best = score; + scale = this_scale; + min = this_min; + L[best_j] += dir; + last_j = best_j; last_dir = dir; + } *the_min = -min; return scale; } @@ -2296,7 +2412,6 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri GGML_ASSERT(quant_weights); assert(k % QK_K == 0); const int nb = k / QK_K; - const bool requantize = true; uint8_t L[QK_K]; uint8_t Laux[16]; @@ -2310,7 +2425,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri memset(sw, 0, QK_K/16*sizeof(float)); float sumx2 = 0; for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; - float sigma2 = sumx2/QK_K; + float sigma2 = 0.75f*sumx2/QK_K; for (int j = 0; j < QK_K/16; ++j) { const float * restrict qw = quant_weights + QK_K * i + 16*j; for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); @@ -2318,31 +2433,25 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); } - float dm, mm; - dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); - mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); + float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); + float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); y[i].d = GGML_FP32_TO_FP16(dm); y[i].dmin = GGML_FP32_TO_FP16(mm); - dm = GGML_FP16_TO_FP32(y[i].d); - mm = GGML_FP16_TO_FP32(y[i].dmin); for (int j = 0; j < QK_K/16; ++j) { - y[i].scales[j] = Ls[j] | (Lm[j] << 4); - } - - if (requantize) { - for (int j = 0; j < QK_K/16; ++j) { - const float d = dm * (y[i].scales[j] & 0xF); - if (!d) continue; - const float m = mm * (y[i].scales[j] >> 4); - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int((x[16*j + ii] + m)/d); - l = MAX(0, MIN(3, l)); - L[16*j + ii] = l; - } + float d = dm*Ls[j]; + float m = mm*Lm[j]; + float id = d ? 1/d : 0.f; + for (int l = 0; l < QK_K/16; ++l) { + int q = nearest_int((x[16*j + l] + m)*id); + q = MAX(0, MIN(3, q)); + L[16*j + l] = q; } } + for (int j = 0; j < QK_K/16; ++j) { + y[i].scales[j] = Ls[j] | (Lm[j] << 4); + } for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { @@ -3253,8 +3362,12 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri const int64_t nb = n_per_row/QK4_0; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK4_0 * ib; - const float * qw = quant_weights + QK4_0 * ib; - for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + if (quant_weights) { + const float * qw = quant_weights + QK4_0 * ib; + for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < QK4_0; ++j) weight[j] = xb[j]*xb[j]; + } float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight); y[ib].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < 16; ++j) { @@ -5449,7 +5562,12 @@ void ggml_vec_dot_q6_0_q8_0(int n, float * restrict s, size_t bs, const void * r void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { +#ifdef HAVE_FANCY_SIMD + enum ggml_type dot_type = GGML_TYPE_Q8_1_X4; +#else + enum ggml_type dot_type = GGML_TYPE_Q8_0_X4; +#endif + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, dot_type, vy, by, s, bs, 0, 1)) { return; } #endif @@ -12967,6 +13085,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict } } +static int iq1_sort_helper(const void * left, const void * right) { + const float * l = left; + const float * r = right; + return *l < *r ? -1 : *l > *r ? 1 : 0; +} + static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS); @@ -12996,6 +13120,9 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v bool is_on_grid_aux[2]; uint8_t block_signs[2]; uint16_t q2[2*(QK_K/16)]; + uint16_t index[2], aux_index[2]; + float sumx[17], sumw[17], pairs[32]; + int * int_pairs = (int *)(pairs + 1); for (int ibl = 0; ibl < nbl; ++ibl) { @@ -13048,11 +13175,35 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v memset(L, 0, 16); continue; } - float best = 0; - float scale = max/(2*kMaxQ-1); + for (int j = 0; j < 16; ++j) { + pairs[2*j] = xval[j]; + int_pairs[2*j] = j; + } + qsort(pairs, 16, 2*sizeof(float), iq1_sort_helper); + { + sumx[0] = sumw[0] = 0; + for (int j = 0; j < 16; ++j) { + int i = int_pairs[2*j]; + sumx[j+1] = sumx[j] + weight[i]*xval[i]; + sumw[j+1] = sumw[j] + weight[i]; + } + } + float best = 0, scale = 0; + for (int i1 = 0; i1 <= 16; ++i1) { + for (int i2 = i1; i2 <= 16; ++i2) { + float sumqx = (sumx[i1] - sumx[0])*1 + (sumx[i2] - sumx[i1])*3 + (sumx[16] - sumx[i2])*5; + float sumq2 = (sumw[i1] - sumw[0])*1 + (sumw[i2] - sumw[i1])*9 + (sumw[16] - sumw[i2])*25; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + } + } + } + best = 0; + float eff_max = scale*(2*kMaxQ - 1); is_on_grid[0] = is_on_grid[1] = true; - for (int is = -9; is <= 9; ++is) { - float id = (2*kMaxQ-1+is*0.1f)/max; + index[0] = index[1] = 0; + for (int is = -7; is <= 7; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/eff_max; float this_scale = 1/id; for (int k = 0; k < 2; ++k) { for (int i = 0; i < 8; ++i) { @@ -13068,6 +13219,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); } + aux_index[k] = grid_index; } float sumqx = 0, sumq2 = 0; for (int i = 0; i < 16; ++i) { @@ -13080,35 +13232,45 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v scale = sumqx/sumq2; best = scale*sumqx; for (int i = 0; i < 16; ++i) L[i] = Laux[i]; for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + for (int k = 0; k < 2; ++k) index[k] = aux_index[k]; } } - int n_not_ongrid = 0; - for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; - if (n_not_ongrid > 0 && scale > 0) { - float id = 1/scale; - for (int k = 0; k < 2; ++k) { - if (is_on_grid[k]) continue; - uint16_t u = 0; - for (int i = 0; i < 8; ++i) { - int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); - l = MAX(0, MIN(kMaxQ-1, l)); - u |= (l << 2*i); - L[8*k + i] = l; + if (scale) { + for (int iter = 0; iter < 3; ++iter) { + float id = 1/scale; + bool changed = false; + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + Laux[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, Laux + 8*k); + } + aux_index[k] = grid_index; + if (grid_index != index[k]) changed = true; } - int grid_index = kmap_q2xs[u]; - if (grid_index < 0) { - const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; - grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + if (!changed) break; + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; + best = scale*sumqx; + memcpy(L, Laux, 16); + for (int k = 0; k < 2; ++k) index[k] = aux_index[k]; + } + else break; } - float sumqx = 0, sumq2 = 0; - for (int i = 0; i < 16; ++i) { - float w = weight[i]; - float q = 2*L[i] + 1; - sumqx += w*xval[i]*q; - sumq2 += w*q*q; - } - if (sumq2 > 0) scale = sumqx/sumq2; } if (scale < 0) { scale = -scale; @@ -13139,13 +13301,34 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v float d = max_scale/31; y[ibl].d = GGML_FP32_TO_FP16(d); float id = 1/d; + float sumqx = 0, sumq2 = 0; for (int ib = 0; ib < QK_K/16; ++ib) { int l = nearest_int(0.5f*(id*scales[ib]-1)); l = MAX(0, MIN(15, l)); if (ib%2 == 0) y[ibl].scales[ib/2] = l; else y[ibl].scales[ib/2] |= (l << 4); + l = 2*l + 1; + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; + } + for (int k = 0; k < 2; ++k) { + int grid_index = q2[2*ib+k] & 511; + const int8_t * grid = (const int8_t *)(iq2xs_grid + grid_index); + const uint8_t signs = ksigns_iq2xs[q2[2*ib+k] >> 9]; + for (int j = 0; j < 8; ++j) { + float w = weight[8*k+j]; + float q = 0.125f*l*grid[j]*(signs & kmask_iq2xs[j] ? -1.f : 1.f); + sumqx += w*q*xb[8*k+j]; + sumq2 += w*q*q; + } + } } memcpy(y[ibl].qs, q2, QK_K/4); + if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(1.05f*sumqx/sumq2); } } @@ -13985,12 +14168,6 @@ static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const return grid_index; } -static int iq1_sort_helper(const void * left, const void * right) { - const float * l = left; - const float * r = right; - return *l < *r ? -1 : *l > *r ? 1 : 0; -} - void iq1s_process_1block(int block_size, const float * xb, const float * weight, int8_t * L, float * the_scale, uint16_t * the_index, int * the_shift, float * pairs, float * sumx, float * sumw) { float max = fabsf(xb[0]); @@ -14214,6 +14391,8 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA}; float sumqx[4], sumq2[4]; + float sumw1[IQ1M_BLOCK_SIZE+1], sumw2[IQ1M_BLOCK_SIZE+1]; + float sumx1[IQ1M_BLOCK_SIZE+1], sumx2[IQ1M_BLOCK_SIZE+1]; const int gindex = iq2_data_index(GGML_TYPE_IQ1_M); @@ -14237,7 +14416,22 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo idx[2*j] = j; } qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); - float best_score = -FLT_MIN, scale = 0.f; + sumw1[0] = sumw2[0] = sumx1[0] = sumx2[0] = 0; + for (int j = 0; j < block_size; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumw1[j+1] = sumw1[j] + weight[i]; + sumx1[j+1] = sumx1[j] + weight[i]*xb[i]; + sumw2[j+1] = sumw2[j]; + sumx2[j+1] = sumx2[j]; + } else { + sumw2[j+1] = sumw2[j] + weight[i]; + sumx2[j+1] = sumx2[j] + weight[i]*xb[i]; + sumw1[j+1] = sumw1[j]; + sumx1[j+1] = sumx1[j]; + } + } + float best_score = 0, scale = 0.f; int besti1 = -1, besti2 = -1, best_k = -1; // 0: +, + // 1: +, - @@ -14245,74 +14439,22 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo // 3: -, - for (int i1 = 0; i1 <= block_size; ++i1) { for (int i2 = i1; i2 <= block_size; ++i2) { - memset(sumqx, 0, 4*sizeof(float)); - memset(sumq2, 0, 4*sizeof(float)); - for (int j = 0; j < i1; ++j) { - int i = idx[2*j]; - if (i < block_size/2) { - sumqx[0] += weight[i]*x_p[0]*xb[i]; - sumqx[1] += weight[i]*x_p[0]*xb[i]; - sumqx[2] += weight[i]*x_m[0]*xb[i]; - sumqx[3] += weight[i]*x_m[0]*xb[i]; - sumq2[0] += weight[i]*x_p[0]*x_p[0]; - sumq2[1] += weight[i]*x_p[0]*x_p[0]; - sumq2[2] += weight[i]*x_m[0]*x_m[0]; - sumq2[3] += weight[i]*x_m[0]*x_m[0]; - } else { - sumqx[0] += weight[i]*x_p[0]*xb[i]; - sumqx[2] += weight[i]*x_p[0]*xb[i]; - sumqx[1] += weight[i]*x_m[0]*xb[i]; - sumqx[3] += weight[i]*x_m[0]*xb[i]; - sumq2[0] += weight[i]*x_p[0]*x_p[0]; - sumq2[2] += weight[i]*x_p[0]*x_p[0]; - sumq2[1] += weight[i]*x_m[0]*x_m[0]; - sumq2[3] += weight[i]*x_m[0]*x_m[0]; - } - } - for (int j = i1; j < i2; ++j) { - int i = idx[2*j]; - if (i < block_size/2) { - sumqx[0] += weight[i]*x_p[1]*xb[i]; - sumqx[1] += weight[i]*x_p[1]*xb[i]; - sumqx[2] += weight[i]*x_m[1]*xb[i]; - sumqx[3] += weight[i]*x_m[1]*xb[i]; - sumq2[0] += weight[i]*x_p[1]*x_p[1]; - sumq2[1] += weight[i]*x_p[1]*x_p[1]; - sumq2[2] += weight[i]*x_m[1]*x_m[1]; - sumq2[3] += weight[i]*x_m[1]*x_m[1]; - } else { - sumqx[0] += weight[i]*x_p[1]*xb[i]; - sumqx[2] += weight[i]*x_p[1]*xb[i]; - sumqx[1] += weight[i]*x_m[1]*xb[i]; - sumqx[3] += weight[i]*x_m[1]*xb[i]; - sumq2[0] += weight[i]*x_p[1]*x_p[1]; - sumq2[2] += weight[i]*x_p[1]*x_p[1]; - sumq2[1] += weight[i]*x_m[1]*x_m[1]; - sumq2[3] += weight[i]*x_m[1]*x_m[1]; - } - } - for (int j = i2; j < block_size; ++j) { - int i = idx[2*j]; - if (i < block_size/2) { - sumqx[0] += weight[i]*x_p[2]*xb[i]; - sumqx[1] += weight[i]*x_p[2]*xb[i]; - sumqx[2] += weight[i]*x_m[2]*xb[i]; - sumqx[3] += weight[i]*x_m[2]*xb[i]; - sumq2[0] += weight[i]*x_p[2]*x_p[2]; - sumq2[1] += weight[i]*x_p[2]*x_p[2]; - sumq2[2] += weight[i]*x_m[2]*x_m[2]; - sumq2[3] += weight[i]*x_m[2]*x_m[2]; - } else { - sumqx[0] += weight[i]*x_p[2]*xb[i]; - sumqx[2] += weight[i]*x_p[2]*xb[i]; - sumqx[1] += weight[i]*x_m[2]*xb[i]; - sumqx[3] += weight[i]*x_m[2]*xb[i]; - sumq2[0] += weight[i]*x_p[2]*x_p[2]; - sumq2[2] += weight[i]*x_p[2]*x_p[2]; - sumq2[1] += weight[i]*x_m[2]*x_m[2]; - sumq2[3] += weight[i]*x_m[2]*x_m[2]; - } - } + sumqx[0] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] + + (sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2]; + sumqx[1] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] + + (sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2]; + sumqx[2] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] + + (sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2]; + sumqx[3] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] + + (sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2]; + sumq2[0] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] + + (sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2]; + sumq2[1] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] + + (sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2]; + sumq2[2] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] + + (sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2]; + sumq2[3] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] + + (sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2]; for (int k = 0; k < 4; ++k) { if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) { scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; @@ -14347,19 +14489,34 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo the_index[k] = grid_index; } if (!all_on_grid) { - float sumqx_f = 0, sumq2_f = 0; - for (int k = 0; k < block_size/8; ++k) { - if (k == 0) xx = best_k < 2 ? x_p : x_m; - else xx = best_k%2 == 0 ? x_p : x_m; - const int8_t * pg = (const int8_t *)(kgrid_q2xs + the_index[k]); - for (int j = 0; j < 8; ++j) { - float w = weight[8*k + j]; - float q = xx[(pg[j] - 1)/2]; - sumqx_f += w*q*xb[8*k+j]; - sumq2_f += w*q*q; + sumqx[0] = sumqx[1] = sumqx[2] = sumqx[3] = 0; + sumq2[0] = sumq2[1] = sumq2[2] = sumq2[3] = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float qp = x_p[L[j]]; + float qm = x_m[L[j]]; + sumqx[0] += w*xb[j]*qp; + sumq2[0] += w*qp*qp; + sumqx[3] += w*xb[j]*qm; + sumq2[3] += w*qm*qm; + if (j < 8) { + sumqx[1] += w*xb[j]*qp; + sumq2[1] += w*qp*qp; + sumqx[2] += w*xb[j]*qm; + sumq2[2] += w*qm*qm; + } else { + sumqx[2] += w*xb[j]*qp; + sumq2[2] += w*qp*qp; + sumqx[1] += w*xb[j]*qm; + sumq2[1] += w*qm*qm; + } + } + best_score = 0; + for (int k = 0; k < 4; ++k) { + if (sumqx[k] > 0 && sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) { + scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; best_k = k; } } - if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f; } *the_scale = scale; *the_shift = best_k; @@ -14393,6 +14550,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA}; const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA}; const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88}; + float all_sigma2[QK_K/32]; iq1m_scale_t s; const float * xx; @@ -14405,11 +14563,18 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy float max_scale = 0; const float * xbl = x + QK_K*ibl; - float sumx2 = 0; - for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; - float sigma2 = 2*sumx2/QK_K; + for (int ib = 0; ib < QK_K/32; ++ib) { + const float * xb = xbl + 32*ib; + float sumx2 = 0; + for (int i = 0; i < 32; ++i) sumx2 += xb[i]*xb[i]; + all_sigma2[ib] = 1.5f*sumx2/32; + } + //float sumx2 = 0; + //for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + //float sigma2 = 1.5f*sumx2/QK_K; for (int ib = 0; ib < QK_K/block_size; ++ib) { + float sigma2 = all_sigma2[ib/2]; const float * xb = xbl + block_size*ib; if (quant_weights) { const float * qw = quant_weights + QK_K*ibl + block_size*ib; @@ -14418,12 +14583,21 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; } float max = fabsf(xb[0]); - for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); + float sumwx = 0; + for (int i = 1; i < block_size; ++i) { + float ax = fabsf(xb[i]); + max = MAX(max, ax); + sumwx += weight[i]*ax; + } if (max < GROUP_MAX_EPS_IQ1_M) { scales[ib] = 0; memset(L, 1, block_size); continue; } + if (sumwx == 0) { + // weight is zero everywhere where xb is not zero => ignore + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } int best_k = -1; iq1m_process_1block(xb, weight, L, &scales[ib], index, &best_k, pairs); @@ -14444,6 +14618,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy float id = 1/d; float sumqx_f = 0, sumq2_f = 0; for (int ib = 0; ib < QK_K/block_size; ++ib) { + float sigma2 = all_sigma2[ib/2]; int l = nearest_int(0.5f*(id*scales[ib+0]-1)); l = MAX(0, MIN(7, l)); sc[ib/4] |= (l << 3*(ib%4)); @@ -14468,7 +14643,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy } } if (sumq2_f > 0) d = sumqx_f/sumq2_f; - s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed. + s.f16 = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed. sc[0] |= ((s.u16 & 0x000f) << 12); sc[1] |= ((s.u16 & 0x00f0) << 8); sc[2] |= ((s.u16 & 0x0f00) << 4); @@ -14575,6 +14750,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block } d = sumqx/sumq2; float best = d*sumqx; + float best_sumqx = sumqx, best_sumq2 = sumq2; for (int itry = -ntry; itry <= ntry; ++itry) { id = (itry + values[0])/max; sumqx = sumq2 = 0; @@ -14588,8 +14764,67 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block } if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { d = sumqx/sumq2; best = d * sumqx; + best_sumqx = sumqx; best_sumq2 = sumq2; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + Lb[j] = best_index_iq4nl(values, al); + } + } + id = (itry + values[15])/max; + sumqx = sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d * sumqx; + best_sumqx = sumqx; best_sumq2 = sumq2; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + Lb[j] = best_index_iq4nl(values, al); + } } } + sumqx = best_sumqx; sumq2 = best_sumq2; + for (int iter = 0; iter < 32*block_size; ++iter) { + float min_step = INFINITY; + int best_j = -1; int dir = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float g = d * w * (xb[j] - d*values[Lb[j]]); + if (g > 0 && Lb[j] < 15) { + float step = (values[Lb[j]+1] - values[Lb[j]])/g; + if (step < min_step) { + min_step = step; best_j = j; dir = 1; + } + } + else if (g < 0 && Lb[j] > 0) { + float step = (values[Lb[j]-1] - values[Lb[j]])/g; + if (step < min_step) { + min_step = step; best_j = j; dir = -1; + } + } + } + if (best_j < 0) break; + + float new_sumqx = sumqx, new_sumq2 = sumq2; + float w = weight[best_j]; + new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]); + new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]); + if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) { + sumqx = new_sumqx; sumq2 = new_sumq2; + d = sumqx/sumq2; best = d*sumqx; + Lb[best_j] += dir; + } + else { + break; + } + } + scales[ib] = d; float abs_d = fabsf(d); if (abs_d > amax_scale) { @@ -15217,8 +15452,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ3_K_R4: break; case GGML_TYPE_IQ4_K_R4: break; case GGML_TYPE_IQ5_K_R4: break; - case GGML_TYPE_IQ4_KS_R4: break; - case GGML_TYPE_Q8_K_R8: break; + case GGML_TYPE_IQ4_KS_R4:break; + case GGML_TYPE_Q8_KV_R8: break; + case GGML_TYPE_Q8_K_R8: break; + case GGML_TYPE_Q8_KV: break; case GGML_TYPE_BF16_R16: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4bc377d5..8e58c93a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14,6 +14,7 @@ #include "iqk/iqk_quantize.h" #if GGML_USE_IQK_MULMAT #include "iqk/iqk_mul_mat.h" +#include "iqk/iqk_config.h" #endif #if defined(_MSC_VER) || defined(__MINGW32__) @@ -716,7 +717,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q4_0_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -740,7 +741,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, .vec_dot = ggml_vec_dot_q4_1_q8_1, #if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else .vec_dot_type = GGML_TYPE_Q8_1_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_1, #endif @@ -788,7 +793,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q5_0_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -808,7 +813,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, .vec_dot = ggml_vec_dot_q5_1_q8_1, #if GGML_USE_IQK_MULMAT +#ifdef __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else .vec_dot_type = GGML_TYPE_Q8_1_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_1, #endif @@ -826,7 +835,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q6_0_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -847,12 +856,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_to_mat = quantize_mat_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0, #if GGML_USE_IQK_MULMAT -#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) +#ifdef HAVE_FANCY_SIMD // Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2 // because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range // (and it gets satured if it does), leading to wrong results. - // TODO: expose HAVE_FANCY_SIMD from iqk_mul_mat.cpp and use #ifdef HAVE_FANCY_SIMD instead of the above. - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -897,6 +905,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q8_2_X4] = { + .type_name = "q8_2_x4", + .blck_size = QK8_2, + .type_size = sizeof(block_q8_2), + .is_quantized = true, + .from_float = quantize_row_q8_2_x4, + .from_float_ref = quantize_row_q8_2_x4, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -1271,8 +1289,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, .vec_dot = ggml_vec_dot_iq4_nl_q8_0, #if GGML_USE_IQK_MULMAT -#if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, +#if defined HAVE_FANCY_SIMD + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1362,6 +1380,30 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_K128, .row_meta_size = 0, }, + [GGML_TYPE_Q8_KV] = { + .type_name = "q8_KV", + .blck_size = 32, + .type_size = 32, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_KV, + .from_float = quantize_row_q8_KV, + .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_ref, + .vec_dot = vec_dot_q8_KV_q8_KV, + .vec_dot_type = GGML_TYPE_Q8_KV, + .row_meta_size = 8, + }, + [GGML_TYPE_Q8_KV_R8] = { + .type_name = "q8_KV_r8", + .blck_size = 32, + .type_size = 32, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_KV_r8, + .from_float = quantize_row_q8_KV_r8, + .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_r8_ref, + .vec_dot = vec_dot_q8_KV_r8_q8_KV, + .vec_dot_type = GGML_TYPE_Q8_KV, + .row_meta_size = 4, + }, [GGML_TYPE_Q8_K16] = { .type_name = "q8_K16", .blck_size = 64, @@ -1643,7 +1685,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq4_nl_r4_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1677,7 +1719,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_q4_0_r8_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1698,7 +1740,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_q8_0_r8_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1719,7 +1761,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_q5_0_r4_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1740,7 +1782,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_q6_0_r4_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1771,6 +1813,15 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { return type_traits[type]; } +static inline int ggml_packed_rows(enum ggml_type type) { + return type == GGML_TYPE_BF16_R16 ? 16 + : type == GGML_TYPE_Q8_K_R8 || type == GGML_TYPE_Q8_KV_R8 || + type == GGML_TYPE_Q8_0_R8 || type == GGML_TYPE_Q4_0_R8 || + type == GGML_TYPE_IQ4_XS_R8 ? 8 + : type >= GGML_TYPE_Q4_0_R8 && type <= GGML_TYPE_Q8_K_R8 ? 4 + : 1; +} + // // simd mappings // @@ -3860,6 +3911,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "MUL_MAT", "MUL_MAT_ID", "OUT_PROD", + "MOE_FUSED_UP_GATE", "SCALE", "SET", @@ -3889,6 +3941,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", + "ARGSORT_THRESH", "LEAKY_RELU", "SOFTCAP", "SOFT_CAP_MAX", @@ -3919,7 +3972,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3953,6 +4006,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "X*Y", "X[i]*Y", "X*Y", + "X*Y1&X*Y2", "x*v", "y-\\>view(x)", @@ -3982,6 +4036,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", + "argsort_thresh(x)", "leaky_relu(x)", "k2*tanh(k1*x)", "soft_max(k2*tanh(k1*x))", @@ -4012,7 +4067,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4280,6 +4335,9 @@ GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { } GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) return 0; + } size_t nbytes; size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { @@ -4412,6 +4470,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q8_KV: wtype = GGML_TYPE_Q8_KV; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q2_K_R4: wtype = GGML_TYPE_Q2_K_R4; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; @@ -4423,6 +4482,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break; case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break; + case GGML_FTYPE_MOSTLY_Q8_KV_R8: wtype = GGML_TYPE_Q8_KV_R8; break; case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break; case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; @@ -6784,6 +6844,51 @@ struct ggml_tensor * ggml_mul_mat_id( return result; } +struct ggml_tensor * ggml_moe_up_gate( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + enum ggml_unary_op op) { + if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) { + struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids); + struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids); + return ggml_fused_mul_unary(ctx, result_gate, result_up, op); + } + GGML_ASSERT(!ggml_is_transposed(as_up)); + GGML_ASSERT(!ggml_is_transposed(as_gate)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row + GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast + + bool is_node = false; + + if (as_up->grad || as_gate->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MOE_FUSED_UP_GATE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = as_up; + result->src[1] = as_gate; + result->src[2] = b; + result->src[3] = ids; + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + return result; +} + + // ggml_out_prod struct ggml_tensor * ggml_out_prod( @@ -8463,6 +8568,27 @@ struct ggml_tensor * ggml_argsort( return result; } +// ggml_argsort + +struct ggml_tensor * ggml_argsort_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int min_entries, + float thresh) { + bool is_node = false; + + //printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); + + ggml_set_op_params_i32(result, 0, (int32_t) min_entries); + ggml_set_op_params_f32(result, 1, thresh); + + result->op = GGML_OP_ARGSORT_THRESH; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} // ggml_top_k @@ -8482,6 +8608,32 @@ struct ggml_tensor * ggml_top_k( return result; } +// ggml_top_k_thresh + +struct ggml_tensor * ggml_top_k_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k, + int min_entries, + float thresh) { + GGML_ASSERT(a->ne[0] >= k); + + //printf("%s: k = %d, min_entries = %d, thresh = %g\n", __func__, k, min_entries, (double)thresh); + struct ggml_tensor * result; + if (min_entries <= 0 || thresh <= 0) { + result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + } else { + result = ggml_argsort_thresh(ctx, a, min_entries, thresh); + } + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( @@ -8515,8 +8667,12 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } + // k*q will be { k->ne[1], q->ne[2], q->ne[1], q->ne[3] } + // v^T is { v->ne[1], v->ne[0], v->ne[2], v->ne[3] } + // => result is { v->ne[0], q->ne[2], q->ne[1], q->ne[3] } // permute(0, 2, 1, 3) - int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + //int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); float params[] = { scale, max_bias, softcap }; @@ -9474,7 +9630,7 @@ static void ggml_compute_forward_dup_f16( float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type)); char * dst_ptr = (char *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -9760,7 +9916,7 @@ static void ggml_compute_forward_dup_bf16( float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type)); char * dst_ptr = (char *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -10032,9 +10188,11 @@ static void ggml_compute_forward_dup_f32( } // parallelize by rows + int n_packed = ggml_packed_rows(dst->type); + GGML_ASSERT(dst->ne[1] % n_packed == 0); const int nr = ne01; // number of rows per thread - const int dr = (nr + nth - 1) / nth; + const int dr = n_packed*((nr/n_packed + nth - 1) / nth); // row range for this thread const int ir0 = dr * ith; const int ir1 = MIN(ir0 + dr, nr); @@ -10080,16 +10238,16 @@ static void ggml_compute_forward_dup_f32( ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type)); char * dst_ptr = (char *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { for (int i02 = 0; i02 < ne02; i02++) { id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { + for (int i01 = ir0; i01 < ir1; i01 += n_packed) { const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - quantize_row_q(src0_ptr, dst_ptr + id, ne00); - id += rs; + quantize_row_q(src0_ptr, dst_ptr + id, ne00*n_packed); + id += rs*n_packed; } id += rs * (ne01 - ir1); } @@ -10354,17 +10512,22 @@ static void ggml_compute_forward_dup_bytes( // parallelize by rows const int nr = ne01; + const int n_packed = ggml_packed_rows(dst->type); + GGML_ASSERT(nr%n_packed == 0); + const int nrp = nr/n_packed; // number of rows per thread - const int dr = (nr + nth - 1) / nth; + const int drp = (nrp + nth - 1) / nth; + const int dr = drp*n_packed; // row range for this thread const int ir0 = dr * ith; + if (ir0 >= nr) return; const int ir1 = MIN(ir0 + dr, nr); if (src0->type == dst->type && ne00 == ne0 && nb00 == type_size && nb0 == type_size) { // copy by rows - const size_t rs = ne00 * type_size; + const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ir0; i01 < ir1; i01++) { @@ -10381,7 +10544,7 @@ static void ggml_compute_forward_dup_bytes( if (ggml_is_contiguous(dst)) { size_t id = 0; char * dst_ptr = (char *) dst->data; - const size_t rs = ne00 * type_size; + const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size; if (nb00 == type_size) { // src0 is contigous on first dimension, copy by rows @@ -10478,12 +10641,122 @@ static void ggml_compute_forward_dup_bytes( } } +static void ggml_compute_forward_dup_q( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + GGML_ASSERT(ggml_is_quantized(dst->src[0]->type)); + + int64_t nrows = ggml_nrows(dst); + int ith = params->ith; + int nth = params->nth; + + if (dst->src[0]->type == dst->type && + dst->src[0]->nb[0] == ggml_type_size(dst->type) && + dst->nb[0] == ggml_type_size(dst->type)) { + ggml_compute_forward_dup_bytes(params, dst); + return; + } + + if (dst->type == GGML_TYPE_Q8_0 && dst->src[0]->type == GGML_TYPE_Q8_0 && + ggml_are_same_shape(dst, dst->src[0])) { + + if (dst->src[0]->nb[0] == sizeof(block_q8_0) && dst->nb[0] == sizeof(block_q8_0)) { + ggml_compute_forward_dup_bytes(params, dst); + return; + } + + // we assume src is transposed and that's why we are here + + GGML_ASSERT(dst->ne[0] % QK8_0 == 0); + + struct ggml_tensor * const src = dst->src[0]; + GGML_ASSERT(src->nb[1] == sizeof(block_q8_0)); + + float aux[QK8_0]; + + int64_t n_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*n_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + n_per_thread, nrows); + + int64_t nblock = dst->ne[0] / QK8_0; + for (int64_t ir = first_row; ir < last_row; ++ir) { + int64_t i3 = ir/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (ir - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = ir - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + int ib0 = i1/QK8_0; + int iq0 = i1%QK8_0; + for (int ib = 0; ib < nblock; ++ib) { + block_q8_0 * dst_q8 = (block_q8_0 *)((char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3]); + float amax = 0; + for (int j = 0; j < QK8_0; ++j) { + int64_t i0 = ib*QK8_0 + j; + const block_q8_0 * src_q8 = (const block_q8_0 *)((const char *)src->data + i0*src->nb[0] + i2*src->nb[2] + i3*src->nb[3]); + float xi = GGML_FP16_TO_FP32(src_q8[ib0].d) * src_q8[ib0].qs[iq0]; + aux[j] = xi; + xi = fabsf(xi); + amax = MAX(amax, xi); + } + float d = amax/127; + dst_q8[ib].d = GGML_FP32_TO_FP16(d); + if (d > 0) { + float id = 1/d; + for (int j = 0; j < QK8_0; ++j) dst_q8[ib].qs[j] = roundf(id*aux[j]); + } else { + memset(dst_q8[ib].qs, 0, QK8_0); + } + } + } + return; + } + + if (dst->type != GGML_TYPE_F32) { + printf("%s: %s -> %s is of type %s\n", __func__, dst->src[0]->name, dst->name, ggml_type_name(dst->type)); + GGML_ABORT("fatal error"); + } + GGML_ASSERT(dst->type == GGML_TYPE_F32); + struct ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(src0->ne[0] == dst->ne[0] && src0->nb[0] == ggml_type_size(src0->type)); + + ggml_to_float_t to_float = type_traits[src0->type].to_float; + GGML_ASSERT(to_float != NULL); + + int n_packed = ggml_packed_rows(src0->type); + GGML_ASSERT(src0->ne[1] % n_packed == 0); + + int64_t n_per_thread = n_packed*((nrows/n_packed + nth - 1)/nth); + int64_t first_row = ith*n_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + n_per_thread, nrows); + + for (int64_t ir = first_row; ir < last_row; ir += n_packed) { + int64_t i03 = ir/(src0->ne[1]*src0->ne[2]); + int64_t i02 = (ir - i03*src0->ne[1]*src0->ne[2])/src0->ne[1]; + int64_t i01 = ir - i03*src0->ne[1]*src0->ne[2] - i02*src0->ne[1]; + int64_t i3 = ir/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (ir - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = ir - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + + const char * q = (const char *)src0->data + i03*src0->nb[3] + i02*src0->nb[2] + i01*src0->nb[1]; + char * f = ( char *)dst->data + i3* dst->nb[3] + i2* dst->nb[2] + i1* dst->nb[1]; + + to_float((const void *)q, (float *)f, src0->ne[0]*n_packed); + } + +} + static void ggml_compute_forward_dup( const struct ggml_compute_params * params, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; + if (ggml_is_quantized(src0->type)) { + ggml_compute_forward_dup_q(params, dst); + return; + } + if (src0->type == dst->type) { ggml_compute_forward_dup_bytes(params, dst); return; @@ -10974,6 +11247,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -11436,6 +11710,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -11447,6 +11722,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -11606,6 +11882,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -11617,6 +11894,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -12467,6 +12745,43 @@ static void ggml_compute_forward_repeat_f16( } } +static void ggml_compute_forward_repeat_any( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(ggml_can_repeat(src, dst)); + GGML_ASSERT(src->type == dst->type); + GGML_ASSERT(src->nb[0] == ggml_type_size(src->type)); + int64_t src_row_size = ggml_row_size(src->type, src->ne[0]); + GGML_ASSERT((int64_t )dst->nb[1] == src_row_size*dst->ne[0]/src->ne[0]); + + int ith = params->ith; + int nth = params->nth; + + int64_t nrows = ggml_nrows(dst); + int64_t nrows_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*nrows_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + nrows_per_thread, nrows); + + for (int64_t row = first_row; row < last_row; ++row) { + int64_t i3 = row/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3]; + int64_t i03 = i3 % src->ne[3]; + int64_t i02 = i2 % src->ne[2]; + int64_t i01 = i1 % src->ne[1]; + const char * x = (const char *)src->data + i01*src->nb[1] + i02*src->nb[2] + i03*src->nb[3]; + for (int64_t ir = 0; ir < dst->ne[0]/src->ne[0]; ++ir) { + memcpy(y, x, src_row_size); + y += src_row_size; + } + } +} + static void ggml_compute_forward_repeat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -12487,7 +12802,8 @@ static void ggml_compute_forward_repeat( } break; default: { - GGML_ABORT("fatal error"); + ggml_compute_forward_repeat_any(params, dst); + //GGML_ABORT("fatal error"); } } } @@ -12595,6 +12911,26 @@ static void ggml_compute_forward_concat_f32( GGML_ASSERT(dim >= 0 && dim < 4); + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst) && + (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) { + // simply copy the data + const int64_t size_src_0 = ggml_nbytes(src0); + const int64_t size_src_1 = ggml_nbytes(src1); + const int64_t block_size = 4096; + const int64_t num_blocks = (size_src_0 + size_src_1 + block_size - 1)/block_size; + for (int64_t i_block = ith; i_block < num_blocks; i_block += nth) { + const int64_t start = i_block*block_size; + if (start < size_src_0) { + int64_t copy_size = MIN(block_size, size_src_0 - start); + memcpy((char *)dst->data + start, (char *)src0->data + start, copy_size); + } else { + int64_t copy_size = MIN(block_size, size_src_0 + size_src_1 - start); + memcpy((char *)dst->data + start, (char *)src1->data + start - size_src_0, copy_size); + } + } + return; + } + int64_t o[4] = {0, 0, 0, 0}; o[dim] = src0->ne[dim]; @@ -12620,6 +12956,44 @@ static void ggml_compute_forward_concat_f32( } } +static void ggml_compute_forward_concat_any( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == src1->type && src0->type == dst->type); + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + // Let's do it for dim = 0 only for now + GGML_ASSERT(dim == 0); + + int ith = params->ith; + int nth = params->nth; + + int64_t nrows = ggml_nrows(dst); + int64_t nrows_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*nrows_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + nrows_per_thread, nrows); + + int64_t src0_row_size = ggml_row_size(src0->type, src0->ne[0]); + int64_t src1_row_size = ggml_row_size(src1->type, src1->ne[0]); + + for (int64_t row = first_row; row < last_row; ++row) { + int64_t i3 = row/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3]; + const char * x0 = (const char *)src0->data + i1*src0->nb[1] + i2*src0->nb[2] + i3*src0->nb[3]; + const char * x1 = (const char *)src1->data + i1*src1->nb[1] + i2*src1->nb[2] + i3*src1->nb[3]; + memcpy(y, x0, src0_row_size); + memcpy(y + src0_row_size, x1, src1_row_size); + } + +} + static void ggml_compute_forward_concat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -12634,7 +13008,8 @@ static void ggml_compute_forward_concat( } break; default: { - GGML_ABORT("fatal error"); + ggml_compute_forward_concat_any(params, dst); + //GGML_ABORT("fatal error"); } } } @@ -14108,39 +14483,20 @@ static void ggml_compute_forward_mul_mat( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE +#if GGML_USE_LLAMAFILE // broadcast factors const int64_t r2 = ne12 / ne02; const int64_t r3 = ne13 / ne03; #endif #if GGML_USE_IQK_MULMAT - if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) { - int counter = 0; - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - if (counter++ % nth == ith) { - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - 0, 1)) goto IQK_MulMat_Not_Available1; - } - } - } - return; - } if (dst->type == GGML_TYPE_F32) { - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith, nth)) goto IQK_MulMat_Not_Available1; - return; + if (iqk_mul_mat_4d(ne01, ne11, ne00, + ne02, ne03, ne12, ne13, nb02, nb03, nb12, nb13, nb2/sizeof(float), nb3/sizeof(float), + src0->type, src0->data, nb01, + src1->type, src1->data, nb11, + (float *)dst->data, nb1/sizeof(float), ith, nth)) return; } -IQK_MulMat_Not_Available1:; #endif #if GGML_USE_LLAMAFILE @@ -14179,7 +14535,40 @@ UseGgmlGemm1:; const size_t nbw3 = nbw2*ne12; assert(params->wsize >= ne13*nbw3); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + if (src1->type != GGML_TYPE_F32) { +#if GGML_USE_IQK_MULMAT + char * work_buffer = wdata + ne13*nbw3 + ith*ne10*sizeof(float); + GGML_ASSERT(params->wsize >= ne13*nbw3 + nth*ne10*sizeof(float)); + iqk_quantize_any(src1->type, vec_dot_type, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + src1->data, wdata, work_buffer, type_traits[src1->type].to_float, from_float, ith, nth); +#else + GGML_ABORT("fatal error"); +#endif + } + else { + +//#ifdef GGML_USE_IQK_MULMAT +// int ts = type_traits[vec_dot_type].type_size; +// int bs = type_traits[vec_dot_type].blck_size; +// int64_t blocks_per_row = ne10/bs; +// int64_t num_blocks = ne11*ne12*ne13*blocks_per_row; +// int gcd = simple_gcd(128, ts); // 128 is to cover cache line sizes for common architectures without getting involved +// // with trying to get it from ggml +// int64_t num_blocks_gcd = (num_blocks + gcd - 1)/gcd; +// int64_t block_per_thread = ((num_blocks_gcd + nth - 1)/nth)*gcd; +// int64_t first_block = ith*block_per_thread; +// int64_t last_block = MIN(num_blocks, first_block + block_per_thread); +// while (first_block < last_block) { +// int64_t i13 = first_block/(ne11*ne12*blocks_per_row); +// int64_t i12 = (first_block - i13*ne11*ne12*blocks_per_row)/(ne11*blocks_per_row); +// int64_t i11 = (first_block - (i13*ne12 + i12)*ne11*blocks_per_row)/blocks_per_row; +// int64_t i10 = first_block % blocks_per_row; +// int64_t blocks_to_do = MIN(blocks_per_row - i10, last_block - first_block); +// from_float((float *)((char *)src1->data + i13*nb13 + i12*nb12 + i11*nb11) + i10*bs, +// (void *)(wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + i10*ts), blocks_to_do*bs); +// first_block += blocks_to_do; +// } +//#else for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -14201,6 +14590,8 @@ UseGgmlGemm1:; } } } +//#endif + } ggml_barrier(params->shared); @@ -14221,34 +14612,14 @@ UseGgmlGemm1:; #if GGML_USE_IQK_MULMAT if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) { - // When K*Q and V*softmax(K*Q) (so ne12*ne13 > 1), it is better (faster) to have fewer threads processing - // one matrix multiplication, but work on several heads at once. - // Hence, we find the GCD(n12*ne13, nth) and have nth/GCD(n12*ne13, nth) threads per head. - // Leaving the previous version commented out for now just in case. const size_t row_size = ggml_row_size(vec_dot_type, ne10); - int ntg = simple_gcd(ne12*ne13, nth); - int counter = 0; - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - if (counter++ % ntg == ith%ntg) { - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith/ntg, nth/ntg)) goto IQK_MulMat_Not_Available2; - } - } - } - //for (int64_t i13 = 0; i13 < ne13; i13++) - // for (int64_t i12 = 0; i12 < ne12; i12++) - // if (!iqk_mul_mat(ne01, ne11, ne00, - // src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - // vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), - // (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - // ith, nth)) goto IQK_MulMat_Not_Available2; - return; + if (iqk_mul_mat_4d(ne01, ne11, ne00, + ne02, ne03, ne12, ne13, nb02, nb03, row_size*ne11, row_size*ne11*ne12, + nb2/sizeof(float), nb3/sizeof(float), + src0->type, src0->data, nb01, + vec_dot_type, wdata, row_size, + (float *)dst->data, nb1/sizeof(float), ith, nth)) return; } -IQK_MulMat_Not_Available2:; #endif #if GGML_USE_LLAMAFILE @@ -14406,7 +14777,7 @@ static void ggml_compute_forward_mul_mat_id( char * wdata_src1_end = (src1->type == vec_dot_type) ? (char *) params->wdata : - (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t)); struct mmid_row_mapping { int32_t i1; @@ -14448,7 +14819,8 @@ static void ggml_compute_forward_mul_mat_id( for (int id = 0; id < n_ids; ++id) { const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); - assert(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //assert(i02 >= 0 && i02 < n_as); MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; matrix_row_counts[i02] += 1; @@ -14614,6 +14986,156 @@ IQK_MulMat_Not_Available:; #undef MMID_MATRIX_ROW } +#if GGML_USE_IQK_MULMAT +static void ggml_compute_forward_mul_mat_id_up_gate( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + GGML_ASSERT(dst->src[0]->type == dst->src[1]->type); + GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst->src[1])); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const struct ggml_tensor * src1 = dst->src[2]; + const struct ggml_tensor * ids = dst->src[3]; + const struct ggml_tensor * src0_1 = dst->src[0]; + const struct ggml_tensor * src0_2 = dst->src[1]; + const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + GGML_ASSERT(ne13 == 1); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + char * wdata_src1_end = (src1->type == vec_dot_type) ? + (char *) params->wdata : + (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t)); + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] + + if (src1->type != vec_dot_type) { + + ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + + char * wdata = params->wdata; + + const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + assert(params->wsize >= ne13*nbw3); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); + } + } + } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] + + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + + // group rows by src0 matrix + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + if (i02 < 0 || i02 >= n_as) continue; + //assert(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; + matrix_row_counts[i02] += 1; + } + } + } + + ggml_barrier(params->shared); + + + // so GGML_TENSOR_BINARY_OP_LOCALS works + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02; + const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02; + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows + // + if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0], + type, src0_1_cur, src0_2_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst->data, nb1, nb2, + matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); + +// if (nth%2 == 0) { +// const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; +// void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// type, src0_d, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst_d, nb1, nb2, +// matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); +// +// } else { +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// src0_1->type, (const char *)src0_1_cur, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst1->data, nb1, nb2, +// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// src0_2->type, (const char *)src0_2_cur, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst2->data, nb1, nb2, +// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); +// } + } + +#undef MMID_MATRIX_ROW +} +#endif + // ggml_compute_forward_out_prod static void ggml_compute_forward_out_prod_f32( @@ -14830,6 +15352,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_1: case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_KV: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -14841,6 +15364,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -15240,6 +15764,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -15251,6 +15776,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -15397,7 +15923,11 @@ static void ggml_compute_forward_get_rows_q( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + if (i01 < 0 || i01 >= ne01) { + memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float)); + continue; + } + //assert(i01 >= 0 && i01 < ne01); dequantize_row_q( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), @@ -15438,11 +15968,14 @@ static void ggml_compute_forward_get_rows_f16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + if (i01 >= 0 && i01 < ne01) { + ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } else { + memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float)); + } - ggml_fp16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } } @@ -15479,11 +16012,13 @@ static void ggml_compute_forward_get_rows_bf16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); - - ggml_bf16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + if (i01 >= 0 && i01 < ne01) { + ggml_bf16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } else { + memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float)); + } } } @@ -15520,11 +16055,13 @@ static void ggml_compute_forward_get_rows_f32( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), - (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + if (i01 >= 0 && i01 < ne01) { + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } else { + memset((char *)dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float)); + } } } @@ -15541,9 +16078,11 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_1: case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_KV: case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -15555,6 +16094,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -15972,28 +16512,28 @@ static void ggml_compute_forward_soft_max_f32( } } -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } -#endif +//#ifndef NDEBUG +// for (int i = 0; i < nc; ++i) { +// //printf("p[%d] = %f\n", i, p[i]); +// assert(!isnan(wp[i])); +// } +//#endif float max = -INFINITY; ggml_vec_max_f32(nc, &max, wp); ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - assert(sum > 0.0); + //assert(sum > 0.0); sum = 1.0/sum; ggml_vec_scale_f32(nc, dp, sum); -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } -#endif +//#ifndef NDEBUG +// for (int i = 0; i < nc; ++i) { +// assert(!isnan(dp[i])); +// assert(!isinf(dp[i])); +// } +//#endif } } @@ -16176,6 +16716,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -16187,6 +16728,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_Q8_KR8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: @@ -16233,6 +16775,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K64: case GGML_TYPE_Q8_K128: + case GGML_TYPE_Q8_KV: case GGML_TYPE_Q8_K16: case GGML_TYPE_Q8_K32: case GGML_TYPE_Q4_0_4_4: @@ -17486,6 +18029,75 @@ static void ggml_compute_forward_argsort( } } +// ggml_compute_forward_argsort_thresh + +static void ggml_compute_forward_argsort_thresh_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + int min_entries = ggml_get_op_params_i32(dst, 0); + float thresh = ggml_get_op_params_f32(dst, 1); + + //if (ith == 0) printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if (src_data[dst_data[j]] < src_data[dst_data[k]]) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + float max_value = src_data[dst_data[0]]; + //printf("Row %ld: max_value is %g, next is %g\n", i, (double)max_value, (double)src_data[dst_data[1]]); + for (int j = min_entries; j < ne0; ++j) { + if (src_data[dst_data[j]] < max_value*thresh) { + //printf(" row %ld: turning off expert %d(%d) with value %g\n", i, j, dst_data[j], (double)src_data[dst_data[j]]); + dst_data[j] = -1; + } + } + } +} + +static void ggml_compute_forward_argsort_thresh( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argsort_thresh_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( @@ -17508,10 +18120,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ith = params->ith; const int nth = params->nth; - const int64_t D = neq0; - const int64_t N = neq1; + const int64_t Dk = nek0; + const int64_t Dv = nev0; + const int64_t N = neq1; - GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne0 == Dv); GGML_ASSERT(ne2 == N); // input tensor rows must be contiguous @@ -17519,12 +18132,12 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(neq0 == Dk); + GGML_ASSERT(nek0 == Dk); + GGML_ASSERT(nev0 == Dv); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(nev0 == Dv); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -17564,44 +18177,57 @@ static void ggml_compute_forward_flash_attn_ext_f16( } #if GGML_USE_IQK_MULMAT - if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { - // I keep changing my mind what is the best strategy to split the threads when processing - // multiple heads. This is my current thinking, the commented out code below was the previous. - int ntg = nth/simple_gcd(neq2*neq3, nth); - int64_t neq1g = (neq1 + ntg - 1)/ntg; - //int64_t work_per_slice = D*nek1*neq1; - //int ntg = 1; - // - // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix - // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of - // the number of threads processing the (iq2, iq3) matrix. - // - //if (neq1 >= 8*nth) { - // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; - // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; - // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; - //} - int counter = 0; - for (int64_t iq3 = 0; iq3 < neq3; iq3++) { - for (int64_t iq2 = 0; iq2 < neq2; iq2++) { - if (counter++ % (nth/ntg) == ith/ntg) { - int iq1 = (ith%ntg)*neq1g; - int this_neq1 = MIN(neq1g, neq1-iq1); - if (!iqk_flash_attn_noalibi(k->type, v->type, - D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), - (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), - (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), - (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), - (const void *)((const char *)mask->data + iq1*mask->nb[1]), - scale, softcap, - (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; - } - } - } - return; -IQK_Flash_Attn_NotAvailable:; - } + if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias, + q->ne[3], q->ne[2], q->nb[3], q->nb[2], + k->ne[3], k->ne[2], k->nb[3], k->nb[2], + v->ne[3], v->ne[2], v->nb[3], v->nb[2], + dst->ne[2], dst->ne[1], dst->nb[1], + k->type, v->type, + Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], + q->data, k->data, v->data, mask->data, + scale, softcap, (float *)dst->data, + params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; +// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { +// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", +// // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); +// // I keep changing my mind what is the best strategy to split the threads when processing +// // multiple heads. This is my current thinking, the commented out code below was the previous. +// int ntg = nth/simple_gcd(neq2*neq3, nth); +// int64_t neq1g = (neq1 + ntg - 1)/ntg; +// //int64_t work_per_slice = D*nek1*neq1; +// //int ntg = 1; +// // +// // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix +// // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of +// // the number of threads processing the (iq2, iq3) matrix. +// // +// //if (neq1 >= 8*nth) { +// // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; +// // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; +// // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; +// //} +// int counter = 0; +// for (int64_t iq3 = 0; iq3 < neq3; iq3++) { +// for (int64_t iq2 = 0; iq2 < neq2; iq2++) { +// if (counter++ % (nth/ntg) == ith/ntg) { +// int iq1 = (ith%ntg)*neq1g; +// int this_neq1 = MIN(neq1g, neq1-iq1); +// if (!iqk_flash_attn_noalibi(k->type, v->type, +// Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), +// (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), +// (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), +// (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), +// (const void *)((const char *)mask->data + iq1*mask->nb[1]), +// scale, softcap, +// (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; +// } +// } +// } +// return; +//IQK_Flash_Attn_NotAvailable:; +// printf("iqk_flash was rejected\n"); +// } #endif const uint32_t n_head = neq2; @@ -17615,6 +18241,8 @@ IQK_Flash_Attn_NotAvailable:; ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + const int64_t Dkv = MAX(Dk, Dv); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices @@ -17628,15 +18256,15 @@ IQK_Flash_Attn_NotAvailable:; float S = 0.0f; // sum float M = -INFINITY; // maximum KQ value - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 + float * VKQ32 = (float *) params->wdata + ith*(3*Dkv + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*Dkv); // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*Dkv); // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*Dkv); // (temporary) buffer for Q converted to quantized/FP16 if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + memset(VKQ16, 0, Dkv*sizeof(ggml_fp16_t)); } else { - memset(VKQ32, 0, D*sizeof(float)); + memset(VKQ32, 0, Dkv*sizeof(float)); } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -17650,7 +18278,7 @@ IQK_Flash_Attn_NotAvailable:; const int iv2 = iq2 / rv2; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); + q_to_vec_dot(pq, Q_q, Dk); // online softmax / attention // loop over n_kv and n_head_kv @@ -17664,7 +18292,7 @@ IQK_Flash_Attn_NotAvailable:; float s; // KQ value const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); + kq_vec_dot(Dk, &s, 0, k_data, 0, Q_q, 0, 1); s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask @@ -17682,14 +18310,14 @@ IQK_Flash_Attn_NotAvailable:; ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, VKQ16, ms); + ggml_vec_scale_f16(Dv, VKQ16, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); + ggml_vec_mad_f16(Dv, VKQ16, (const ggml_fp16_t *) v_data, vs); } else { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f @@ -17697,30 +18325,30 @@ IQK_Flash_Attn_NotAvailable:; ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f32(D, VKQ32, ms); + ggml_vec_scale_f32(Dv, VKQ32, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } - v_to_float(v_data, V32, D); + v_to_float(v_data, V32, Dv); // V += v*expf(s - M) - ggml_vec_mad_f32(D, VKQ32, V32, vs); + ggml_vec_mad_f32(Dv, VKQ32, V32, vs); } S = S*ms + vs; // scale and increment sum with partial sum } if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < Dv; ++d) { VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); } } // V /= S const float S_inv = 1.0f/S; - ggml_vec_scale_f32(D, VKQ32, S_inv); + ggml_vec_scale_f32(Dv, VKQ32, S_inv); // dst indices const int i1 = iq1; @@ -19042,17 +19670,19 @@ static void ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) { GGML_ASSERT(params); + GGML_UNUSED(next); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { - return; + return false; } #if IK_PRINT_TIMING int64_t t1 = ggml_time_us(); #endif + bool skip_next = false; switch (tensor->op) { case GGML_OP_DUP: { @@ -19162,6 +19792,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mul_mat_id(params, tensor); } break; + case GGML_OP_MOE_FUSED_UP_GATE: + { + ggml_compute_forward_mul_mat_id_up_gate(params, tensor); + } break; case GGML_OP_OUT_PROD: { ggml_compute_forward_out_prod(params, tensor); @@ -19286,6 +19920,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argsort(params, tensor); } break; + case GGML_OP_ARGSORT_THRESH: + { + ggml_compute_forward_argsort_thresh(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -19402,6 +20040,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm int64_t t2 = ggml_time_us(); if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1)); #endif + return skip_next; } //////////////////////////////////////////////////////////////////////////////// @@ -19918,6 +20557,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_MOE_FUSED_UP_GATE: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_OUT_PROD: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -20266,6 +20909,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_ARGSORT_THRESH: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_LEAKY_RELU: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -20928,6 +21575,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + case GGML_OP_MOE_FUSED_UP_GATE: case GGML_OP_OUT_PROD: { n_tasks = n_threads; @@ -20985,6 +21633,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: + case GGML_OP_ARGSORT_THRESH: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: @@ -21115,6 +21764,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa if (node->src[1]->type != vec_dot_type) { cur = ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + if (node->src[1]->type != GGML_TYPE_F32) { + cur += n_tasks*node->src[1]->ne[0]*sizeof(float); // src1->type -> f32 -> vec_dot_type + } } } break; case GGML_OP_MUL_MAT_ID: @@ -21131,6 +21783,20 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows } break; + case GGML_OP_MOE_FUSED_UP_GATE: + { + cur = 0; + const struct ggml_tensor * src0 = node->src[0]; + const struct ggml_tensor * src2 = node->src[2]; + const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; + if (src2->type != vec_dot_type) { + cur += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + } + const int n_as = src0->ne[2]; + cur += GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src2->ne[2] * sizeof(int64_t); // matrix_rows + } break; case GGML_OP_OUT_PROD: { if (ggml_is_quantized(node->src[0]->type)) { @@ -21184,9 +21850,62 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne00 = node->src[0]->ne[0]; // D + const int64_t Dk = node->src[0]->ne[0]; + const int64_t Dv = node->src[2]->ne[0]; + const int64_t D = MAX(Dk, Dv); - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread +#if GGML_USE_IQK_MULMAT + size_t qsize = 0; + const struct ggml_tensor * q = node->src[0]; + const struct ggml_tensor * k = node->src[1]; + if (k->type == GGML_TYPE_Q8_0) { + qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]); + } + if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { + if (k->ne[2] > 1) { + int gcd = simple_gcd(k->ne[2], n_tasks); + int nth_k = n_tasks/gcd; + int nek2_k = k->ne[2]/gcd; + int nchunk = nek2_k*k->ne[1]/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); + int nstep_k = k->ne[2]*k->ne[1]/nk; + size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); + size_t size = nstep_k*result_size; + cur = MAX(cur, size+qsize); + } else { + int nstep_k = k->ne[1]/32; + int gcd_k = simple_gcd(nstep_k, n_tasks); + if (gcd_k > 1) { + int nth_k = n_tasks/gcd_k; + int rk2 = q->ne[2]/k->ne[2]; + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks; + if (ggml_is_quantized(k->type)) { + enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; + size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); + size += q->ne[2]*row_size; + } + cur = MAX(cur, size+qsize); + } + } + } else { + cur = MAX(cur, qsize); + } +#endif } break; case GGML_OP_FLASH_ATTN_BACK: { @@ -21247,12 +21966,26 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /*.shared=*/ state->shared, }; +#if IK_PRINT_TIMING + int64_t t_start = ggml_time_us(); + int64_t t_eval = 0; +#endif + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; if (ggml_is_noop(node)) continue; - ggml_compute_forward(¶ms, node); +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif + if (ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) { + ++node_n; + } +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + t_eval += tim2 - tim1; +#endif if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { state->shared->ec = GGML_STATUS_ABORTED; @@ -21264,6 +21997,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { break; } } +#if IK_PRINT_TIMING + int64_t t_end = ggml_time_us(); + if (state->ith == 0) printf("ggml_barrier(...): %d us\n", (int)(t_end - t_start - t_eval)); +#endif return 0; } @@ -21289,6 +22026,9 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl #ifdef GGML_USE_OPENMP if (n_threads > 1) { +//#if IK_PRINT_TIMING +// int64_t tim1 = ggml_time_us(); +//#endif #pragma omp parallel num_threads(n_threads) { #pragma omp single @@ -21305,6 +22045,10 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl }; ggml_graph_compute_thread(&worker); } +//#if IK_PRINT_TIMING +// int64_t tim2 = ggml_time_us(); +// printf("%s(...): %d us\n", __func__, (int)(tim2-tim1)); +//#endif } else { struct ggml_compute_state worker = { .thrd = 0, @@ -23039,6 +23783,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_KV: result = quantize_q8_KV(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K_R4: result = quantize_q2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -23050,6 +23795,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_KV_R8:result = quantize_q8_KV_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h new file mode 100644 index 00000000..dc3e369f --- /dev/null +++ b/ggml/src/iqk/iqk_common.h @@ -0,0 +1,138 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=cpp fenc=utf-8 :vi +// +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#include "iqk_config.h" + +#if defined IQK_IMPLEMENT + +#include +#include +#include + +#include "ggml-impl.h" +#include "ggml-quants.h" +#include "iqk_mul_mat.h" +#include "iqk_quantize.h" + +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#define FA_TIMING 0 + +#include +#include +#if FA_TIMING +#include +#include +struct Perf { + using TimePoint = std::chrono::time_point; + std::array times = {}; + std::mutex mutex; + bool report; + static auto cur_time() { return std::chrono::high_resolution_clock::now(); } + inline void accum(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + std::lock_guard lock(mutex); + times[what] += dt; + } + inline void accum_nolock(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + times[what] += dt; + } + inline void add(const Perf& other) { + std::lock_guard lock(mutex); + for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i]; + } + Perf(bool r) : report(r) {} + ~Perf() { + if (report) { + double tot = 0; + for (auto& t : times) tot += t; + if (!tot) return; + printf("======================= Timing: %g ms in total\n", tot); + for (int i = 0; i < int(times.size()); ++i) { + if (times[i]) { + printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%'); + } + } + } + } + static Perf& instance() { + static Perf p(true); + return p; + } + static double delta(const TimePoint& t1, const TimePoint& t2) { + return 1e-6*std::chrono::duration_cast(t2-t1).count(); + } +}; +#endif + +#ifdef __AVX2__ +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +#endif + +namespace { + +typedef struct { + int32_t i1; + int32_t i2; +} mmid_row_mapping; + +struct DataInfo { + float * s; + const char * cy; + size_t bs; + size_t by; + int cur_y = 0; + int ne11; + const mmid_row_mapping * row_mapping = nullptr; + size_t bs2 = 0; + + inline const char * src1_row(int iy) const { + if (!row_mapping) return cy + (cur_y + iy)*by; + int i11 = row_mapping[cur_y + iy].i1 % ne11; + int i12 = row_mapping[cur_y + iy].i2; + return cy + (i11 + i12*ne11)*by; + } + + inline void store(int ix, int iy, float result) const { + *(dst_row(iy) + ix) = result; + } +#ifdef __AVX__ + inline void store(int ix, int iy, __m128 result) const { + _mm_storeu_ps(dst_row(iy) + ix, result); + } + inline void store(int ix, int iy, __m256 result) const { + _mm256_storeu_ps(dst_row(iy) + ix, result); + } +#endif +#ifdef __AVX512F__ + inline void store(int ix, int iy, __m512 result) const { + _mm512_storeu_ps(dst_row(iy) + ix, result); + } +#endif +#ifdef __ARM_NEON + inline void store(int ix, int iy, float32x4_t result) const { + vst1q_f32(dst_row(iy) + ix, result); + } +#endif + inline float * dst_row(int iy) const { + if (!row_mapping) return s + (cur_y + iy)*bs; + int i12 = row_mapping[cur_y + iy].i2; + int i1 = row_mapping[cur_y + iy].i1; + int i2 = i12; + return s + i1*bs + i2*bs2; + } +}; + +typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x); + +#endif diff --git a/ggml/src/iqk/iqk_config.h b/ggml/src/iqk/iqk_config.h new file mode 100644 index 00000000..3d8238d7 --- /dev/null +++ b/ggml/src/iqk/iqk_config.h @@ -0,0 +1,50 @@ +// +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#pragma once + +#if defined IQK_IMPLEMENT +#undef IQK_IMPLEMENT +#endif + +#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD +#define IQK_IMPLEMENT +#endif + +#ifdef GGML_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BUILD +# define IQK_API __declspec(dllexport) +# else +# define IQK_API __declspec(dllimport) +# endif +# else +# define IQK_API __attribute__ ((visibility ("default"))) +# endif +#else +# define IQK_API +#endif + +#ifdef _MSC_VER +#define IQK_NOINLINE __declspec(noinline) +#define IQK_ALWAYS_INLINE inline +#if !defined __x86_64__ && defined _M_X64 +#define __x86_64__ +#endif +#else +#define IQK_NOINLINE __attribute__((__noinline__)) +#define IQK_ALWAYS_INLINE __attribute__((__always_inline__)) +#endif + +#if defined __x86_64__ +#if defined HAVE_FANCY_SIMD + #undef HAVE_FANCY_SIMD +#endif +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) + #define HAVE_FANCY_SIMD +#endif +#endif + diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp new file mode 100644 index 00000000..610f18b7 --- /dev/null +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -0,0 +1,357 @@ +// +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#include "iqk_config.h" +#include "iqk_mul_mat.h" +#include "iqk_flash_impl.h" + +#ifdef IQK_IMPLEMENT + +#include +#include +#include +#include +#include +#include + +namespace { +inline uint32_t simple_gcd(uint32_t a, uint32_t b) { + while (a != b) { + if (a > b) a -= b; + else b -= a; + } + return a; +} +inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float * Racc, const float * R) { + if (Mj == -INFINITY) return; + if (Mj > M) { + if (M == -INFINITY) { + std::memcpy(Racc, R, Dv*sizeof(float)); + S = Sj; + } else { + float c = exp(M - Mj); + S = c*S + Sj; + for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + } + M = Mj; + } else { + float c = exp(Mj - M); + S += c*Sj; + for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + } +} +} + +// TODO: get the ggml_type enum here without polution +// +extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, + int neq3, int neq2, long nbq3, long nbq2, + int nek3, int nek2, long nbk3, long nbk2, + int nev3, int nev2, long nbv3, long nbv2, + int ne2, int ne1, long nb1, + int int_type_k_in, // type of k + int int_type_v, // type of v + int Dk, // K head size + int Dv, // V head size + int neq1, // number of columns in q + int nek1, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + const void * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv, // v*softmax(scale*(k*q)) + [[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, + int ith, int nth) { + + if (type_q != 0 || type_mask != 1 || max_bias > 0) return false; + + int rk2 = neq2/nek2; + int rv2 = neq2/nev2; + int rk3 = neq3/nek3; + int rv3 = neq3/nev3; + + int int_type_k = int_type_k_in; + auto work_buffer = work_buffer_in; + if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) { + uint64_t row_size = 0; + work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); + if (int_type_k != int_type_k_in) { + stride_k = row_size; + nbk2 = stride_k*nek1; + nbk3 = nbk2*nek2; + k = work_buffer_in; + barrier(barrier_data); + } + } + //uint64_t row_size = 0; + //auto work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); + //if (int_type_k != int_type_k_in) { + // stride_k = row_size; + // nbk2 = stride_k*nek1; + // nbk3 = nbk2*nek2; + // k = work_buffer_in; + // barrier(barrier_data); + //} + + // Getting confused all the time about where to load data from and store the results to + // (especially when combining the results from the threads). + // So, for now, making it work just for MLA (nek2 = 1). + // I think it would also speed up things for GQA, but I'm leaving this for another day. + if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) { + int nstep_k = nek1/32; + int gcd_k = simple_gcd(nstep_k, nth); + if (gcd_k >= 1) { + int nth_k = nth/gcd_k; + int ith_k = ith%gcd_k; + int ith_q = ith/gcd_k; + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + if (nq_per_thread > 1) { + int ith_mid = nth_k; + int nq_this_thread = nq_per_thread; + if (nq_per_thread*nth_k > rk2) { + ith_mid = rk2 - nth_k*(nq_per_thread - 1); + if (ith_q >= ith_mid) --nq_this_thread; + } + int j_mid = ith_mid*nq_per_thread; + auto work = (char *)work_buffer; + auto size_thread = (Dv + 16)*nq_per_thread*sizeof(float); + auto result_buffer = work; + + auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k; + auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v; + auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2; + auto qth = (const char *)q + q_offset; + auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here + + // Each thread will produce a result of size Dv*nq_this_thread*sizeof(float) + // In addition, we need M, S for the nq_this_thread rows the thread is processing + // => (Dv + 2)*nq_per_thread*sizeof(float). We use (Dv + 16) instead to make sure threads are not + // writing onto the same cache line. + auto work_this_thread = (float *)(result_buffer + ith*size_thread); + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, + (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, + scale, softcap, + work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false; + + barrier(barrier_data); + + // There are nek1/gcd_k contributions for each j that we need to sum up + // Thread i computed k/v (i%gcd_k)*(nek1/gcd_k) for j (i/gcd_k)*(rk2/nth_k)...((i/gcd_k)+1)*(rk2/nth_k) and results at offset i*size_thread + + // TODO: simdify this + // TODO: if nth > rk2, have threads process portions of the rows instead of entire rows as it is now + for (int j = ith; j < rk2; j += nth) { + auto Racc = qkv + j*nb1/sizeof(float); + float M = -INFINITY, S = 0; + int jth_first, jj, nq_this_j; + if (j < j_mid) { + jth_first = j/nq_per_thread; + jj = j%nq_per_thread; + nq_this_j = nq_per_thread; + } else { + jth_first = ith_mid + (j - j_mid)/(nq_per_thread-1); + jj = (j - j_mid)%(nq_per_thread-1); + nq_this_j = nq_per_thread - 1; + } + jth_first *= gcd_k; + for (int jth = jth_first; jth < jth_first + gcd_k; ++jth) { + auto R = (const float *)(result_buffer + jth*size_thread); + auto Mj = R + Dv*nq_this_j; + auto Sj = Mj + nq_this_j; + R += jj*Dv; + accumulate_qkv(Dv, M, S, Mj[jj], Sj[jj], Racc, R); + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + } + } + } + + if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { + auto result_size = (Dv + 16)*rk2*sizeof(float); + int gcd = simple_gcd(nek2, nth); + if (false && gcd > 1) { + int nth_g = nth/gcd; + int ith_g = ith%nth_g; + int nek1_32 = nek1/32; + int nek1_pt = (nek1_32 + nth_g - 1)/nth_g; + int ith_mid = nth_g; + if (nek1_pt*nth_g > nek1_32) { + ith_mid = nek1_32 - nth_g*(nek1_pt - 1); + } + nek1_pt *= 32; + int nek1_mid = ith_mid*nek1_pt; + int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32; + for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) { + int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread; + auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size); + auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); + auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; + auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; + auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv, + this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, + scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; + } + + barrier(barrier_data); + + for (int iq2 = ith; iq2 < neq2; iq2 += nth) { + int ik02 = iq2/rk2; + int il = iq2 - ik02*rk2; + auto Racc = qkv + iq2*nb1/sizeof(float); + float M = -INFINITY, S = 0; + for (int ig = 0; ig < nth_g; ++ig) { + int istep_k = ik02*nth_g + ig; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + const float * R = this_result + il*Dv; + const float * Mj = this_result + Dv*rk2; + const float * Sj = Mj + rk2; + accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + } + int nth_k = nth/gcd; + int nek2_k = nek2/gcd; + int nchunk = nek2_k*nek1/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (nek2*nek1/(32*nth)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + //int nk = 32 * (nek2*nek1/(32*nth)); + int nkk = (nek1 + nk - 1)/nk; + int nstep_k = nek2*nkk; + //if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k); + for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) { + int ik02 = istep_k/nkk; + int ik01 = nk*(istep_k - ik02*nkk); + int this_nk = ik01 + nk <= nek1 ? nk : nek1 - ik01; + if (this_nk <= 0) break; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); + auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; + auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; + auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv, + this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, + scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; + } + + barrier(barrier_data); + + // We have nkk results for each head + for (int iq2 = ith; iq2 < neq2; iq2 += nth) { + // ik02*rk2 + il = iq2 (il = 0...rk2-1) => ik02 = iq2/rk2, il = iq2%rk2; + int ik02 = iq2/rk2; + int il = iq2 - ik02*rk2; + auto Racc = qkv + iq2*nb1/sizeof(float); + //std::memset(Racc, 0, Dv*sizeof(float)); + float M = -INFINITY, S = 0; + for (int ikk = 0; ikk < nkk; ++ikk) { + int istep_k = ik02*nkk + ikk; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + const float * R = this_result + il*Dv; + const float * Mj = this_result + Dv*rk2; + const float * Sj = Mj + rk2; + accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + } + + // I keep changing my mind what is the best strategy to split the threads when processing + // multiple heads. This is my current thinking, the commented out code below was the previous. + int ntg = nth/simple_gcd(neq2*neq3, nth); + int neq1g = (neq1 + ntg - 1)/ntg; + //int64_t work_per_slice = D*nek1*neq1; + //int ntg = 1; + // + // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix + // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of + // the number of threads processing the (iq2, iq3) matrix. + // + //if (neq1 >= 8*nth) { + // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; + // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; + // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; + //} + int counter = 0; + for (int64_t iq3 = 0; iq3 < neq3; iq3++) { + for (int64_t iq2 = 0; iq2 < neq2; iq2++) { + if (counter++ % (nth/ntg) == ith/ntg) { + int iq1 = (ith%ntg)*neq1g; + int this_neq1 = std::min(neq1g, neq1-iq1); + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float), + (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), + (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3), + (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3), + (const void *)((const char *)mask + iq1*stride_m), + scale, softcap, + (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false; + } + } + } + + return true; +} + +#else + +bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int type_mask, [[maybe_unused]] float max_bias, + [[maybe_unused]] int neq3, [[maybe_unused]] int neq2, [[maybe_unused]] long nbq3, [[maybe_unused]] long nbq2, + [[maybe_unused]] int nek3, [[maybe_unused]] int nek2, [[maybe_unused]] long nbk3, [[maybe_unused]] long nbk2, + [[maybe_unused]] int nev3, [[maybe_unused]] int nev2, [[maybe_unused]] long nbv3, [[maybe_unused]] long nbv2, + [[maybe_unused]] int ne2, [[maybe_unused]] int ne1, [[maybe_unused]] long nb1, + [[maybe_unused]] int type_k, // type of k + [[maybe_unused]] int type_v, // type of v + [[maybe_unused]] int Dk, // K head size + [[maybe_unused]] int Dv, // V head size + [[maybe_unused]] int nq, // number of columns in q + [[maybe_unused]] int nk, // number of rows in k + [[maybe_unused]] int stride_q, // distance between q columns in bytes + [[maybe_unused]] int stride_k, // distance between k rows in bytes + [[maybe_unused]] int stride_v, // distance between v rows in bytes + [[maybe_unused]] int stride_m, // distance between mask rows (in bytes + [[maybe_unused]] const void * q, // q matrix. + [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + [[maybe_unused]] float scale, // scale applied before softmax + [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax + [[maybe_unused]] float * qkv, // v*softmax(scale*(k*q)) + [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, + [[maybe_unused]] int ith, [[maybe_unused]] int nth) { + return false; +} + +#endif + diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h new file mode 100644 index 00000000..6f62e56b --- /dev/null +++ b/ggml/src/iqk/iqk_flash_impl.h @@ -0,0 +1,33 @@ +// +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +bool iqk_flash_attn_impl(int type_k, // type of k + int type_v, // type of v + int Dk, // K head size + int Dv, // V head size + int nq, // number of columns in q + int nk, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv, // v*softmax(scale*(k*q)) + float * M, + float * S); + +void * iqk_repack_k(int type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3, + const void * k, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index b2bcfa1d..39904649 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -8,24 +8,19 @@ // #include -#if defined IQK_IMPLEMENT -#undef IQK_IMPLEMENT -#endif +#include "iqk_config.h" -#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD -#define IQK_IMPLEMENT -#endif +#if defined IQK_IMPLEMENT #include #include #include -#if defined IQK_IMPLEMENT - #include "ggml-impl.h" #include "ggml-quants.h" #include "iqk_mul_mat.h" #include "iqk_quantize.h" +#include "iqk_flash_impl.h" #define GGML_COMMON_IMPL_C #include "ggml-common.h" @@ -101,26 +96,6 @@ struct Perf { }; #endif -#ifdef _MSC_VER -#define IQK_NOINLINE __declspec(noinline) -#define IQK_ALWAYS_INLINE inline -#if !defined __x86_64__ && defined _M_X64 -#define __x86_64__ -#endif -#else -#define IQK_NOINLINE __attribute__((__noinline__)) -#define IQK_ALWAYS_INLINE __attribute__((__always_inline__)) -#endif - -#if defined __x86_64__ -#if defined HAVE_FANCY_SIMD - #undef HAVE_FANCY_SIMD -#endif -#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) - #define HAVE_FANCY_SIMD -#endif -#endif - namespace { typedef struct { @@ -244,6 +219,118 @@ struct MulMat { funcs[n_left-1](n, vx, bx, info, nrc_x); } } + inline void gelu(int n, const float * src, float * dst); + inline void relu(int n, const float * src, float * dst); + inline void silu(int n, const float * src, float * dst); + inline void activate(ggml_unary_op op, int n, const float * src, float * dst) { + if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst); + else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst); + else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst); + else GGML_ABORT("fatal error"); + } + inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) { +#ifdef __aarch64__ + constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) +#else + constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small) +#endif + auto op = ggml_unary_op(unary_op); + float tmp[k_x_step*16]; + if (func16 && nrc_y >= 16) { + int n_step = (nrc_y - info.cur_y)/16; + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < 16; ++ky) { + activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + } + func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < 16; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += 16; + } + } + info.cur_y += 16 * n_step; + if (info.cur_y == nrc_y) return; + } + int ny = funcs.size(); + while (!funcs[ny-1] && ny > 0) --ny; + int n_left = nrc_y - info.cur_y; + int n_step = n_left/ny; + if (n_step > 0) { + if (n_step*ny != n_left) { + ++n_step; + int ny1 = n_left/n_step; + int ny2 = ny1 + 1; + int my1 = n_step*ny2 - n_left; + int my2 = n_step - my1; + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < my1; ++iy) { + funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny1; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += ny1; + } + for (int iy = 0; iy < my2; ++iy) { + funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny2; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += ny2; + } + } + info.cur_y += n_left; + } + else { + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += ny; + } + } + info.cur_y += ny * n_step; + } + } + n_left = nrc_y - info.cur_y; + if (n_left > 0) { + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < n_left; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + } + } + } static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny); static inline int num_rows(ggml_type type) { #ifdef HAVE_FANCY_SIMD @@ -270,6 +357,8 @@ struct MulMat { case GGML_TYPE_IQ4_XS_R8: case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K_R4: + case GGML_TYPE_Q8_KV: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q8_0_R8: @@ -302,6 +391,8 @@ struct MulMat { case GGML_TYPE_IQ4_XS_R8: case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_Q8_KV: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; default: return 1; @@ -314,7 +405,7 @@ private: } -bool iqk_mul_mat(long Nx, long Ny, long ne00, +extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth) { @@ -341,7 +432,122 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00, return true; } -bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, +namespace { +inline uint32_t simple_gcd(uint32_t a, uint32_t b) { + while (a != b) { + if (a > b) a -= b; + else b -= a; + } + return a; +} +} + +extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, + long ne02, long ne03, long ne12, long ne13, + long nb02, long nb03, long nb12, long nb13, long nb2, long nb3, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, + float * C, long stride_C, int ith, int nth) { + + auto r2 = ne12 / ne02; + auto r3 = ne13 / ne03; + + if (ne13 == 1 && Ny == 1 && r2 > 1) { + if (Nx >= 256 && Nx%32 == 0) { + int nx32 = Nx/32; + int nchunk = nx32*ne02; + if (r2 <= 8) { + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false; + int nx64 = Nx/64; + int nchunk64 = nx64*ne02; + for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) { + int i02 = ichunk/nx64; + int ix = 64*(ichunk - i02*nx64); + DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64); + } + int ix0 = 64*nx64; + if (ix0 < Nx) { + nx32 -= 2*nx64; + nchunk = nx32*ne02; + for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + int i02 = ichunk/nx32; + int ix = ix0 + 32*(ichunk - i02*nx32); + DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32); + } + } + //for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + // int i02 = ichunk/nx32; + // int ix = 32*(ichunk - i02*nx32); + // DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + // mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32); + //} + return true; + } + for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + int i02 = ichunk/nx32; + int ix = ichunk - i02*nx32; + if (!iqk_mul_mat(32, r2, ne00, + typeA, (const char *)A + 32*ix*strideA + i02*nb02, strideA, + typeB, (const char *)B + i02*r2*nb12, nb12, + C + 32*ix + r2*i02*nb2, nb2, 0, 1)) return false; + + } + return true; + } + //if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02); + int gcd = simple_gcd(ne02, nth); + int counter = 0; + for (int64_t i12 = 0; i12 < ne02; i12++) { + if ((counter++ % gcd) == (ith%gcd)) { + if (!iqk_mul_mat(Nx, r2, ne00, + typeA, (const char *)A + i12*nb02, strideA, + typeB, (const char *)B + i12*r2*nb12, nb12, + C + r2*i12*nb2, nb2, + ith/gcd, nth/gcd)) return false; + } + } + return true; + } + + if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) { + //printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx); + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { + return false; + } + int n_per_thread = (Nx + nth - 1)/nth; + int first = ith*n_per_thread; + if (first >= Nx) return true; + int last = first + n_per_thread <= Nx ? first + n_per_thread : Nx; + for (int ix = first; ix < last; ++ix) { + for (int i02 = 0; i02 < ne02; ++i02) { + DataInfo info{C + ix + i02*nb2, (const char *)B + i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + mm.funcs[0](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), nb02, info, 1); + } + } + return true; + } + + int gcd = simple_gcd(ne12*ne13, nth); + int counter = 0; + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + if ((counter++ % gcd) == (ith%gcd)) { + if (!iqk_mul_mat(Nx, Ny, ne00, + typeA, (const char *)A + i12/r2*nb02 + i13/r3*nb03, strideA, + typeB, (const char *)B + i12*nb12 + i13*nb13, strideB, + C + i12*nb2 + i13*nb3, stride_C, + ith/gcd, nth/gcd)) return false; + } + } + } + return true; +} + +extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { @@ -367,6 +573,34 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, return true; } +extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, + int typeA, const void * Aup, const void * Agate, long strideA, + int typeB, const void * B, long strideB, + float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { + + const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; + assert(row_mapping != nullptr); + + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { + return false; + } + size_t row_size_qx = strideA; + size_t row_size_qy = strideB; + auto num_rows = MulMat::num_rows(ggml_type(typeA)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; + auto first_x = ith*nrc_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; + first_x *= num_rows; + nrc_x *= num_rows; + DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), + row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; + mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op); + return true; +} + + namespace { inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { @@ -1156,7 +1390,7 @@ static const uint32_t iq1s_grid_us[2048] = { }; #endif -#ifndef HAVE_FANCY_SIMD +#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__) const uint64_t keven_signs[128] = { 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff, @@ -1419,7 +1653,7 @@ inline void set_scales_16(const __m256i& all_scales, __m256i * scales) { template inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { if (j == 0) { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); @@ -1436,7 +1670,7 @@ inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, } #endif } else { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); @@ -1518,6 +1752,15 @@ __m256i inline load_iq4nl_values_256() { return MM256_SET_M128I(val128, val128); } +__m128i inline load_iq4k_values_128() { + return _mm_loadu_si128((const __m128i *)iq4k_values); +} + +__m256i inline load_iq4k_values_256() { + auto val128 = load_iq4k_values_128(); + return MM256_SET_M128I(val128, val128); +} + #ifdef HAVE_FANCY_SIMD //====================================== Zen4 ================================================== @@ -2694,7 +2937,7 @@ struct DequantizerIQ6K final : public BaseDequantizer { auto h1 = _mm256_andnot_si256(mask4, hbits); auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1); auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2); - auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(0xff)); + auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(-1)); // 0xff; return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)), _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))), _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)), @@ -2790,7 +3033,7 @@ struct DequantizerIQ4KSS final : public BaseDequantizer { const __m256i values; __m256i data[4]; const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001); - const __m256i bmask = _mm256_set1_epi16(0xfffe); + const __m256i bmask = _mm256_set1_epi16(-2); // 0xfffe; const __m128i mask = _mm_set1_epi16(254); const __m128i m127 = _mm_set1_epi16(-127); const __m128i m128 = _mm_set1_epi16(-128); @@ -3544,9 +3787,9 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da #ifdef HAVE_FANCY_SIMD template -static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm512_set1_epi8(0xf); auto values = load_iq4nl_values_512(); int nb = n / QK4_NL; @@ -3583,7 +3826,8 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); @@ -3600,9 +3844,10 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -3617,9 +3862,9 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data } #else template -static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m1 = _mm256_set1_epi16(1); auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); @@ -3656,7 +3901,8 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm_storeu_ps(d8+4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d))); + auto aux = _mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm_storeu_ps(d8+4*iy, _mm_castsi128_ps(aux)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq4[4*ib4+k]); @@ -3672,7 +3918,8 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } } @@ -3724,14 +3971,13 @@ inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) { _mm256_maddubs_epi16(v[5], _mm256_shuffle_epi32(yh, 0x55))); auto sumi4 = _mm256_add_epi16(_mm256_maddubs_epi16(v[6], _mm256_shuffle_epi32(yh, 0xaa)), _mm256_maddubs_epi16(v[7], _mm256_shuffle_epi32(yh, 0xff))); - auto sumi = _mm256_add_epi32(_mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(sumi1, sumi2)), - _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(sumi3, sumi4))); + auto sumi = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_add_epi16(sumi1, sumi2), _mm256_add_epi16(sumi3, sumi4))); #endif return sumi; } template -static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); @@ -3745,7 +3991,7 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D auto acc1 = _mm256_setzero_ps(); auto acc2 = _mm256_setzero_ps(); for (int ib4 = 0; ib4 < nb/4; ++ib4) { - helper.vec = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)); + helper.vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16)); for (int k = 0; k < 4; ++k) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4); @@ -3760,9 +4006,10 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); auto sumi = accum_q4_0_quants(v, qy[ib].qs); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1); - acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc2); + acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc2); } acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1); info.store(ix, 0, acc1); @@ -3780,7 +4027,7 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D d4[k] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); } for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm256_storeu_ps(d8 + 8*iy, scales); auto m4 = _mm256_extractf128_ps(scales, 1); auto m8 = _mm256_set_m128(m4, m4); @@ -3808,9 +4055,10 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = accum_q4_0_quants(v, qy[ib].qs); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -3908,6 +4156,85 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI } } +// sum[ qy_i * ls_k * (qx_i - 1+/-delta_k)] +// = sum[qy_i * qx_i * ls_k] - 1/8*sum[qy_i * ls_k * (8+/-o_k)] +// = 1/8 * ( sum[qy_i * qx_i * 8*ls+k] - sum[qy_i * ls_k * (8+/-o_k)] ) + +template +static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + Q8 q8(info); + __m256i qx[8]; + __m256i scales[4]; + __m256 acc[nrc_y] = {}; + auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000 + __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100); + for (int ix = 0; ix < nrc_x; ++ix) { + auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < n/QK_K; ++ibl) { + float d = GGML_FP16_TO_FP32(iq1s[ibl].d); + auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh); + auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7)); + scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1)); +#ifdef HAVE_FANCY_SIMD + auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask); + auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9)); +#else + auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask); + auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7))); +#endif + deltas128 = _mm_mullo_epi16(scales128, deltas128); + scales128 = _mm_slli_epi16(scales128, 3); + auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128); + auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128); + auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7 + auto all_scales = MM256_SET_M128I(scales128, scales128); + auto shuffle = shuffle0; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle); + shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)); + } + const uint8_t * qs = iq1s[ibl].qs; + const uint16_t * qh = iq1s[ibl].qh; + for (int ib = 0; ib < QK_K/32; ib += 2) { + qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)], + iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)], + iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + qs += 8; + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); + auto sumi = _mm256_setzero_si256(); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0); + auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1); +#ifdef HAVE_FANCY_SIMD + auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1); + auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2); + sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2)); +#else + auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1); + auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2); + auto dot = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2)); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot)); +#endif + } +#ifdef HAVE_FANCY_SIMD + sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas); +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas)); +#endif + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, 0.125f*hsum_float_8(acc[iy])); + acc[iy] = _mm256_setzero_ps(); + } + } +} + template static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -4006,9 +4333,9 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI #ifdef HAVE_FANCY_SIMD template -static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { if constexpr (nrc_y == 1) { - mul_mat_q4_0_r8_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + mul_mat_q4_0_r8_q8_2_avx2<1>(n, vx, bx, info, nrc_x); return; } GGML_ASSERT(nrc_x%16 == 0); @@ -4053,7 +4380,8 @@ static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn const block_iq4_nl_r8 * iq4h = (const block_iq4_nl_r8 *)((const char *)vx + (ix+8)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); @@ -4070,9 +4398,10 @@ static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(qy[ib].qs); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4084,15 +4413,15 @@ static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q4_0_r8_q8_1_avx2(n, vx, bx, info, nrc_x); +static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q4_0_r8_q8_2_avx2(n, vx, bx, info, nrc_x); } #endif template -static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); +static void mul_mat_q5_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m5 = _mm256_set1_epi8(0x10); #ifndef HAVE_FANCY_SIMD @@ -4139,7 +4468,7 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales)); } for (int k = 0; k < 4; ++k) { @@ -4157,9 +4486,10 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_BF16_TO_FP32(s)), acc[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4172,12 +4502,12 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D #ifdef HAVE_FANCY_SIMD template -static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { if constexpr (nrc_y == 1) { - mul_mat_q5_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + mul_mat_q5_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x); } else { GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm512_set1_epi8(0xf); auto m5 = _mm512_set1_epi8(0x10); int nb = n / QK5_0; @@ -4219,7 +4549,7 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16))); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]); @@ -4236,9 +4566,10 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4254,15 +4585,15 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q5_0_r4_q8_1_avx2(n, vx, bx, info, nrc_x); +static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q5_0_r4_q8_2_avx2(n, vx, bx, info, nrc_x); } #endif template -static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); +static void mul_mat_q6_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m6 = _mm256_set1_epi8(0x30); auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f)); @@ -4307,7 +4638,7 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale)); } for (int k = 0; k < 4; ++k) { @@ -4325,9 +4656,10 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_BF16_TO_FP32(s)), acc[iy]); } } @@ -4341,12 +4673,12 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D #ifdef HAVE_FANCY_SIMD template -static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { if constexpr (nrc_y == 1) { - mul_mat_q6_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + mul_mat_q6_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x); } else { GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm512_set1_epi8(0xf); auto m6 = _mm512_set1_epi8(0x30); int nb = n / QK6_0; @@ -4386,7 +4718,7 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm256_storeu_ps(d8 + 8*iy, scales); } for (int k = 0; k < 4; ++k) { @@ -4404,9 +4736,10 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4422,8 +4755,8 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q6_0_r4_q8_1_avx2(n, vx, bx, info, nrc_x); +static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q6_0_r4_q8_2_avx2(n, vx, bx, info, nrc_x); } #endif @@ -4463,20 +4796,15 @@ inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) { return sumi; } inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) { - qx[0] = _mm256_loadu_si256((const __m256i *)x+0); - qx[1] = _mm256_loadu_si256((const __m256i *)x+1); - qx[2] = _mm256_loadu_si256((const __m256i *)x+2); - qx[3] = _mm256_loadu_si256((const __m256i *)x+3); - qx[4] = _mm256_loadu_si256((const __m256i *)x+4); - qx[5] = _mm256_loadu_si256((const __m256i *)x+5); - qx[6] = _mm256_loadu_si256((const __m256i *)x+6); - qx[7] = _mm256_loadu_si256((const __m256i *)x+7); + for (int i = 0; i < 8; ++i) { + qx[i] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)x+i), _mm256_set1_epi8(127)); + } return qx_r8_q8_dot_product(qx, y); } template -static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%16 == 0); - Q8 q8(info); + Q8 q8(info); int nb = n / QK8_0; if constexpr (nrc_y == 1) { __m256 acc[2] = {}; @@ -4485,7 +4813,8 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int ix = 0; ix < nrc_x; ix += 8) { const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { - _mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d))); + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16); + _mm256_storeu_ps(d8, _mm256_castsi256_ps(aux)); for (int k = 0; k < 4; ++k) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx); @@ -4499,9 +4828,10 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int ib = 4*(nb/4); ib < nb; ++ib) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d)); auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); - acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[1]); + acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[1]); } } info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0])); @@ -4516,7 +4846,8 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); } for (int k = 0; k < 4; ++k) { auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d)); @@ -4525,6 +4856,7 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int j = 0; j < 8; ++j) { qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)), _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1); + qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127)); } for (int iy = 0; iy < nrc_y; ++iy) { auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k); @@ -4541,13 +4873,15 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int j = 0; j < 8; ++j) { qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)), _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1); + qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127)); } for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4560,9 +4894,9 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); + Q8 q8(info); auto m1 = _mm256_set1_epi16(1); int nb = n / QK8_0; __m256 acc[nrc_y] = {}; @@ -4585,7 +4919,7 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm_storeu_ps(d8 + 4*iy, scales); } for (int k = 0; k < 4; ++k) { @@ -4617,9 +4951,9 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn sx[j] = _mm256_sign_epi8(qx[j], qx[j]); } for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; + auto qy = (const block_q8_2 *)q8.y[iy]; auto sumi = dot(qy[ib].qs); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d}))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } for (int j = 0; j < 4; ++j) { @@ -4627,9 +4961,9 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn sx[j] = _mm256_sign_epi8(qx[j], qx[j]); } for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; + auto qy = (const block_q8_2 *)q8.y[iy]; auto sumi = dot(qy[ib].qs+16); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d}))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } } @@ -6353,7 +6687,7 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn // The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) template static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); #ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); @@ -6376,6 +6710,11 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn auto s1 = _mm256_sign_epi8(qx[1], qx[1]); auto s2 = _mm256_sign_epi8(qx[2], qx[2]); auto s3 = _mm256_sign_epi8(qx[3], qx[3]); +#else + qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127)); + qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127)); + qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127)); + qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127)); #endif for (int iy = 0; iy < nrc_y; ++iy) { auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib); @@ -6415,6 +6754,346 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn } } +// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) +template +static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%32 == 0); + GGML_ASSERT(nrc_x%8 == 0); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + int nb = n / 16; + __m256i acc[nrc_y] = {}; + __m256i qx[4]; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + float sy[nrc_y]; +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ix += 8) { + auto dptr = (const float *)((const char *)vx + ix*bx); + auto dx = _mm256_loadu_ps(dptr); + auto q8x = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows + qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0); + qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1); + qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2); + qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3); +#ifndef HAVE_FANCY_SIMD + auto s0 = _mm256_sign_epi8(qx[0], qx[0]); + auto s1 = _mm256_sign_epi8(qx[1], qx[1]); + auto s2 = _mm256_sign_epi8(qx[2], qx[2]); + auto s3 = _mm256_sign_epi8(qx[3], qx[3]); +#else + qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127)); + qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127)); + qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127)); + qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127)); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib); + auto y = MM256_SET_M128I(y128, y128); +#ifdef HAVE_FANCY_SIMD + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); + auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4)); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34)); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy])); +#ifdef HAVE_FANCY_SIMD + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy])); +#endif + info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy]))); + acc[iy] = _mm256_setzero_si256(); + } + } +} + +template +static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%32 == 0); + if (nrc_y == 1 && nrc_x == 1) { + auto dx = (const float *)vx; + auto dy = (const float *)info.src1_row(0); +#ifdef HAVE_FANCY_SIMD + auto sy = (const int32_t *)(dy + 1); + auto x = (const int8_t *)(dx + 2); + auto y = (const int8_t *)(dy + 2); + auto isum = _mm512_setzero_si512(); + for (int i = 0; i < n/64; ++i) { + auto qx = _mm512_loadu_si512((const __m512i *)x + i); + auto qy = _mm512_loadu_si512((const __m512i *)y + i); + isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy); + } + auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1)); + for (int i = 2*(n/64); i < n/32; ++i) { + auto qx = _mm256_loadu_si256((const __m256i *)x + i); + auto qy = _mm256_loadu_si256((const __m256i *)y + i); + isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy); + } + info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0])); +#else + auto x = (const int8_t *)(dx + 2); + auto y = (const int8_t *)(dy + 2); + auto isum = _mm256_setzero_si256(); + for (int i = 0; i < n/32; ++i) { + auto qx = _mm256_loadu_si256((const __m256i *)x + i); + auto qy = _mm256_loadu_si256((const __m256i *)y + i); + auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx)); + isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot)); + } + info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum)); +#endif + return; + } + __m256i qx[2]; + __m256i acc[2*nrc_y] = {}; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + int32_t sy[nrc_y]; +#else + __m256i sx[2]; + auto m1 = _mm256_set1_epi16(1); +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr+1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ++ix) { + auto dx = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dx + 2); + for (int i = 0; i < n/64; ++i) { + for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD + qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127)); +#else + qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD + acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j)); +#else + auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); + acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot)); +#endif + } + } + } + if (int i = 2*(n/64); i < n/32) { +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127)); +#else + qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i); + sx[0] = _mm256_sign_epi8(qx[0], qx[0]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i)); +#else + auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); + acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot)); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1])); +#ifdef HAVE_FANCY_SIMD + info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy])); +#else + info.store(ix, iy, dx[0]*dy[iy]*sumi); +#endif + acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256(); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m512i qx[4]; + __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {}; + float dy[nrc_y]; + int32_t sy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -64*iptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[8]; + float dx[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int kx = 0; kx < 8; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/32; ++i) { + for (int kx = 0; kx < 4; ++kx) { + qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)), + _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1); + } + auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]); + auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]); + auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]); + auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]); + qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128)); + qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + if constexpr (nrc_y <= 4) { + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } else { + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + } + auto scales_x = _mm256_loadu_ps(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + if constexpr (nrc_y <= 4) { + auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } else { + acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[iy] = _mm512_setzero_si512(); + } + } + } +} +#endif + +template +static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(n%32 == 0); + __m256i qx[4]; +#ifndef HAVE_FANCY_SIMD + __m256i sx[4]; + auto m1 = _mm256_set1_epi16(1); +#endif + __m256i acc[nrc_y] = {}; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + int32_t sy[nrc_y]; +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[4]; + float dx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + for (int kx = 0; kx < 4; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/32; ++i) { + for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i); + auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]); + auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]); + auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]); + auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]); +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127)); + qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127)); + qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127)); + qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127)); +#else + qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]); + qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]); + qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]); + qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); +#ifdef HAVE_FANCY_SIMD + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2)); + auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4)); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34)); +#endif + } + } + auto scales_x = _mm_loadu_ps(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1)); +#ifdef HAVE_FANCY_SIMD + sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy])); +#endif + auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy])); + info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi))); + acc[iy] = _mm256_setzero_si256(); + } + } +} + #ifdef __AVX512BF16__ template static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -6988,7 +7667,7 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI template inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); @@ -7004,7 +7683,7 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons sumi[1] = _mm256_add_epi32(p2, p4); #endif } else { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); @@ -7191,7 +7870,7 @@ struct DequantizerIQ1BN { _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), }; const __m256i m3 = _mm256_set1_epi16(3); -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__ const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); #endif @@ -7202,7 +7881,7 @@ struct DequantizerIQ1BN { auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3); auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3); auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3); -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__ v1 = _mm256_permutex2var_epi8(val1, bmask, val2); v2 = _mm256_permutex2var_epi8(val3, bmask, val4); #else @@ -7221,7 +7900,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const __m256i accd[nrc_y]; __m256i val[4]; -#if !(defined __AVX512VNNI__ && defined __AVX512VL__) +#ifndef HAVE_FANCY_SIMD const auto m1_16 = _mm256_set1_epi16(1); #endif @@ -7243,7 +7922,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1)); acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3)); #else @@ -7267,7 +7946,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), @@ -7288,7 +7967,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const if (i < nb) { deq.prepare_iq1bn_quants(x + i, val[0], val[1]); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); #else @@ -7340,7 +8019,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const __m256i accd[nrc_y]; __m256i val[4]; -#if !(defined __AVX512VNNI__ && defined __AVX512VL__) +#ifndef HAVE_FANCY_SIMD const auto m1_16 = _mm256_set1_epi16(1); #endif @@ -7352,7 +8031,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const __m256i acc[2] = {}; for (int i = 0; i < nb/2; ++i) { deq.prepare4(i, val); -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1)); acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)), @@ -7375,7 +8054,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { deq.prepare4(i, val); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3)); @@ -7394,7 +8073,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const if (i < nb) { deq.prepare2(i, val); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); #else @@ -7483,7 +8162,7 @@ struct DequantizerIQ3S final : public BaseDequantizer { }; struct EvenSignHelper { -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ union sbits_t { __m128i vec; __mmask32 mask[4]; @@ -7548,7 +8227,7 @@ struct DequantizerIQ3XXS final : public BaseDequantizer { } IQK_ALWAYS_INLINE void sign_2_values(const uint16_t * signs, __m256i * values) const { -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(signs[2] | (signs[3] << 16)), _mm_set1_epi32(signs[0] | (signs[1] << 16))), values); #else esh.sign_value(signs[0] | (signs[1] << 16), values[0]); @@ -7723,7 +8402,7 @@ struct DequantizerIQ2XS final : public BaseDequantizer { value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone)); } inline void sign_values(const __m256i& data, __m256i * values) const { -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ auto partial_bits = _mm256_cvtepi16_epi8(_mm256_srli_epi16(data, 9)); auto pcnt = _mm_popcnt_epi8(partial_bits); auto full_bits = _mm_or_si128(partial_bits, _mm_slli_epi16(_mm_and_si128(pcnt, _mm_set1_epi8(1)), 7)); @@ -7773,7 +8452,7 @@ struct DequantizerIQ2XS final : public BaseDequantizer { constexpr static int minv = 43; SimpleBits bits; -#ifndef HAVE_FANCY_SIMD +#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__) Helper helper; #endif const __m256i idx_mask = _mm256_set1_epi16(511); @@ -7818,7 +8497,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { } IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const { -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0); esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2); #else @@ -7904,6 +8583,22 @@ template struct return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3 } } + inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); } +}; + +template struct Sum4q4 { + inline __m256i compute(const __m256i * qx, const Q8 * y) const { + const Q8x4 * y4 = (const Q8x4 *)y; + auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0 + auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1 + auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2 + auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3 + auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1 + auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3 + auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 + return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123); + } + inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); } }; struct ScaleHelperQ8_0 { @@ -7956,6 +8651,29 @@ struct ScaleHelperQ_0_1 { const __m128 min = _mm_set1_ps(float(-min_value)); }; +//template +//struct ScaleHelperQ_0_2 { +// ggml_bf16_t scales8[4]; +// template +// inline __m256 prepare4(const Q * y) { +// for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; +// auto s4 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales8)), 16)); +// return _mm256_set_m128(_mm_mul_ps(s4, min), s4); +// } +// template +// inline __m256 prepare4(__m256 other_scales, const Q * y) { +// return _mm_mul256_ps(other_scales, prepare4(y)); +// } +// template inline std::pair prepare1(const Q * y) const { +// float d = GGML_BF16_TO_FP32(y->d); +// return std::make_pair(d, -d*float(min_value)); +// } +// std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { +// return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); +// } +// const __m128 min = _mm_set1_ps(float(-min_value)); +//}; + struct ScaleHelperQ8_1 { template inline __m256 prepare4(const Q * y) { @@ -7977,6 +8695,30 @@ struct ScaleHelperQ8_1 { } }; +struct ScaleHelperQ8_2 { + template + inline __m256 prepare4(const Q * y) { + const block_q8_2_x4 * y4 = (const block_q8_2_x4 *)y; + auto aux = _mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y4->d)); + return _mm256_castsi256_ps(_mm256_slli_epi32(aux, 16)); + } + template + inline __m256 prepare4(__m256 other_scales, const Q * y) { + return _mm256_mul_ps(other_scales, prepare4(y)); + } + template inline std::pair prepare1(const Q * y) const { + return std::make_pair(GGML_BF16_TO_FP32(y->d), GGML_BF16_TO_FP32(y->m)); + } + template inline std::pair prepare1(const std::pair& dm, const Q * y) const { + ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; + return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); + } + std::pair inline prepare1(const std::pair& dm, const block_q8_2 * y) const { + ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; + return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); + } +}; + struct ScaleHelperQ_1 { uint32_t scales8[4]; const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100); @@ -8011,6 +8753,7 @@ struct MinusType0 { inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } inline float compute(float d, int) const { return d; } inline float result(__m256 acc, int) const { return hsum_float_8(acc); } + inline __m256 vresult(__m256 acc, int) const { return acc; } }; template struct MinusType1 { @@ -8030,6 +8773,9 @@ template struct MinusType1 { const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); return hsum_float_4(_mm_add_ps(sum, accm[iy])); } + inline __m256 vresult(__m256 acc, int iy) const { + return _mm256_add_ps(acc, _mm256_insertf128_ps(_mm256_setzero_ps(), accm[iy], 0)); + } }; template struct AccumT { @@ -8057,7 +8803,7 @@ template struct AccumT { for (int iy = 0; iy < nrc_y; ++iy) { auto s12 = scales.prepare1(other_scales, y[iy] + i); auto d = accm.compute(s12, iy); - const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); } } @@ -8066,6 +8812,36 @@ template struct AccumT { info.store(ix, iy, accm.result(acc[iy], iy)); } } + template + inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, __m256 * result) { + auto qx = unp.quants(); + __m256 dall[nrc_y]; + for (int i = 0; i < nb/4; ++i) { + auto other_scales = unp.set_block_4(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); + dall[iy] = accm.compute(s12, iy); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto pall = sum.compute(qx, y[iy] + 4*i); + acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); + } + } + if (!is_multiple_of_4) { + for (int i = 4*(nb/4); i < nb; ++i) { + auto other_scales = unp.set_block(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare1(other_scales, y[iy] + i); + auto d = accm.compute(s12, iy); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + result[iy] = accm.vresult(acc[iy], iy); + } + } }; template @@ -8074,10 +8850,8 @@ using AccumType0 = AccumT; template using AccumType1 = AccumT, nrc_y, is_multiple_of_4>; -using Sum4Type0 = Sum4; -using Sum4Type1 = Sum4; using Sum4TypeQ80 = Sum4; -using Sum4TypeQ81 = Sum4; +using Sum4TypeQ82 = Sum4; template void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { @@ -8091,6 +8865,19 @@ void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& in } } +template +void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { + GGML_ASSERT(nrc_x%2 == 0); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + Scales scales; + for (int ix = 0; ix < nrc_x; ix += 2) { + unp.set_row(ix); + AccumType accum; + accum.compute(nb, unp, scales, sum4, y, info, ix); + } +} + template void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); @@ -8107,6 +8894,63 @@ void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info } } +inline __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31)); + //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} + +template +void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_0 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + + template void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); @@ -8123,6 +8967,68 @@ void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info } } +template +void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%Unpacker::block_size() == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + if (nb%4 == 0) { + mul_mat_qX_q8_Helper, ScaleHelperQ8_2, block_q8_2, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } else { + mul_mat_qX_q8_Helper, ScaleHelperQ8_2, block_q8_2, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } +} + +template +void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_2 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + struct Dequantizer4bit { const __m256i m4 = _mm256_set1_epi8(0xf); inline __m256i dequant(const uint8_t * qs) const { @@ -8160,7 +9066,11 @@ struct Q4_0_1_Dequantizer { struct IQ4_NL_Dequantizer { Dequantizer4bit b4; +#ifdef HAVE_FANCY_SIMD const __m256i values = load_iq4nl_values_256(); +#else + const __m256i values = load_iq4k_values_256(); +#endif inline __m256i dequant(const block_iq4_nl * x) const { return _mm256_shuffle_epi8(values, b4.dequant(x->qs)); } @@ -8251,73 +9161,6 @@ struct Q_Unpacker { } }; -struct Q8_0_x4_Unpacker_256 { - using Sum4T = Sum4TypeQ80; - inline static int block_size() { return QK8_0; } - Q8_0_x4_Unpacker_256(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} - - const char * cx_0; - const block_q8_0_x4 * x; - size_t bx; - - __m256i qx[4]; - - inline const __m256i* quants() const { return qx; } - - inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); } - - inline auto set_block_4(int i) { - auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); - for (int j = 0; j < 4; ++j) { - qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); - } - return scales; - } - inline auto set_block(int i) { - auto q8 = (const block_q8_0 *)(x + i); - qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); - return GGML_FP16_TO_FP32(q8->d); - } -}; - -#ifdef HAVE_FANCY_SIMD -struct Q8_0_x4_Unpacker_512 { - using Sum4T = Sum4TypeQ81; - inline static int block_size() { return QK8_0; } - Q8_0_x4_Unpacker_512(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} - - const char * cx_0; - const block_q8_0_x4 * x; - size_t bx; - const __m128 min = _mm_set1_ps(-128.f); - - __m256i qx[4]; - - inline const __m256i* quants() const { return qx; } - - inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); } - - inline auto set_block_4(int i) { - auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); - for (int j = 0; j < 4; ++j) { - qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); - qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(-128)); - } - return _mm256_set_m128(_mm_mul_ps(scales, min), scales); - } - inline auto set_block(int i) { - auto q8 = (const block_q8_0 *)(x + i); - qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); - qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(-128)); - float d = GGML_FP16_TO_FP32(q8->d); - return std::make_pair(d, -128.f*d); - } -}; -using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_512; -#else -using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_256; -#endif - struct Q8_0_Unpacker final : public Q_Unpacker { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; @@ -8325,7 +9168,7 @@ struct Q8_0_Unpacker final : public Q_Unpacker, Q8_0_1_Dequantizer> { Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK8_0; } }; struct Q4_0_Unpacker final : public Q_Unpacker { @@ -8335,14 +9178,23 @@ struct Q4_0_Unpacker final : public Q_Unpacker, Q4_0_1_Dequantizer> { Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + //using Sum4T = Sum4TypeQ82; + using Sum4T = Sum4q4; inline static int block_size() { return QK4_0; } }; +#ifdef HAVE_FANCY_SIMD struct IQ4_NL_Unpacker final : public Q_Unpacker, IQ4_NL_Dequantizer> { IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK4_NL; } }; +#else +struct IQ4_NL_Unpacker final : public Q_Unpacker { + IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK4_NL; } +}; +#endif struct Q5_0_Unpacker final : public Q_Unpacker { Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; @@ -8350,22 +9202,22 @@ struct Q5_0_Unpacker final : public Q_Unpacker, Q5_1_Dequantizer> { Q5_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK5_0; } }; struct Q4_1_Unpacker final : public Q_Unpacker { Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4Type1; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK4_1; } }; struct Q5_1_Unpacker final : public Q_Unpacker> { Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4Type1; - inline static int block_size() { return QK4_1; } + using Sum4T = Sum4TypeQ82; + inline static int block_size() { return QK5_1; } }; struct Q6_0_1_Unpacker final : public Q_Unpacker, Q6_0_1_Dequantizer> { Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK6_0; } }; @@ -8378,6 +9230,9 @@ struct QFBase { using Acc = __m512; static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); } static inline Data load(const float * x) { return _mm512_loadu_ps(x); } + static inline Data load(const ggml_bf16_t * x) { + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16)); + } static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm512_fmadd_ps(y, x, prev); } @@ -8473,7 +9328,7 @@ template struct QFT final : public QFBase { xv[1] = load1(ix+1, i); xv[2] = load1(ix+2, i); xv[3] = load1(ix+3, i); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]); auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]); auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]); @@ -8850,18 +9705,47 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_0_q8_0_T; m.funcs[7] = mul_mat_qX_0_q8_0_T; } - else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v) { - m.funcs[0] = mul_mat_qX_1_q8_1_T; - m.funcs[1] = mul_mat_qX_1_q8_1_T; - m.funcs[2] = mul_mat_qX_1_q8_1_T; - m.funcs[3] = mul_mat_qX_1_q8_1_T; - m.funcs[4] = mul_mat_qX_1_q8_1_T; - m.funcs[5] = mul_mat_qX_1_q8_1_T; - m.funcs[6] = mul_mat_qX_1_q8_1_T; - m.funcs[7] = mul_mat_qX_1_q8_1_T; + else if constexpr (std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_1_q8_2_T; + m.funcs[1] = mul_mat_qX_1_q8_2_T; + m.funcs[2] = mul_mat_qX_1_q8_2_T; + m.funcs[3] = mul_mat_qX_1_q8_2_T; + m.funcs[4] = mul_mat_qX_1_q8_2_T; + m.funcs[5] = mul_mat_qX_1_q8_2_T; + m.funcs[6] = mul_mat_qX_1_q8_2_T; + m.funcs[7] = mul_mat_qX_1_q8_2_T; + } + else if constexpr (std::is_same_v) { +#ifdef HAVE_FANCY_SIMD + m.funcs[0] = mul_mat_qX_1_q8_2_T; + m.funcs[1] = mul_mat_qX_1_q8_2_T; + m.funcs[2] = mul_mat_qX_1_q8_2_T; + m.funcs[3] = mul_mat_qX_1_q8_2_T; + m.funcs[4] = mul_mat_qX_1_q8_2_T; + m.funcs[5] = mul_mat_qX_1_q8_2_T; + m.funcs[6] = mul_mat_qX_1_q8_2_T; + m.funcs[7] = mul_mat_qX_1_q8_2_T; +#else + m.funcs[0] = mul_mat_qX_0_q8_0_T; + m.funcs[1] = mul_mat_qX_0_q8_0_T; + m.funcs[2] = mul_mat_qX_0_q8_0_T; + m.funcs[3] = mul_mat_qX_0_q8_0_T; + m.funcs[4] = mul_mat_qX_0_q8_0_T; + m.funcs[5] = mul_mat_qX_0_q8_0_T; + m.funcs[6] = mul_mat_qX_0_q8_0_T; + m.funcs[7] = mul_mat_qX_0_q8_0_T; +#endif + } + else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_1_q8_2_T; + m.funcs[1] = mul_mat_qX_1_q8_2_T; + m.funcs[2] = mul_mat_qX_1_q8_2_T; + m.funcs[3] = mul_mat_qX_1_q8_2_T; + m.funcs[4] = mul_mat_qX_1_q8_2_T; + m.funcs[5] = mul_mat_qX_1_q8_2_T; + m.funcs[6] = mul_mat_qX_1_q8_2_T; + m.funcs[7] = mul_mat_qX_1_q8_2_T; } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -9173,33 +10057,33 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q4_1: assert (ne00 % QK4_1 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q5_0: assert (ne00 % QK5_0 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q5_1: assert (ne00 % QK5_1 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q6_0: assert (ne00 % QK6_0 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); #ifdef HAVE_FANCY_SIMD MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; #else MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_0_X4; @@ -9208,19 +10092,23 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ4_NL: assert (ne00 % QK4_NL == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; +#ifdef HAVE_FANCY_SIMD + expected_typeB = GGML_TYPE_Q8_2_X4; +#else + expected_typeB = GGML_TYPE_Q8_0_X4; +#endif break; case GGML_TYPE_IQ4_NL_R4: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_iq4_nl_r4_q8_1<1>; - mm.funcs[1] = mul_mat_iq4_nl_r4_q8_1<2>; - mm.funcs[2] = mul_mat_iq4_nl_r4_q8_1<3>; - mm.funcs[3] = mul_mat_iq4_nl_r4_q8_1<4>; - mm.funcs[4] = mul_mat_iq4_nl_r4_q8_1<5>; - mm.funcs[5] = mul_mat_iq4_nl_r4_q8_1<6>; - mm.funcs[6] = mul_mat_iq4_nl_r4_q8_1<7>; - mm.funcs[7] = mul_mat_iq4_nl_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1_X4; + mm.funcs[0] = mul_mat_iq4_nl_r4_q8_2<1>; + mm.funcs[1] = mul_mat_iq4_nl_r4_q8_2<2>; + mm.funcs[2] = mul_mat_iq4_nl_r4_q8_2<3>; + mm.funcs[3] = mul_mat_iq4_nl_r4_q8_2<4>; + mm.funcs[4] = mul_mat_iq4_nl_r4_q8_2<5>; + mm.funcs[5] = mul_mat_iq4_nl_r4_q8_2<6>; + mm.funcs[6] = mul_mat_iq4_nl_r4_q8_2<7>; + mm.funcs[7] = mul_mat_iq4_nl_r4_q8_2<8>; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_IQ4_XS_R8: assert (ne00 % QK_K == 0); @@ -9393,6 +10281,33 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #endif expected_typeB = GGML_TYPE_Q8_KR8; break; + case GGML_TYPE_Q8_KV: + assert (ne00 % 32 == 0); + mm.funcs[0] = mul_mat_q8_KV_q8_KV_1<1>; + mm.funcs[1] = mul_mat_q8_KV_q8_KV<2>; + mm.funcs[2] = mul_mat_q8_KV_q8_KV<3>; + mm.funcs[3] = mul_mat_q8_KV_q8_KV<4>; + mm.funcs[4] = mul_mat_q8_KV_q8_KV<5>; + mm.funcs[5] = mul_mat_q8_KV_q8_KV<6>; + mm.funcs[6] = mul_mat_q8_KV_q8_KV<7>; + mm.funcs[7] = mul_mat_q8_KV_q8_KV<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_q8_KV_q8_KV<16>; +#endif + expected_typeB = GGML_TYPE_Q8_KV; + break; + case GGML_TYPE_Q8_KV_R8: + assert (ne00 % 32 == 0); + mm.funcs[0] = mul_mat_q8_KV_r8_q8_KV<1>; + mm.funcs[1] = mul_mat_q8_KV_r8_q8_KV<2>; + mm.funcs[2] = mul_mat_q8_KV_r8_q8_KV<3>; + mm.funcs[3] = mul_mat_q8_KV_r8_q8_KV<4>; + mm.funcs[4] = mul_mat_q8_KV_r8_q8_KV<5>; + mm.funcs[5] = mul_mat_q8_KV_r8_q8_KV<6>; + mm.funcs[6] = mul_mat_q8_KV_r8_q8_KV<7>; + mm.funcs[7] = mul_mat_q8_KV_r8_q8_KV<8>; + expected_typeB = GGML_TYPE_Q8_KV; + break; case GGML_TYPE_IQ4_K_R4: assert (ne00 % QK_K == 0); mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>; @@ -9448,54 +10363,68 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_Q4_0_R8: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q4_0_r8_q8_1<1>; - mm.funcs[1] = mul_mat_q4_0_r8_q8_1<2>; - mm.funcs[2] = mul_mat_q4_0_r8_q8_1<3>; - mm.funcs[3] = mul_mat_q4_0_r8_q8_1<4>; - mm.funcs[4] = mul_mat_q4_0_r8_q8_1<5>; - mm.funcs[5] = mul_mat_q4_0_r8_q8_1<6>; - mm.funcs[6] = mul_mat_q4_0_r8_q8_1<7>; - mm.funcs[7] = mul_mat_q4_0_r8_q8_1<8>; + mm.funcs[0] = mul_mat_q4_0_r8_q8_2<1>; + mm.funcs[1] = mul_mat_q4_0_r8_q8_2<2>; + mm.funcs[2] = mul_mat_q4_0_r8_q8_2<3>; + mm.funcs[3] = mul_mat_q4_0_r8_q8_2<4>; + mm.funcs[4] = mul_mat_q4_0_r8_q8_2<5>; + mm.funcs[5] = mul_mat_q4_0_r8_q8_2<6>; + mm.funcs[6] = mul_mat_q4_0_r8_q8_2<7>; + mm.funcs[7] = mul_mat_q4_0_r8_q8_2<8>; #ifdef HAVE_FANCY_SIMD - mm.func16 = mul_mat_q4_0_r8_q8_1<16>; + mm.func16 = mul_mat_q4_0_r8_q8_2<16>; #endif - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q5_0_R4: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q5_0_r4_q8_1<1>; - mm.funcs[1] = mul_mat_q5_0_r4_q8_1<2>; - mm.funcs[2] = mul_mat_q5_0_r4_q8_1<3>; - mm.funcs[3] = mul_mat_q5_0_r4_q8_1<4>; - mm.funcs[4] = mul_mat_q5_0_r4_q8_1<5>; - mm.funcs[5] = mul_mat_q5_0_r4_q8_1<6>; - mm.funcs[6] = mul_mat_q5_0_r4_q8_1<7>; - mm.funcs[7] = mul_mat_q5_0_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1_X4; + mm.funcs[0] = mul_mat_q5_0_r4_q8_2<1>; + mm.funcs[1] = mul_mat_q5_0_r4_q8_2<2>; + mm.funcs[2] = mul_mat_q5_0_r4_q8_2<3>; + mm.funcs[3] = mul_mat_q5_0_r4_q8_2<4>; + mm.funcs[4] = mul_mat_q5_0_r4_q8_2<5>; + mm.funcs[5] = mul_mat_q5_0_r4_q8_2<6>; + mm.funcs[6] = mul_mat_q5_0_r4_q8_2<7>; + mm.funcs[7] = mul_mat_q5_0_r4_q8_2<8>; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q6_0_R4: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q6_0_r4_q8_1<1>; - mm.funcs[1] = mul_mat_q6_0_r4_q8_1<2>; - mm.funcs[2] = mul_mat_q6_0_r4_q8_1<3>; - mm.funcs[3] = mul_mat_q6_0_r4_q8_1<4>; - mm.funcs[4] = mul_mat_q6_0_r4_q8_1<5>; - mm.funcs[5] = mul_mat_q6_0_r4_q8_1<6>; - mm.funcs[6] = mul_mat_q6_0_r4_q8_1<7>; - mm.funcs[7] = mul_mat_q6_0_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1_X4; + mm.funcs[0] = mul_mat_q6_0_r4_q8_2<1>; + mm.funcs[1] = mul_mat_q6_0_r4_q8_2<2>; + mm.funcs[2] = mul_mat_q6_0_r4_q8_2<3>; + mm.funcs[3] = mul_mat_q6_0_r4_q8_2<4>; + mm.funcs[4] = mul_mat_q6_0_r4_q8_2<5>; + mm.funcs[5] = mul_mat_q6_0_r4_q8_2<6>; + mm.funcs[6] = mul_mat_q6_0_r4_q8_2<7>; + mm.funcs[7] = mul_mat_q6_0_r4_q8_2<8>; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q8_0_R8: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q8_0_r8_q8_1<1>; - mm.funcs[1] = mul_mat_q8_0_r8_q8_1<2>; - mm.funcs[2] = mul_mat_q8_0_r8_q8_1<3>; - mm.funcs[3] = mul_mat_q8_0_r8_q8_1<4>; - mm.funcs[4] = mul_mat_q8_0_r8_q8_1<5>; - mm.funcs[5] = mul_mat_q8_0_r8_q8_1<6>; - mm.funcs[6] = mul_mat_q8_0_r8_q8_1<7>; - mm.funcs[7] = mul_mat_q8_0_r8_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1_X4; + mm.funcs[0] = mul_mat_q8_0_r8_q8_2<1>; + mm.funcs[1] = mul_mat_q8_0_r8_q8_2<2>; + mm.funcs[2] = mul_mat_q8_0_r8_q8_2<3>; + mm.funcs[3] = mul_mat_q8_0_r8_q8_2<4>; + mm.funcs[4] = mul_mat_q8_0_r8_q8_2<5>; + mm.funcs[5] = mul_mat_q8_0_r8_q8_2<6>; + mm.funcs[6] = mul_mat_q8_0_r8_q8_2<7>; + mm.funcs[7] = mul_mat_q8_0_r8_q8_2<8>; + expected_typeB = GGML_TYPE_Q8_2_X4; + break; + case GGML_TYPE_IQ1_S: + mm.funcs[0] = mul_mat_iq1_s_q8_K<1>; + mm.funcs[1] = mul_mat_iq1_s_q8_K<2>; + mm.funcs[2] = mul_mat_iq1_s_q8_K<3>; + mm.funcs[3] = mul_mat_iq1_s_q8_K<4>; + mm.funcs[4] = mul_mat_iq1_s_q8_K<5>; + mm.funcs[5] = mul_mat_iq1_s_q8_K<6>; + mm.funcs[6] = mul_mat_iq1_s_q8_K<7>; + mm.funcs[7] = mul_mat_iq1_s_q8_K<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_iq1_s_q8_K<16>; +#endif + expected_typeB = GGML_TYPE_Q8_K; break; case GGML_TYPE_IQ1_S_R4: assert (ne00 % QK4_NL == 0); @@ -11241,9 +12170,9 @@ inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& i q8.process_scales(i, deq, sc16, acc); sum_4(i, deq, q8, sc16, acc); } - //for (int i = 4*(nb/4); i < nb; ++i) { - // q8.process_1_block(i, deq, acc); - //} + for (int i = 4*(nb/4); i < nb; ++i) { + q8.process_1_block(i, deq, acc); + } for (int iy = 0; iy < Q8::nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); @@ -11318,15 +12247,13 @@ static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { - if (nrc_x%2 == 0) { + if (nrc_x%2 == 0 && n%128 == 0) { Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); } else { Dequantizer deq(vx, bx); mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); } - //Dequantizer deq(vx, bx); - //mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); } } @@ -11337,7 +12264,7 @@ static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { - if (nrc_x%2 == 0) { + if (nrc_x%2 == 0 && n%128 == 0) { Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); } else { @@ -12472,7 +13399,7 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI int nb = n / 32; GGML_ASSERT(nb%4 == 0); uint8x16_t qx[8]; - int32x4_t acc[nrc_y] = {}; + float32x4_t acc[nrc_y] = {}; auto ms = vdup_n_u16(0x8000); auto mask = vdupq_n_s8(0x03); float d8[4*nrc_y]; @@ -12537,6 +13464,68 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI } } +template +static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + Q8 q8(info); + int8x16_t qx[16]; + int32x4_t scales[2]; + int16x4_t deltas[2]; + float32x4_t acc[nrc_y] = {}; + auto delta_mask = vdupq_n_u16(0x8000); + for (int ix = 0; ix < nrc_x; ++ix) { + auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < n/QK_K; ++ibl) { + float d = GGML_FP16_TO_FP32(iq1s[ibl].d); + auto qhb = vld1q_u16(iq1s[ibl].qh); + auto scales128 = vandq_u16(vshrq_n_u16(qhb, 12), vdupq_n_u16(7)); + scales128 = vaddq_u16(vshlq_n_u16(scales128, 1), vdupq_n_u16(1)); + auto mask = vceqq_u16(vandq_u16(qhb, delta_mask), delta_mask); + // Note: we explicitely assume IQ1S_DELTA = 0.125 + auto deltas128 = vsubq_s16(vbicq_s16(scales128, mask), vandq_s16(scales128, mask)); + //auto deltas128 = vorrq_s16(vandq_s16(vdupq_n_s16(-1), mask), vbicq_s16(vdupq_n_s16(1), mask)); + //deltas128 = vmulq_s16(scales128, deltas128); + scales128 = vshlq_n_u16(scales128, 3); + auto qs = iq1s[ibl].qs; + auto qh = iq1s[ibl].qh; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + qx[4*ib64+0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[2*ib64+0] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[2*ib64+0] << 5) & 0x700)]}); + qx[4*ib64+1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[2*ib64+0] << 2) & 0x700)], iq1s_grid[qs[3] | ((qh[2*ib64+0] >> 1) & 0x700)]}); + qx[4*ib64+2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[4] | ((qh[2*ib64+1] << 8) & 0x700)], iq1s_grid[qs[5] | ((qh[2*ib64+1] << 5) & 0x700)]}); + qx[4*ib64+3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[6] | ((qh[2*ib64+1] << 2) & 0x700)], iq1s_grid[qs[7] | ((qh[2*ib64+1] >> 1) & 0x700)]}); + qs += 8; + } + scales[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales128))); + scales[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales128))); + deltas[0] = vget_low_s16 (deltas128); + deltas[1] = vget_high_s16(deltas128); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums8(iy, ibl); + auto sumi = vdupq_n_s32(0); + sumi = vmlal_s16(sumi, deltas[0], vget_low_s16 (bsums)); + sumi = vmlal_s16(sumi, deltas[1], vget_high_s16(bsums)); + for (int k = 0; k < QK_K/128; ++k) { + auto qy = q8.load_quants_64(iy, ibl, 2*k+0); + auto dot1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+0], qy.val[0]), qx[8*k+1], qy.val[1]); + auto dot2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+2], qy.val[2]), qx[8*k+3], qy.val[3]); + auto dot12 = vpaddq_s32(dot1, dot2); + qy = q8.load_quants_64(iy, ibl, 2*k+1); + auto dot3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+4], qy.val[0]), qx[8*k+5], qy.val[1]); + auto dot4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+6], qy.val[2]), qx[8*k+7], qy.val[3]); + auto dot34 = vpaddq_s32(dot3, dot4); + auto dot = vpaddq_s32(dot12, dot34); + sumi = vmlaq_s32(sumi, dot, scales[k]); + } + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi)); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy])); + acc[iy] = vdupq_n_f32(0); + } + } +} + template static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -13703,6 +14692,123 @@ void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf } } +static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%32 == 0); + int32x4_t acc[4] = {}; + auto dptr = (const float *)info.src1_row(0); + const float dy = dptr[0]; + auto q8y = (const int8_t *)(dptr + 2); + for (int ix = 0; ix < nrc_x; ++ix) { + auto dx = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dx + 2); + for (int i = 0; i < n/64; ++i) { + auto qx = vld1q_s8_x4(q8x + 64*i); + for (int j = 0; j < 4; ++j) { + acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j)); + } + } + if (int i = 2*(n/64); i < n/32) { + auto qx = vld1q_s8_x2(q8x + 32*i); + for (int j = 0; j < 2; ++j) { + acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j)); + } + } + acc[0] = vaddq_s32(acc[0], acc[1]); + acc[2] = vaddq_s32(acc[2], acc[3]); + acc[0] = vaddq_s32(acc[0], acc[2]); + info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0])); + acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0); + } +} + +template +static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(n%16 == 0); + int8x16_t qx[4]; + int32x4_t acc[nrc_y] = {}; + float dy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[4]; + float dx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + for (int kx = 0; kx < 4; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/16; ++i) { + for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i); + auto row01 = vtrnq_s32(qx[0], qx[1]); + auto row23 = vtrnq_s32(qx[2], qx[3]); + qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]); + qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]); + qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]); + qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8y[iy] + 16*i); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3); + } + } + auto scales_x = vld1q_f32(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy])); + info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy]))); + acc[iy] = vdupq_n_s32(0); + } + } +} + +template +void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + int32x4_t acc[2*nrc_y] = {}; + float dy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ix += 8) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < n/16; ++ib) { + auto q1 = vld1q_s8_x4(q8x + 128*ib + 0); + auto q2 = vld1q_s8_x4(q8x + 128*ib + 64); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8y[iy]+16*ib); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3); + } + } + auto scale1_x = vld1q_f32(dptr+0); + auto scale2_x = vld1q_f32(dptr+4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto scale_y = vdupq_n_f32(dy[iy]); + auto scale1 = vmulq_f32(scale1_x, scale_y); + auto scale2 = vmulq_f32(scale2_x, scale_y); + info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0]))); + info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1]))); + acc[2*iy+0] = acc[2*iy+1] = vdupq_n_s32(0.f); + } + } +} + void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8<1, block_q8_0_x4> q8(info); @@ -13892,8 +14998,8 @@ struct Q4_0_R8_Dequantizer { float32x4x2_t scales = { vcvt_f32_f16(vget_low_f16(scales16)), vcvt_f32_f16(vget_high_f16(scales16)) }; for (int j = 0; j < 4; ++j) { auto bits = vld1q_u8_x2(iq4[4*ib4+k].qs + 32*j); - //bits.val[0] = veorq_u8(m88, bits.val[0]); - //bits.val[1] = veorq_u8(m88, bits.val[1]); + bits.val[0] = veorq_u8(m88, bits.val[0]); + bits.val[1] = veorq_u8(m88, bits.val[1]); qx[2*j+0] = vshlq_n_u8(bits.val[0], 4); qx[2*j+1] = vandq_u8(bits.val[0], m4); qx[2*j+8] = vshlq_n_u8(bits.val[1], 4); @@ -14234,6 +15340,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.func16 = mul_mat_iq2_s_r4_q8_k<16>; expected_Btype = GGML_TYPE_Q8_K; break; + case GGML_TYPE_IQ1_S: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_q8_K); + m.func16 = mul_mat_iq1_s_q8_K<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; case GGML_TYPE_IQ1_S_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1); m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1; @@ -14279,6 +15390,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k); expected_Btype = GGML_TYPE_Q8_KR8; break; + case GGML_TYPE_Q8_KV: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_q8_KV); + m.funcs[0] = mul_mat_q8_KV_q8_KV_1; + m.func16 = mul_mat_q8_KV_q8_KV<16>; + expected_Btype = GGML_TYPE_Q8_KV; + break; + case GGML_TYPE_Q8_KV_R8: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_r8_q8_KV); + expected_Btype = GGML_TYPE_Q8_KV; + break; case GGML_TYPE_IQ2_K_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; @@ -14366,10 +15487,49 @@ inline float32x4_t v_tanh(float32x4_t x) { return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); } -inline float32x4_t v_tanh(float16x8_t x) { - auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); - auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); - return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); +//inline float32x4_t v_tanh(float16x8_t x) { +// auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); +// auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); +// return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); +//} +inline float32x4_t v_silu(float32x4_t x) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t neg_x = vsubq_f32(zero, x); + const float32x4_t exp_neg_x = v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} +inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { + const float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); + arg = vmulq_f32(arg, vmulq_f32(x, c2)); + float32x4_t exp_arg = v_expf(arg); + float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one))); + uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); + return vbslq_f32(mask, x, gelu); +} + +void MulMat::gelu(int n, const float * x, float * y) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + int i = 0; + auto c1 = vdupq_n_f32(GELU_COEF_A); + auto c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI); + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, v_gelu(vld1q_f32(x + i), c1, c2)); + } + for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); +} + +void MulMat::silu(int n, const float * x, float * y) { + int i = 0; + for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_silu(vld1q_f32(x + i))); + for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i])); +} + +void MulMat::relu(int n, const float * x, float * y) { + for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0; } #endif @@ -14413,6 +15573,24 @@ inline __m512 v_tanh(__m512 x) { const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); return _mm512_mask_blend_ps(mask, res, one); } +inline __m512 v_gelu(__m512 x, __m512 c1, __m512 c2) { + const __m512 one = _mm512_set1_ps(1.0f); + __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); + //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1)); + arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x)); + const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ); + const __m512 exp_arg = v_expf(arg); + const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one)); + return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one)); +} +inline static __m512 v_silu(__m512 x) { + const __m512 one = _mm512_set1_ps(1); + const __m512 zero = _mm512_setzero_ps(); + const __m512 neg_x = _mm512_sub_ps(zero, x); + const __m512 exp_neg_x = v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_div_ps(x, one_plus_exp_neg_x); +} #endif #if defined(__AVX2__) && defined(__FMA__) @@ -14466,6 +15644,61 @@ inline __m256 v_tanh(__m256 x) { const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res)); } +inline static __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) { + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); + __m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1)); + arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2)); + __m256 exp_arg = v_expf(arg); + __m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one))); + return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu)); +} +inline static __m256 v_silu(__m256 x) { + const __m256 one = _mm256_set1_ps(1); + const __m256 zero = _mm256_setzero_ps(); + const __m256 neg_x = _mm256_sub_ps(zero, x); + const __m256 exp_neg_x = v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_div_ps(x, one_plus_exp_neg_x); +} + +void MulMat::gelu(int n, const float * x, float * y) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + //GGML_ASSERT(n%8 == 0); + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + { + __m512 c1 = _mm512_set1_ps(GELU_COEF_A); + __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2)); + } +#endif +#if defined __AVX2__ && defined __FMA__ + if (i + 7 < n) { + __m256 c1 = _mm256_set1_ps(GELU_COEF_A); + __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2)); + + } +#endif + for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); +} + +void MulMat::silu(int n, const float * x, float * y) { + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_silu(_mm512_loadu_ps(x + i))); +#endif +#if defined __AVX2__ && defined __FMA__ + for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_silu(_mm256_loadu_ps(x + i))); +#endif + for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i])); +} + +void MulMat::relu(int n, const float * x, float * y) { + for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0; +} #endif } // namespace @@ -14476,7 +15709,7 @@ template struct BaseHelper { BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {} - inline void set_block(int k1) { block = data + k1*k_step*stride; } + //inline void set_block(int k1) { block = data + k1*k_step*stride; } inline void reset_block() { block = data; } inline void next_block() { block += k_step*stride; } inline const char * lblock(int l1) const { return block + l1*stride; } @@ -14488,7 +15721,7 @@ struct BaseHelper { }; struct F16 { -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ using Data = __m512; constexpr static int block_size = 16; constexpr static int num_registers = 32; @@ -14510,6 +15743,13 @@ struct F16 { auto v256 = _mm256_set_m128(v128, v128); return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1); } + static inline void set4(const float * ptr, Data * vs) { + auto v = set4(ptr); + vs[0] = _mm512_shuffle_ps(v, v, 0x00); + vs[1] = _mm512_shuffle_ps(v, v, 0x55); + vs[2] = _mm512_shuffle_ps(v, v, 0xaa); + vs[3] = _mm512_shuffle_ps(v, v, 0xff); + } static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); } static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); } static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); } @@ -14535,6 +15775,13 @@ struct F16 { auto v128 = _mm_loadu_ps(ptr); return _mm256_set_m128(v128, v128); } + static inline void set4(const float * ptr, Data * vs) { + auto v = set4(ptr); + vs[0] = _mm256_shuffle_ps(v, v, 0x00); + vs[1] = _mm256_shuffle_ps(v, v, 0x55); + vs[2] = _mm256_shuffle_ps(v, v, 0xaa); + vs[3] = _mm256_shuffle_ps(v, v, 0xff); + } static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); } static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); } static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); } @@ -14626,13 +15873,49 @@ struct HelperF16 final : public BaseHelper { } }; +template struct block_q8_KV { + float d; + int s; + int8_t qs[D]; +}; + +template +struct HelperQ8KV final : public BaseHelper { + using Base = BaseHelper; + using block_q8 = block_q8_KV; + constexpr static int block_size_q = D; + HelperQ8KV(const char * data, int stride) : Base(data, stride) {} + + // Needed for v * softmax(k * q) + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + auto q8 = (const block_q8_KV *)Base::lblock(l1); +#ifdef __aarch64__ + auto vd = F16::set1(q8->d); + auto qs = vld1_s8_x2(q8->qs + 8*i); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); +#else + auto vd = F16::set1(q8->d); +#ifdef __AVX512F__ + v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0)))); + v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1)))); +#else + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0))))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8))))); +#endif +#endif + } +}; + template struct HelperQ80 final : public BaseHelper { using Base = BaseHelper; #ifdef HAVE_FANCY_SIMD - using block_q8 = block_q8_1; + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; #else using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; #endif HelperQ80(const char * data, int stride) : Base(data, stride) {} @@ -14648,7 +15931,7 @@ struct HelperQ80 final : public BaseHelper { v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0)))); v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1)))); #else @@ -14676,34 +15959,162 @@ struct HelperQ80 final : public BaseHelper { y += D/QK8_1; } } -}; + static inline void convert(int nq, int stride_q, const float * q, block_q8_2 * y) { + //GGML_ASSERT(nq <= step); Why did I have this assert? + for (int i = 0; i < nq; ++i) { + quantize_row_q8_2_x4(q, y, D); + q += stride_q; + y += D/QK8_2; + } + } + + static inline void convert(int nq, int stride_q, const float * q, block_q8_KV * y) { + for (int i = 0; i < nq; ++i) { + quantize_row_q8_KV(q, y, D); + q += stride_q; + ++y; + } + } +}; +} + +void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3, + const void * data, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size) { + repacked_type = int_type_k; + auto type_k = ggml_type(int_type_k); + if (type_k != GGML_TYPE_Q8_0 || nek0%QK8_0 != 0) return work; + int nrows = nek1*nek2*nek3; + if (nrows%8 != 0) return work; + repacked_type = int(GGML_TYPE_Q8_0_R8); + row_size = ggml_row_size(GGML_TYPE_Q8_0, nek0); + void * result = (char *)work + nrows*row_size; + int npt = 8*((nrows/8 + nth - 1)/nth); + int first = npt*ith; + if (first >= nrows) return result; + int last = std::min(first + npt, nrows); + const block_q8_0 * x8[8]; + auto y = (block_q8_0_r8 *)((char *)work + first*row_size); + int nblock = nek0/QK8_0; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = first; row < last; row += 8) { + int ik3 = row/(nek1*nek2); + int ik2 = (row - ik3*nek1*nek2)/nek1; + int ik1 = row - ik3*nek1*nek2 - ik2*nek1; + auto this_data = (const char *)data + ik1*nbk1 + ik2*nbk2 + ik3*nbk3; + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(this_data + k*nbk1); + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; +#ifdef __AVX2__ + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + //#ifdef HAVE_FANCY_SIMD + // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + //#endif + _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3); + m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1)); + m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1)); + m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1)); + m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1)); + t0 = _mm256_unpacklo_epi32(m0, m1); + t1 = _mm256_unpacklo_epi32(m2, m3); + t2 = _mm256_unpackhi_epi32(m0, m1); + t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + //#ifdef HAVE_FANCY_SIMD + // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + //#endif + _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3); +#elif defined __ARM_NEON + for (int l = 0; l < 2; ++l) { + m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l); + m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l); + m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l); + m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0); + vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1); + vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2); + vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3); + } +#else + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + } + y += nblock; + } + return result; +} + +namespace { template -struct HelperQ80R4 : public BaseHelper { +struct HelperQ80R8 : public BaseHelper { using Base = BaseHelper; #ifdef __AVX2__ - using block_q8 = block_q8_1; + constexpr static int block_size_q = QK8_2; + using block_q8 = block_q8_2; #else + constexpr static int block_size_q = QK8_0; using block_q8 = block_q8_0; #endif - HelperQ80R4(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) { + HelperQ80R8(const char * data, int stride) : Base(data, stride) {} + HelperQ80R8(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) { r4 = repack(nk, q8); Base::data = (const char *)r4.data(); Base::stride = (D/QK8_0)*sizeof(block_q8_0); } - static std::vector repack(int nk, const HelperQ80 q8) { - static_assert(D%QK8_0 == 0); - GGML_ASSERT(nk%8 == 0); + static void repack(int nk, const char * q8_data, int q8_stride, block_q8_0_r8 * y) { constexpr int nblock = D/QK8_0; - std::vector result(nblock * nk/8); - auto y = result.data(); const block_q8_0 * x8[8]; #ifdef __ARM_NEON int8x16x2_t m0, m1, m2, m3; #endif for (int row = 0; row < nk; row += 8) { - for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8_data + (row + k)*q8_stride); for (int ib = 0; ib < nblock; ++ib) { for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; #ifdef __AVX2__ @@ -14719,12 +16130,12 @@ struct HelperQ80R4 : public BaseHelper { m1 = _mm256_unpackhi_epi64(t0, t1); m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); -#ifdef HAVE_FANCY_SIMD - m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); - m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); - m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); - m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); -#endif +//#ifdef HAVE_FANCY_SIMD +// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); +// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); +// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); +// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); +//#endif _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); @@ -14741,12 +16152,12 @@ struct HelperQ80R4 : public BaseHelper { m1 = _mm256_unpackhi_epi64(t0, t1); m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); -#ifdef HAVE_FANCY_SIMD - m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); - m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); - m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); - m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); -#endif +//#ifdef HAVE_FANCY_SIMD +// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); +// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); +// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); +// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); +//#endif _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); @@ -14785,16 +16196,127 @@ struct HelperQ80R4 : public BaseHelper { } y += nblock; } + } + + static std::vector repack(int nk, const HelperQ80& q8) { + static_assert(D%QK8_0 == 0); + GGML_ASSERT(nk%8 == 0); + constexpr int nblock = D/QK8_0; + std::vector result(nblock * nk/8); + auto y = result.data(); + repack(nk, q8.data, q8.stride, y); return result; } std::vector r4; }; +// TODO: unite this with the above +template +struct HelperQ8KVR8 : public BaseHelper { + using Base = BaseHelper; + constexpr static int block_size_q = D; + using block_q8 = block_q8_KV; + + struct block_q8_KV_r8 { + float d[8]; + int8_t qs[8*D]; + }; + + HelperQ8KVR8(int nk, const HelperQ8KV& q8) : Base(q8.data, q8.stride) { + r4 = repack(nk, q8); + Base::data = (const char *)r4.data(); + Base::stride = sizeof(block_q8_KV_r8)/8; + } + + static std::vector repack(int nk, const HelperQ8KV& q8) { + static_assert(D%32 == 0); + GGML_ASSERT(nk%8 == 0); + std::vector result(nk/8); + auto y = result.data(); +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + const int8_t * x8[8]; + for (int ix = 0; ix < nk/8; ++ix) { + for (int k = 0; k < 8; ++k) { + auto dptr = (const float *)(q8.data + (8*ix + k)*q8.stride); + y[ix].d[k] = dptr[0]; + x8[k] = (const int8_t *)(dptr + 2); + } + for (int ib = 0; ib < D/16; ++ib) { +#ifdef __AVX2__ + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); +//#ifdef HAVE_FANCY_SIMD +// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); +// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); +// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); +// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); +//#endif + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0); + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1); + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2); + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3); +#elif defined __ARM_NEON + // TODO + m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib); + m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib); + m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib); + m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0); + vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1); + vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2); + vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3); +#else + // TODO + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + } + } + return result; + } + + std::vector r4; +}; + template struct HelperQ40 final : public BaseHelper { using Base = BaseHelper; +#if defined __AVX2__ + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; +#else using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; +#endif HelperQ40(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -14811,7 +16333,7 @@ struct HelperQ40 final : public BaseHelper { #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); @@ -14837,7 +16359,8 @@ struct HelperQ40 final : public BaseHelper { template struct HelperQ41 final : public BaseHelper { using Base = BaseHelper; - using block_q8 = block_q8_1; + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; HelperQ41(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -14855,7 +16378,7 @@ struct HelperQ41 final : public BaseHelper { auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_and_si128(q, mask); auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask); v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm); @@ -14882,9 +16405,16 @@ struct HelperIQ4nl final : public BaseHelper { #ifdef __aarch64__ using block_q8 = block_q8_0; HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {} + constexpr static int block_size_q = QK8_0; #else HelperIQ4nl(const char * data, int stride) : Base(data, stride) {} - using block_q8 = block_q8_1; +#ifdef HAVE_FANCY_SIMD + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; +#else + using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; +#endif #endif // Needed for v * softmax(k * q) @@ -14901,7 +16431,7 @@ struct HelperIQ4nl final : public BaseHelper { #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask)); auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask)); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); @@ -14928,8 +16458,10 @@ template struct HelperQ60 final : public BaseHelper { #ifdef __aarch64__ using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; #else - using block_q8 = block_q8_1; + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; #endif using Base = BaseHelper; HelperQ60(const char * data, int stride) : Base(data, stride) {} @@ -14940,7 +16472,9 @@ struct HelperQ60 final : public BaseHelper { auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0; #ifdef __aarch64__ // TODO - auto vd = F16::set1(*(const float16_t *)&dl->d); + const float16_t * d16 = (const float16_t *)&dl->d; + auto vd = F16::set1(d16[0]); + //auto vd = F16::set1(*(const float16_t *)&dl->d); auto qh8 = vld1_u8(dl->qh); auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); auto qs = vld1q_u8(dl->qs); @@ -14954,7 +16488,7 @@ struct HelperQ60 final : public BaseHelper { auto bl = _mm_loadu_si128((const __m128i *)dl->qs); uint64_t aux64; std::memcpy(&aux64, dl->qh, 8); auto bh = _mm_set_epi64x(aux64, aux64 << 4); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32); auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); @@ -15088,7 +16622,7 @@ struct FlashMS { return vmaxvq_f32(vmax); } inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) { - auto vzero = vdupq_n_f32(0); + auto vzero = vdupq_n_f16(0); auto vinf = vdupq_n_f32(-INFINITY); for (int l = 0; l < k_step/8; ++l) { auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l)); @@ -15096,9 +16630,9 @@ struct FlashMS { auto vm2 = vzip2q_u16(vm, vm); auto kq = vld1q_f32_x2(cache + k_step*j + 8*l); vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1), - vbicq_u32(vinf, vm1))); + vbicq_u32(vreinterpretq_u32_f32(vinf), vm1))); vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2), - vbicq_u32(vinf, vm2))); + vbicq_u32(vreinterpretq_u32_f32(vinf), vm2))); } float32x4_t vmax = vdupq_n_f32(-INFINITY); auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale)); @@ -15130,6 +16664,23 @@ struct FlashMS { } return F16::reduce_max(vk); } + static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) { + return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l))); + //auto m128 = _mm_loadu_si128((const __m128i *)mask+l); + //m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); + //auto m256 = _mm256_cvtepi16_epi32(m128); + //auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); + //return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + } +#ifdef __AVX512F__ + static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) { + auto m256 = _mm256_loadu_si256((const __m256i *)mask+l); + m256 = _mm256_cmpeq_epi16(m256, _mm256_setzero_si256()); + auto m512 = _mm512_cvtepi16_epi32(m256); + auto mf = _mm512_castsi512_ps(_mm512_or_si512(m512, _mm512_slli_epi32(m512, 16))); + return _mm512_or_ps(_mm512_and_ps(mf, val), _mm512_andnot_ps(mf, vinf)); + } +#endif inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) { #ifdef HAVE_FANCY_SIMD auto vzero = _mm256_set1_epi16(0); @@ -15147,15 +16698,9 @@ struct FlashMS { } } #else - auto vzero = _mm_set1_epi16(0); auto vinf = F16::set1(-INFINITY); for (int l = 0; l < k_step/F16::block_size; ++l) { - auto m128 = _mm_loadu_si128((const __m128i *)mask + l); - m128 = _mm_cmpeq_epi16(m128, vzero); - auto m256 = _mm256_cvtepi16_epi32(m128); - auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); - auto val = _mm256_loadu_ps(cache + k_step*j + F16::block_size*l); - vk[l] = _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + vk[l] = apply_mask(l, mask, F16::load(cache + k_step*j + F16::block_size*l), vinf); } if (softcap <= 0) { for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]); @@ -15172,23 +16717,23 @@ struct FlashMS { inline void update_M_S(int j, float32x4_t * vk) { float smax = load_and_scale(j, vk); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } inline void update_M_S(int j, float32x4_t * vk, const char * mask) { float smax = load_apply_mask_and_scale(j, vk, mask); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } #else inline void update_M_S(int j, F16::Data * vk) { float smax = load_and_scale(j, vk); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } inline void update_M_S(int j, F16::Data * vk, const char * mask) { float smax = load_apply_mask_and_scale(j, vk, mask); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } #endif @@ -15211,11 +16756,50 @@ struct FlashQKV { using qkv_cache_t = float; #endif + template + inline void accumulate_qkv_1(const VHelper& vh, const FlashMS& fms) { + F16::Data vq[D/F16::block_size]; + if (fms.need_scaling[0] == 2) { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero(); + } else { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i); + if (fms.need_scaling[0] == 1) { + auto vms = F16::set1(fms.vms[0]); + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]); + } + } + F16::Data v0, v1; + for (int l = 0; l < k_step; l += 4) { + auto vs0 = F16::set1(fms.cache[l + 0]); + auto vs1 = F16::set1(fms.cache[l + 1]); + auto vs2 = F16::set1(fms.cache[l + 2]); + auto vs3 = F16::set1(fms.cache[l + 3]); + for (int i = 0; i < D/F16::block_size; i += 2) { + vh.load(l+0, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs0); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs0); + vh.load(l+1, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs1); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs1); + vh.load(l+2, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs2); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs2); + vh.load(l+3, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs3); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs3); + } + } + for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]); + } + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 // Hence, for now, we will not handle head sizes of 80 and 112 template inline void accumulate_qkv(const VHelper& vh, const FlashMS& fms) { - F16::Data v[8]; + if constexpr (q_step == 1) { + accumulate_qkv_1(vh, fms); + return; + } for (int j = 0; j < q_step; ++j) { auto R = qkv_cache + D*j; if (fms.need_scaling[j] == 2) { @@ -15228,6 +16812,43 @@ struct FlashQKV { } } } +#ifdef __AVX512F__ + if constexpr ((D/F16::block_size)%4 == 0) { + F16::Data v[16]; + F16::Data vs[4]; + for (int i = 0; i < D/F16::block_size; i += 4) { + for (int l = 0; l < k_step; l += 4) { + for (int k = 0; k < 4; ++k) { + vh.load(l+k, i+0, v[4*k+0], v[4*k+1]); + vh.load(l+k, i+2, v[4*k+2], v[4*k+3]); + } + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto s3 = F16::load(R + F16::block_size*(i+2)); + auto s4 = F16::load(R + F16::block_size*(i+3)); + F16::set4(fms.cache + k_step*j + l, vs); + for (int k = 0; k < 4; ++k) { + s1 = F16::fmadd(s1, v[4*k+0], vs[k]); + s2 = F16::fmadd(s2, v[4*k+1], vs[k]); + s3 = F16::fmadd(s3, v[4*k+2], vs[k]); + s4 = F16::fmadd(s4, v[4*k+3], vs[k]); + } + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); + F16::store(R + F16::block_size*(i+2), s3); + F16::store(R + F16::block_size*(i+3), s4); + } + } + } + return; + } +#endif + F16::Data v[8]; +#ifdef __AVX2__ + F16::Data vs[4]; +#endif for (int i = 0; i < D/F16::block_size; i += 2) { for (int l = 0; l < k_step; l += 4) { vh.load(l+0, i, v[0], v[4]); @@ -15238,6 +16859,13 @@ struct FlashQKV { auto R = qkv_cache + D*j; auto s1 = F16::load(R + F16::block_size*(i+0)); auto s2 = F16::load(R + F16::block_size*(i+1)); +#ifdef __AVX2__ + F16::set4(fms.cache + k_step*j + l, vs); + for (int k = 0; k < 4; ++k) { + s1 = F16::fmadd(s1, v[k+0], vs[k]); + s2 = F16::fmadd(s2, v[k+4], vs[k]); + } +#else auto vs = F16::set4(fms.cache + k_step*j + l); s1 = F16::fmadd_lane0(s1, v[0], vs); s2 = F16::fmadd_lane0(s2, v[4], vs); @@ -15247,6 +16875,7 @@ struct FlashQKV { s2 = F16::fmadd_lane2(s2, v[6], vs); s1 = F16::fmadd_lane3(s1, v[3], vs); s2 = F16::fmadd_lane3(s2, v[7], vs); +#endif F16::store(R + F16::block_size*(i+0), s1); F16::store(R + F16::block_size*(i+1), s2); } @@ -15256,6 +16885,10 @@ struct FlashQKV { template inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS& fms) { + if (nq1 == 1) { + accumulate_qkv_1(vh, fms); + return; + } F16::Data v[8]; for (int j = 0; j < nq1; ++j) { auto R = qkv_cache + D*j; @@ -15295,7 +16928,7 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS& fms, int j, const qkv_cache_t * R, float * qkv) const { + inline void normalize_and_store_1row(const FlashMS& fms, int j, const qkv_cache_t * R, float * qkv) const { GGML_ASSERT(fms.S[j] > 0); auto norm = F16::set1(1/fms.S[j]); //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); @@ -15305,21 +16938,55 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS& fms, int nq1, int stride_qkv, float * qkv) const { - auto R = qkv_cache; - for (int j = 0; j < nq1; ++j) { - normalize_and_store(fms, j, R, qkv); - qkv += stride_qkv; - R += D; + inline void normalize_and_store(const FlashMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const { + if (M && S) { + std::memcpy(M, fms.M, nq1*sizeof(float)); + std::memcpy(S, fms.S, nq1*sizeof(float)); + auto R = qkv_cache; + for (int j = 0; j < nq1; ++j) { +#ifdef __aarch64__ + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i)); + } +#else + std::memcpy(qkv, R, D*sizeof(float)); +#endif + qkv += stride_qkv; + R += D; + } + } else { + auto R = qkv_cache; + for (int j = 0; j < nq1; ++j) { + normalize_and_store_1row(fms, j, R, qkv); + qkv += stride_qkv; + R += D; + } } } - inline void normalize_and_store(const FlashMS& fms, int stride_qkv, float * qkv) const { - auto R = qkv_cache; - for (int j = 0; j < q_step; ++j) { - normalize_and_store(fms, j, R, qkv); - qkv += stride_qkv; - R += D; + inline void normalize_and_store(const FlashMS& fms, int stride_qkv, float * qkv, float * M, float * S) const { + if (M && S) { + std::memcpy(M, fms.M, q_step*sizeof(float)); + std::memcpy(S, fms.S, q_step*sizeof(float)); + auto R = qkv_cache; + for (int j = 0; j < q_step; ++j) { +#ifdef __aarch64__ + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i)); + } +#else + std::memcpy(qkv, R, D*sizeof(float)); +#endif + qkv += stride_qkv; + R += D; + } + } else { + auto R = qkv_cache; + for (int j = 0; j < q_step; ++j) { + normalize_and_store_1row(fms, j, R, qkv); + qkv += stride_qkv; + R += D; + } } } @@ -15329,12 +16996,13 @@ struct FlashQKV { // As a result, we get an infinite stream of warnings about uninitialized variable use (one for each // combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of // constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization. - qkv_cache_t qkv_cache[D*q_step] = {}; + qkv_cache_t qkv_cache[D*q_step]; // = {}; + //qkv_cache_t * qkv_cache; }; template struct FlashQKfp32 { - static_assert(D%F16::block_size == 0 && D <= 256); + static_assert(D%F16::block_size == 0 && D <= 576); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -15352,23 +17020,18 @@ struct FlashQKfp32 { #endif constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + static_assert(krem == 0); DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; for (int iq = 0; iq < q_step/nrc_q; ++iq) { for (int ik = 0; ik < k_step/nrc_k; ++ik) { mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); } - if constexpr (krem > 0) { - mul_mat_Qx_Qy_MxN_fa, QFT>(D, kh.block, kh.stride, k_step - krem, info); - } info.cur_y += nrc_q; } if constexpr (qrem > 0) { for (int ik = 0; ik < k_step/nrc_k; ++ik) { mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); } - if constexpr (krem > 0) { - mul_mat_Qx_Qy_MxN_fa, QFT>(D, kh.block, kh.stride, k_step - krem, info); - } } F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { @@ -15421,7 +17084,7 @@ struct FlashQKfp32 { constexpr int nrc_k = 8; #endif static_assert(k_step%nrc_k == 0); - int qrem = q_step - nrc_q*(q_step/nrc_q); + int qrem = nq - nrc_q*(nq/nrc_q); DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; for (int iq = 0; iq < nq/nrc_q; ++iq) { for (int ik = 0; ik < k_step/nrc_k; ++ik) { @@ -15471,7 +17134,7 @@ struct FlashQKfp32 { } } F16::Data vk[k_step/F16::block_size]; - for (int j = 0; j < q_step; ++j) { + for (int j = 0; j < nq; ++j) { fms.update_M_S(j, vk, mask + stride_m*j); } } @@ -15576,46 +17239,83 @@ struct FlashQKfp32 { MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); + MAKE_FUNCS(mul_mat_qX_1_q8_2_T, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx, 4); MAKE_FUNCS(mul_mat_qX_0_q8_0_T>) { + else if constexpr (std::is_same_v>) { +#ifdef __aarch64__ + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); +#else + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); +#ifdef HAVE_FANCY_SIMD + if constexpr (D%32 == 0 && k_step%8 == 0) { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq); + } else { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + } +#endif + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); +#endif + } + else if constexpr (std::is_same_v>) { #ifdef __aarch64__ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq); #else - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_2, nq); #endif } + else if constexpr (std::is_same_v>) { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq); + } else if constexpr (std::is_same_v>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); + MAKE_FUNCS(mul_mat_qX_1_q8_2_T>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); + MAKE_FUNCS(mul_mat_qX_1_q8_2_T>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_1_q8_1>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0(q_step); - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; for (int iq = 0; iq < q_step/nrc_q; ++iq) { mul_mat(D, kh.block, kh.stride, info, k_step); info.cur_y += nrc_q; @@ -15653,7 +17353,7 @@ struct FlashQKfp32 { static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m, const block_q8 * q, const char * mask, FlashMS& fms) { auto [mul_mat, nrc_q] = mul_mat_kernel(nq); - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; for (int iq = 0; iq < nq/nrc_q; ++iq) { mul_mat(D, kh.block, kh.stride, info, k_step); info.cur_y += nrc_q; @@ -15678,13 +17378,14 @@ struct FlashQKfp32 { } }; -template +template void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS& fms, - FlashQKV& fqkv, - const float * q, const char * mask, float * qkv) { + FlashQKV& fqkv, + const float * q, const char * mask, float * qkv, + float * M, float * S) { #ifdef __aarch64__ - float16_t q_f16[D*q_step]; + float16_t q_f16[Dk*q_step]; #endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { @@ -15697,7 +17398,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { #ifdef __aarch64__ - KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms); + KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); #endif @@ -15706,11 +17407,12 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); q += q_step*stride_q; mask += q_step*stride_m; qkv += q_step*stride_qkv; + if (M && S) { M += q_step; S += q_step; } } int n_left = nq1 - q_step*(nq1/q_step); if (n_left > 0) { @@ -15723,7 +17425,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { #ifdef __aarch64__ - KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms); + KQHelper::multiply_mask_kq(n_left, kh, Dk, stride_m, q_f16, mr, fms); #else KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); #endif @@ -15732,16 +17434,39 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); } } -template +template void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS& fms, - FlashQKV& fqkv, - const float * q, const char * mask, float * qkv) { - typename KHelper::block_q8 q8[q_step*(D/QK8_0)]; + FlashQKV& fqkv, + const float * q, const char * mask, float * qkv, + float * M, float * S, char * qptr) { + auto q8 = (typename KHelper::block_q8 *)qptr; + if constexpr (q_step > 1 && std::is_same_v>) { + if (nq1 == q_step) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8]; + HelperQ80R8 khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0)); + auto q8r = (typename HelperQ80R8::block_q8 *)qptr; + HelperQ80::convert(q_step, stride_q, q, q8r); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); + KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); + fqkv.accumulate_qkv(vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + return; + } + } #if FA_TIMING Perf perf(false); #endif @@ -15752,7 +17477,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80::convert(q_step, stride_q, q, q8); + HelperQ80::convert(q_step, stride_q, q, q8); #if FA_TIMING perf.accum_nolock(0, t1); #endif @@ -15775,22 +17500,23 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, } #if FA_TIMING t1 = Perf::cur_time(); - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); perf.accum_nolock(3, t1); #else - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); #endif q += q_step*stride_q; mask += q_step*stride_m; qkv += q_step*stride_qkv; + if (M && S) { M += q_step; S += q_step; } } int n_left = nq1 - q_step*(nq1/q_step); if (n_left > 0) { fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80::convert(n_left, stride_q, q, q8); + HelperQ80::convert(n_left, stride_q, q, q8); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms); @@ -15799,13 +17525,19 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); } #if FA_TIMING Perf::instance().add(perf); #endif } +char * get_q_storage(size_t size) { + thread_local std::vector q_storage; + if (q_storage.size() < size) q_storage.resize(size); + return q_storage.data(); +} + // Some of the methods in FlashAttn have two identical implementations that only differ by // one version using a loop over the template parameter q_step, while the other using a loop // over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot, @@ -15816,9 +17548,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, // rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to // process template parameter of such functions, but this would result in the compiler generating // q_step-1 versions of these functions for us, which I though was too much with q_step = 8. -template +template struct FlashAttn { - static_assert(D%F16::block_size == 0 && D <= 256); + static_assert(Dk%F16::block_size == 0 && Dk <= 576); + static_assert(Dv%F16::block_size == 0 && Dv <= 512); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -15826,36 +17559,66 @@ struct FlashAttn { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float * qkv) { - if constexpr (std::is_same_v> || std::is_same_v> || - std::is_same_v> || - std::is_same_v>) { - compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); - } - else if constexpr (std::is_same_v>) { - if (nq1 >= 8) { + const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { + if constexpr (std::is_same_v> || + std::is_same_v> || + std::is_same_v> || + std::is_same_v> || + std::is_same_v> || + std::is_same_v> || + std::is_same_v> || + std::is_same_v>) { + constexpr size_t kMaxOnStackSize = 576; + //auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); + auto q_size = q_step*(Dk/QK8_2*sizeof(block_q8_2)); + q_size = GGML_PAD(q_size, 64); + if (q_size > kMaxOnStackSize) { + auto qptr = get_q_storage(q_size); + if (nq1 >= 8) { + if constexpr (std::is_same_v>) { #if FA_TIMING - auto t1 = Perf::cur_time(); - HelperQ80R4 khr4(nk1, kh); - Perf::instance().accum(4, t1); + auto t1 = Perf::cur_time(); + HelperQ80R8 khr4(nk1, kh); + Perf::instance().accum(4, t1); #else - HelperQ80R4 khr4(nk1, kh); + HelperQ80R8 khr4(nk1, kh); #endif - compute_helper_q, VHelper, FlashQKfp32>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); - } else{ - compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + compute_helper_q, VHelper, FlashQKfp32>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + return; + + } + if constexpr (std::is_same_v>) { +#if FA_TIMING + auto t1 = Perf::cur_time(); + HelperQ8KVR8 khr4(nk1, kh); + Perf::instance().accum(4, t1); +#else + HelperQ8KVR8 khr4(nk1, kh); +#endif + compute_helper_q, VHelper, FlashQKfp32>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + return; + } + } + compute_helper_q>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + } - } else { - compute_helper>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + else { + typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; + compute_helper_q>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8); + } + } + else { + compute_helper>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } } - FlashMS fms; - FlashQKV fqkv; + FlashMS fms; + FlashQKV fqkv; }; @@ -15895,7 +17658,8 @@ struct HelperBF16 final : public BaseHelper { template struct FlashQKbf16 { - static_assert(D%32 == 0 && D <= 256); + //static_assert(D%32 == 0 && D <= 256); + static_assert(D%32 == 0 && D <= 576); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -16088,7 +17852,22 @@ struct FlashQKbf16 { static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, const char * mask, FlashMS& fms) { #endif - { + if constexpr (q_step == 1) { + __m512bh vq[D/32]; + __m512bh vk[D/32]; + __m256 sum[8]; + for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i)); + for (int l = 0; l < k_step; l += 8) { + for (int k = 0; k < 8; ++k) { + kh.load(l+k, vk); + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]); + sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1)); + } + _mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum)); + } + } + else { __m512bh qv[D/32]; if constexpr (D <= 128) { __m512bh vkh[D/4]; @@ -16188,9 +17967,12 @@ struct FlashQKbf16 { } }; -template +template struct FlashAttnBF16 { - static_assert(D%32 == 0 && D <= 256); + //static_assert(Dk%32 == 0 && Dk <= 256); + //static_assert(Dv%32 == 0 && Dv <= 256); + static_assert(Dk%32 == 0 && Dk <= 576); + static_assert(Dv%32 == 0 && Dv <= 512); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -16198,8 +17980,8 @@ struct FlashAttnBF16 { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float * qkv) { - ggml_bf16_t q_bf16[q_step*D]; + const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { + ggml_bf16_t q_bf16[q_step*Dk]; #if FA_TIMING Perf perf(false); #endif @@ -16210,7 +17992,7 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); - FlashQKbf16::convert(stride_q, q, q_bf16); + FlashQKbf16::convert(stride_q, q, q_bf16); #if FA_TIMING perf.accum_nolock(0, t1); #endif @@ -16218,13 +18000,13 @@ struct FlashAttnBF16 { for (int k1 = 0; k1 < nk1/k_step; ++k1) { #if FA_TIMING //t1 = Perf::cur_time(); - FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); + FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); //perf.accum_nolock(1, t1); t1 = Perf::cur_time(); fqkv.accumulate_qkv(vh, fms); perf.accum_nolock(3, t1); #else - FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); + FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(vh, fms); #endif kh.next_block(); @@ -16234,7 +18016,7 @@ struct FlashAttnBF16 { #if FA_TIMING t1 = Perf::cur_time(); #endif - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); #if FA_TIMING perf.accum_nolock(4, t1); #endif @@ -16248,161 +18030,203 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); - FlashQKbf16::convert(n_left, stride_q, q, q_bf16); + FlashQKbf16::convert(n_left, stride_q, q, q_bf16); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - FlashQKbf16::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); + FlashQKbf16::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(n_left, vh, fms); kh.next_block(); vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); } #if FA_TIMING Perf::instance().add(perf); #endif } - FlashMS fms; - FlashQKV fqkv; + FlashMS fms; + FlashQKV fqkv; }; #endif -template +template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float scale, float softcap, float * qkv) { + const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { - if (nk1 >= 256) { //4096) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; + if (nk1 >= 512) { + if (nq1 >= 128) { + int n_step = nq1/128; + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(128*n_step)) return; + } if (nq1 >= 64) { - FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - return; + int n_step = nq1/64; + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(64*n_step)) return; } if (nq1 >= 32) { - FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - return; + int n_step = nq1/32; + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(32*n_step)) return; } if (nq1 >= 16) { - FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - return; + int n_step = nq1/16; + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(16*n_step)) return; } } if (nq1 >= 8) { - FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + int n_step = nq1/8; + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(8*n_step)) return; } - else { - FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + else if (nq1 >= 4) { + int n_step = nq1/4; + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(4*n_step)) return; } + else if (nq1 >= 2) { + int n_step = nq1/2; + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(2*n_step)) return; + } + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } #ifdef __AVX512BF16__ -template +template inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv) { - HelperBF16 kh(k, stride_k); - HelperBF16 vh(v, stride_v); + float scale, float softcap, float * qkv, float * M, float * S) { + HelperBF16 kh(k, stride_k); + HelperBF16 vh(v, stride_v); if (nk1 >= 4096) { if (nq1 >= 64) { - FlashAttnBF16 fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } else if (nq1 >= 16) { - FlashAttnBF16 fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } } if (nq1 >= 8) { - FlashAttnBF16 fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { - FlashAttnBF16 fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } } #endif -template +template inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { switch (type_v) { case GGML_TYPE_F16: { - HelperF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperF16 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512BF16__ case GGML_TYPE_BF16: { - HelperBF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperBF16 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #endif case GGML_TYPE_Q8_0: { - HelperQ80 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ80 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + } break; + case GGML_TYPE_Q8_KV: { + HelperQ8KV vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q6_0: { - HelperQ60 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ60 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + } break; + case GGML_TYPE_Q4_0: { + HelperQ40 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS - case GGML_TYPE_Q4_0: { - HelperQ40 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); - } break; case GGML_TYPE_Q4_1: { - HelperQ41 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ41 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperIQ4nl vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #endif default: break; } } -template +template inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { switch (type_k) { case GGML_TYPE_F16: { - HelperF16 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperF16 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_0: { - HelperQ80 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ80 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + } break; + case GGML_TYPE_Q8_0_R8: { + HelperQ80R8 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + } break; + case GGML_TYPE_Q8_KV: { + HelperQ8KV kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q6_0: { - HelperQ60 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ60 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + } break; + case GGML_TYPE_Q4_0: { + HelperQ40 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS - case GGML_TYPE_Q4_0: { - HelperQ40 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); - } break; case GGML_TYPE_Q4_1: { - HelperQ41 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ41 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperIQ4nl kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; #endif default: break; @@ -16416,37 +18240,141 @@ inline bool flash_attn_is_supported(ggml_type type) { #endif #if GGML_IQK_FA_ALL_QUANTS if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || - type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; + type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q8_0_R8) return true; #else - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8 + || type == GGML_TYPE_Q4_0) return true; #endif return false; } + +template +inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, + int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, + const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; + if (nq1 >= 8) { + int n_step = nq1/8; + FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(8*n_step)) return; + } + if (nq1 >= 4) { + int n_step = nq1/4; + FlashAttn<576, 512, 4, step_k> fa(scale, softcap); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(4*n_step)) return; + } + if (nq1 >= 2) { + int n_step = nq1/2; + FlashAttn<576, 512, 2, step_k> fa(scale, softcap); + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(2*n_step)) return; + } + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } -bool iqk_flash_attn_noalibi(int int_type_k, // type of k - int int_type_v, // type of v - int D, // head size - int nq1, // number of columns in q - int nk1, // number of rows in k - int stride_q, // distance between q columns in bytes - int stride_k, // distance between k rows in bytes - int stride_v, // distance between v rows in bytes - int stride_m, // distance between mask rows (in bytes - int stride_qkv, // distance between rows in mask (in bytes) - const float * q, // q matrix. - const void * k, // k matrix. Assumed to be fp16, nq x nk elements - const void * v, // v matrix. Assumed to be fp16, nq x nk elements - const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - float scale, // scale applied before softmax - float softcap, // if > 0, a "soft-cap" operation is applied before softmax - float * qkv) { // v*softmax(scale*(k*q)) +template +inline bool iqk_deepseek_helper(ggml_type type_k, + int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, + const float * q, const char * k, const char * v, const char * mask, + float scale, float softcap, float * qkv, float * M, float * S) { + if (type_k == GGML_TYPE_Q8_0) { + HelperQ80<576, step_k> kh((const char *)k, stride_k); + HelperQ80<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + return true; + } + if (type_k == GGML_TYPE_Q8_0_R8) { + HelperQ80R8<576, step_k> kh((const char *)k, stride_k); + HelperQ80<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + return true; + } + if (type_k == GGML_TYPE_Q6_0) { + HelperQ60<576, step_k> kh((const char *)k, stride_k); + HelperQ60<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + return true; + } + if (type_k == GGML_TYPE_Q8_KV) { + HelperQ8KV<576, step_k> kh((const char *)k, stride_k); + HelperQ8KV<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + return true; + } + if (type_k == GGML_TYPE_F16) { + HelperF16<576, step_k> kh((const char *)k, stride_k); + HelperF16<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + return true; + } +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16) { + HelperBF16<576, step_k> kh((const char *)k, stride_k); + HelperBF16<512, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } else { + FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } + return true; + } +#endif + return false; +} + +} + +#include "iqk_flash_impl.h" + +bool iqk_flash_attn_impl(int int_type_k, // type of k + int int_type_v, // type of v + int Dk, // K head size + int Dv, // V head size + int nq1, // number of columns in q + int nk1, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv, // v*softmax(scale*(k*q)) + float * M, float * S) { + + if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 auto type_k = ggml_type(int_type_k); auto type_v = ggml_type(int_type_v); + + if (Dk == 576 && Dv == 512) { + GGML_ASSERT(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0)); + stride_q /= sizeof(float); // q stride as float + return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S); + } + if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; - if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 - if (D != 64 && D != 96 && D != 128 && D != 256) return false; + if (Dk != Dv && Dk != 192 && Dv != 128) return false; + if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false; + if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dk != 256) return false; auto ck = (const char *)k; auto cv = (const char *)v; @@ -16458,30 +18386,34 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_k == GGML_TYPE_BF16) { if (nk1%64 == 0) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 96: - iqk_flash_helper_T< 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 128: - iqk_flash_helper_T<128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 192: + iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } return true; } if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 96: - iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 128: - iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 192: + iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -16490,42 +18422,63 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k } #endif - if (nk1%64 == 0) { - switch (D) { + if (nk1%128 == 0) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; - // Disable until we fix accumulate_qkv for odd D/16 - //case 80: - // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 96: - iqk_flash_helper_T< 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; - // Disable until we fix accumulate_qkv for odd D/16 - //case 112: - // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 128: - iqk_flash_helper_T<128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 192: + iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } return true; } - switch (D) { + if (nk1%64 == 0) { + switch (Dk) { + case 64: + iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + // Disable until we fix accumulate_qkv for odd D/16 + //case 80: + // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 96: + iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + // Disable until we fix accumulate_qkv for odd D/16 + //case 112: + // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 128: + iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 192: + iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 256: + iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + default: + return false; + } + return true; + } + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 192: + iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -16535,32 +18488,28 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k #else // IQK_IMPLEMENT -bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { +extern "C" IQK_API bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { return false; } -bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long, +extern "C" IQK_API bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/, + long /*ne02*/, long /*ne03*/, long /*ne12*/, long /*ne13*/, + long /*nb02*/, long /*nb03*/, long /*nb12*/, long /*nb13*/, long /*nb2*/, long /*nb3*/, + int /*typeA*/, const void * /*A*/, long /*strideA*/, + int /*typeB*/, const void * /*B*/, long /*strideB*/, + float * /*C*/, long /*stride_C*/, int /*ith*/, int /*nth*/) { + return false; +} + +extern "C" IQK_API bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long, const void *, int, int) { return false; } -bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k - [[maybe_unused]] int int_type_v, // type of v - [[maybe_unused]] int D, // head size - [[maybe_unused]] int nq, // number of columns in q - [[maybe_unused]] int nk, // number of rows in k - [[maybe_unused]] int stride_q, // distance between q columns in bytes - [[maybe_unused]] int stride_k, // distance between k rows in bytes - [[maybe_unused]] int stride_v, // distance between v rows in bytes - [[maybe_unused]] int stride_m, // distance between mask rows (in bytes - [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes) - [[maybe_unused]] const float * q, // q matrix. - [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - [[maybe_unused]] float scale, // scale applied before softmax - [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax - [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q)) +extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/, int /*unary_op*/, + int /*typeA*/, const void * /*Aup*/, const void * /*Agate*/, long /*strideA*/, + int /*typeB*/, const void * /*B*/, long /*strideB*/, + float * /*C*/, long /*nb1*/, long /*nb2*/, const void * /*vrow_mapping*/, int /*ith*/, int /*nth*/) { return false; } diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 6e27c614..6f44af52 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -7,37 +7,59 @@ #pragma once #include #include +#include "iqk_config.h" #ifdef __cplusplus extern "C" { #endif -bool iqk_mul_mat(long Nx, long Ny, long ne00, +IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth); -bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, +IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, + long ne02, long ne03, long ne12, long ne13, + long nb02, long nb03, long nb12, long nb13, long nb2, long nb3, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, + float * C, long stride_C, int ith, int nth); + +IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); -bool iqk_flash_attn_noalibi(int type_k, // type of k +IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, + int typeA, const void * Aup, const void * Agate, long strideA, + int typeB, const void * B, long strideB, + float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); + +typedef void (*barrier_t) (void *); + +IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, + int neq3, int neq2, long nbq3, long nbq2, + int nek3, int nek2, long nbk3, long nbk2, + int nev3, int nev2, long nbv3, long nbv2, + int ne2, int ne1, long nb1, + int type_k, // type of k int type_v, // type of v - int D, // head size + int Dk, // K head size + int Dv, // V head size int nq, // number of columns in q int nk, // number of rows in k int stride_q, // distance between q columns in bytes int stride_k, // distance between k rows in bytes int stride_v, // distance between v rows in bytes int stride_m, // distance between mask rows (in bytes - int stride_qkv, // distance between rows in mask (in bytes) - const float * q, // q matrix. + const void * q, // q matrix. const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax - float * qkv); // v*softmax(scale*(k*q)) + float * qkv, // v*softmax(scale*(k*q)) + void * work_buffer, barrier_t barrier, void * barrier_data, + int ith, int nth); #ifdef __cplusplus } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 5f5af45a..9d543506 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -12,6 +12,7 @@ #define GGML_COMMON_IMPL_C #include "ggml-common.h" #include "iqk_quantize.h" +#include "iqk_config.h" #include #include @@ -26,6 +27,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -45,15 +47,6 @@ constexpr int popcount(uint32_t x) { return __builtin_popcount(x); } constexpr int popcount(uint64_t x) { return __builtin_popcountll(x); } #endif -#if defined __x86_64__ -#if defined HAVE_FANCY_SIMD - #undef HAVE_FANCY_SIMD -#endif -#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) - #define HAVE_FANCY_SIMD -#endif -#endif - namespace { inline int nearest_int(float fval) { @@ -195,6 +188,34 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i } +void iqk_quantize_any(int from_type, int to_type, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, + uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3, + const void * x, void * y, void * work_buffer, + to_float_t to_float, from_float_t from_float, int ith, int nth) { + auto type_x = ggml_type(from_type); + GGML_ASSERT(ggml_type_size(type_x) == nb0); + auto type_y = ggml_type(to_type); + auto row_size_y = ggml_row_size(type_y, ne0); + int64_t nrows = ne1*ne2*ne3; + int64_t nrows_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = nrows_per_thread*ith; + if (first_row >= nrows) return; + int64_t last_row = std::min(first_row + nrows_per_thread, nrows); + for (int64_t row = first_row; row < last_row; ++row) { + int64_t i3 = row/(ne1*ne2); + int64_t i2 = (row - i3*ne1*ne2)/ne1; + int64_t i1 = row - i3*ne1*ne2 - i2*ne1; + const char * cx = (const char *)x + i1*nb1 + i2*nb2 + i3*nb3; + // TODO: special case common types such as f16, q8_0 + // (although the performance gains may be too small to justify the added complexity) + to_float((const void *)cx, (float *)work_buffer, ne0); + auto cy = (char *)y + (i3*ne1*ne2 + i2*ne1 + i1)*row_size_y; + from_float((const float *)work_buffer, (void *)cy, ne0); + } +} + + size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { IQ1BNQuantizer iq1bn; auto row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row); @@ -779,13 +800,14 @@ void quantize_row_q8_0_x4(const float * x, void * vy, int64_t k) { #endif } -void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { +namespace { +template +void quantize_row_q8_1_x4_T(const float * x, Block * y, int64_t k) { assert(k % QK8_1 == 0); const int nb = k / QK8_1; const int nb4 = 4*(nb/4); - block_q8_1 * y = (block_q8_1 *)vy; - block_q8_1_x4 * y4 = (block_q8_1_x4 *)vy; + Block_x4 * y4 = (Block_x4 *)y; #if defined(__aarch64__) for (int i = 0; i < nb; i++) { int i4 = i/4, ir = i%4; @@ -832,10 +854,18 @@ void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { accv = vaddq_s32(accv, vi); } - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + if constexpr (std::is_same_v) { + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } else { + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } } else { - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_BF16(d * vaddvq_s32(accv)).bits; + } else { + y[i].s = GGML_FP32_TO_BF16(d * vaddvq_s32(accv)).bits; + } } } #else @@ -861,13 +891,25 @@ void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { const float max_scalar = _mm_cvtss_f32( max4 ); // Quantize these floats - const float d = max_scalar / 127.f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + float d = max_scalar / 127.f; + if constexpr (std::is_same_v) { + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } } else { - y[i].d = GGML_FP32_TO_FP16(d); + if (i < nb4) { + auto t = GGML_FP32_TO_BF16(d); + y4[i4].d[ir] = t.bits; + d = ggml_bf16_to_fp32(t); + } else { + auto t = GGML_FP32_TO_BF16(d); + y[i].d = t.bits; + d = ggml_bf16_to_fp32(t); + } } - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const float id = d > 0 ? 1/d : 0.f; const __m256 mul = _mm256_set1_ps( id ); // Apply the multiplier @@ -889,10 +931,19 @@ void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { __m256i i3 = _mm256_cvtps_epi32( v3 ); // Compute the sum of the quants and set y[i].s - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + int isum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + if constexpr (std::is_same_v) { + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * isum); + } else { + y[i].s = GGML_FP32_TO_FP16(d * isum); + } } else { - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_BF16(d * isum).bits; + } else { + y[i].s = GGML_FP32_TO_BF16(d * isum).bits; + } } // Convert int32 to int16 @@ -915,6 +966,15 @@ void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { } #endif } +} + +void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { + quantize_row_q8_1_x4_T(x, (block_q8_1 *)vy, k); +} + +void quantize_row_q8_2_x4(const float * x, void * vy, int64_t k) { + quantize_row_q8_1_x4_T(x, (block_q8_2 *)vy, k); +} // // ============================================== iq2_K @@ -1497,12 +1557,13 @@ inline int best_index_iq3nl(const int8_t * values, float x) { static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) { - const int ntry = 5; + constexpr int ntry = 3; block_iq3_k * y = (block_iq3_k *)vy; float scales[QK_K/16]; float weight[16]; + uint8_t L[16]; const int8_t * shifted_values = iq3nl_values + 8; @@ -1562,7 +1623,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c } bool is_shifted = false; for (int itry = -ntry; itry <= ntry; ++itry) { - id = (itry + iq3nl_values[0])/max; + id = (2*itry + iq3nl_values[0])/max; sumqx_p = sumq2_p = 0; sumqx_m = sumq2_m = 0; for (int j = 0; j < 16; ++j) { @@ -1583,7 +1644,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; } - id = (itry + shifted_values[0])/max; + id = (2*itry + shifted_values[0])/max; sumqx_p = sumq2_p = 0; sumqx_m = sumq2_m = 0; for (int j = 0; j < 16; ++j) { @@ -1605,20 +1666,55 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; } } - if (d) { - const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values; - float sumqx = 0, sumq2 = 0; - id = 1/d; + if (!d) { + scales[ib] = 0; continue; + } + + const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values; + float sumqx = 0, sumq2 = 0; + id = 1/d; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(block_values, al); + L[j] = l; + float q = block_values[l]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0) d = sumqx/sumq2; + + float best_d = d; + for (int iter = 0; iter < 128; ++iter) { + float gmax = 0; + int best_j = -1, dir = 0; for (int j = 0; j < 16; ++j) { float w = weight[j]; - float al = id*xb[j]; - int l = best_index_iq3nl(block_values, al); - float q = block_values[l]; - sumqx += w*q*xb[j]; - sumq2 += w*q*q; + float g = d * w * (xb[j] - d*block_values[L[j]]); + if (g > 0 && L[j] < 7) { + if (g > gmax) { + gmax = g; best_j = j; dir = 1; + } + } + else if (g < 0 && L[j] > 0) { + if (-g > gmax) { + gmax = -g; best_j = j; dir = -1; + } + } } - if (sumq2 > 0) d = sumqx/sumq2; + if (best_j < 0) break; + + float w = weight[best_j]; + sumqx += w*xb[best_j]*(block_values[L[best_j]+dir] - block_values[L[best_j]]); + sumq2 += w*(block_values[L[best_j]+dir]*block_values[L[best_j]+dir] - block_values[L[best_j]]*block_values[L[best_j]]); + L[best_j] += dir; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + best_d = sumqx/sumq2; best = best_d*sumqx; + } + else if (iter > 8) break; + } + scales[ib] = d; if (is_shifted) extra |= (1 << ib); @@ -2969,6 +3065,103 @@ void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) { } #endif } +// TODO: merge this with the above template +void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) { + assert(k % 32 == 0); + auto dptr = (float *)vy; + auto q8 = (int8_t *)(dptr + 2); +#ifdef __AVX2__ + const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + __m256 maxAbs = _mm256_setzero_ps(); + for (int ib = 0; ib < k/8; ++ib) { + const __m256 v = _mm256_loadu_ps(x + 8*ib); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v)); + } + const float maxScalar = hmax_f32_8(maxAbs); + if (!maxScalar) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = maxScalar / 127.f; + auto mul = _mm256_set1_ps(1/dptr[0]); + auto isum = _mm256_setzero_si256(); + for (int i = 0; i < k/32; i++) { + __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 0)); + __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 8)); + __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 16)); + __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 24)); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST); + v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST); + __m256i i0 = _mm256_cvtps_epi32(v0); + __m256i i1 = _mm256_cvtps_epi32(v1); + __m256i i2 = _mm256_cvtps_epi32(v2); + __m256i i3 = _mm256_cvtps_epi32(v3); + isum = _mm256_add_epi32(isum, _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + i0 = _mm256_packs_epi16( i0, i2 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + _mm256_storeu_si256((__m256i *)q8, i0); + q8 += 32; + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = hsum_i32_8(isum); +#elif defined __ARM_NEON + int32x4_t ival[8]; + auto vmax = vdupq_n_f32(0.f); + for (int j = 0; j < k; j += 4) { + vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(x + j))); + } + auto smax = vmaxvq_f32(vmax); + if (!smax) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = smax/127; + auto vid = vdupq_n_f32(1/dptr[0]); + auto isum = vdupq_n_s32(0); + for (int ib = 0; ib < k/32; ++ib) { + auto xb = x + 32*ib; + for (int k = 0; k < 8; ++k) { + auto val = vld1q_f32(xb + 4*k); + ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid)); + isum = vaddq_s32(isum, ival[k]); + } + for (int k = 0; k < 4; ++k) { + auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1])); + vst1_s8(q8, vmovn_s16(i16)); + q8 += 8; + } + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = vaddvq_s32(isum); +#else + float amax = 0; + for (int j = 0; j < k; ++j) { + float ax = std::abs(x[j]); + amax = std::max(amax, ax); + } + if (!amax) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = amax/127; + float id = 1/dptr[0]; + int isum = 0; + for (int i = 0; i < k; i++) { + q8[i] = nearest_int(id*x[i]); + isum += q8[i]; + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = isum; +#endif +} } void quantize_row_q8_K128(const float * x, void * vy, int64_t k) { @@ -3888,7 +4081,7 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8 #ifdef HAVE_FANCY_SIMD static void modify_q8_0_r8(int64_t k, char * cy) { - auto y = (block_iq4_nl_r8 *)cy; + auto y = (block_q8_0_r8 *)cy; int nb = k/(32*8); for (int ib = 0; ib < nb; ++ib) { for (int l = 0; l < 4; ++l) { @@ -5414,6 +5607,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b GGML_UNUSED(by); } +// +// ========================================= q8_KV_r8 +// + +void quantize_row_q8_KV_r8_ref(const float * x, void * y, int64_t k) { + quantize_q8_KV_r8(x, y, 8, k/8, nullptr); +} + +void quantize_row_q8_KV_r8(const float * x, void * y, int64_t k) { + quantize_q8_KV_r8(x, y, 8, k/8, nullptr); +} + +static void repack_q8_KV(int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%16 == 0); + auto row_size_x = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto row_size_y = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row); + const int8_t * x8[8]; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = 0; row < nrows; row += 8) { + auto dy = (float *)cy; + auto qy = (int8_t *)(dy + 8); + for (int k = 0; k < 8; ++k) { + auto dx = (const float *)(cx + k*row_size_x); + dy[k] = dx[0]; + x8[k] = (const int8_t *)(dx + 2); + } + for (int ib = 0; ib < n_per_row/16; ++ib) { +#ifdef __AVX2__ +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); +#ifdef HAVE_FANCY_SIMD + if (online) { + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + } +#endif + _mm256_storeu_si256((__m256i *)qy + 4*ib+0, m0); + _mm256_storeu_si256((__m256i *)qy + 4*ib+1, m1); + _mm256_storeu_si256((__m256i *)qy + 4*ib+2, m2); + _mm256_storeu_si256((__m256i *)qy + 4*ib+3, m3); +#elif defined __ARM_NEON + m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib); + m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib); + m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib); + m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(qy + 0 + 128*ib, m0); + vst1q_s8_x2(qy + 32 + 128*ib, m1); + vst1q_s8_x2(qy + 64 + 128*ib, m2); + vst1q_s8_x2(qy + 96 + 128*ib, m3); +#else + // TODO + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + + } + cx += 8*row_size_x; + cy += online ? 8*row_size_x : 8*row_size_y; + //So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row + } +} +#ifdef HAVE_FANCY_SIMD +static void modify_q8_KV_r8(int64_t k, char * cy) { + int8_t * q8 = (int8_t *)(cy + 8*sizeof(float)); + for (int j = 0; j < k; ++j) q8[j] += 127; +} +#endif + +size_t quantize_q8_KV_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%16 == 0); + char * qcur = (char *)dst; + auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row); + std::vector qtmp(8*row_size_0); + for (int row = 0; row < nrows; row += 8) { + quantize_q8_KV(src, (void *)qtmp.data(), 8, n_per_row, imatrix); + repack_q8_KV(8, n_per_row, qtmp.data(), qcur, false); + qcur += 8*row_size_1; + src += 8*n_per_row; + } + return nrows*row_size_1; +} + +void dequantize_row_q8_KV_r8(const void * vx, float * y, int64_t k) { + auto n_per_row = k/8; + float * y8[8]; + for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k; + auto dptr = (const float *)vx; + auto q8 = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < n_per_row/16; ++ib) { + for (int k = 0; k < 8; ++k) { + for (int l = 0; l < 4; ++l) { + for (int i = 0; i < 4; ++i) y8[k][16*ib + 4*l + i] = dptr[k] * q8[128*ib + 32*l + 4*k + i]; + } + } + } +} + +void vec_dot_q8_KV_r8_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV_R8, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + // // ========================================= bf16_r4 // @@ -6237,8 +6574,8 @@ size_t quantize_iq1_s_r4(const float * src, void * dst, int64_t nrows, int64_t n auto xb = src + k*n_per_row + kBlockSize*ibl; float sumx2 = 0; for (int j = 0; j < kBlockSize; ++j) sumx2 += xb[j]*xb[j]; - if (!sumx2) { - printf("Found block with all zeros\n"); + if (sumx2 < 1e-14f) { + //printf("Found block with all zeros\n"); // all zero int ind = 1029; // this is the grid entry with all zeros scales[4*ibl+k] = 0; @@ -6366,20 +6703,49 @@ size_t quantize_iq1_m_r4(const float * src, void * dst, int64_t nrows, int64_t n for (int ibl = 0; ibl < nblock; ++ibl) { for (int k = 0; k < 4; ++k) { auto xb = src + k*n_per_row + kBlockSize*ibl; - float sumx2 = 0; - for (int j = 0; j < kBlockSize; ++j) sumx2 += xb[j]*xb[j]; - if (!sumx2) { + float sumx2l = 0, sumx2h = 0; + for (int j = 0; j < kBlockSize/2; ++j) sumx2l += xb[j]*xb[j]; + for (int j = kBlockSize/2; j < kBlockSize; ++j) sumx2h += xb[j]*xb[j]; + float sumx2 = sumx2l + sumx2h; + if (sumx2 < 1e-14f) { scales[8*ibl+2*k+0] = scales[8*ibl+2*k+1] = 0; + int ind = 1029; + for (int i = 0; i < 4; ++i) { + y[ibl].qs[4*i + k] = ind & 255; + } + for (int i = 0; i < 2; ++i) { + y[ibl].qh[4*i+k] = (ind >> 8) | ((ind >> 8) << 4); + } continue; } float sigma2 = 1.5f*sumx2/kBlockSize; if (imatrix) { for (int j = 0; j < kBlockSize; ++j) weight[j] = imatrix[kBlockSize*ibl + j]*sqrt(sigma2 + xb[j]*xb[j]); + float sumwx = 0; + for (int j = 0; j < kBlockSize/2; ++j) sumwx += weight[j]*std::abs(xb[j]); + if (sumwx < 1e-14f) { + for (int j = 0; j < kBlockSize/2; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); + } + sumwx = 0; + for (int j = kBlockSize/2; j < kBlockSize; ++j) sumwx += weight[j]*std::abs(xb[j]); + if (sumwx < 1e-14) { + for (int j = kBlockSize/2; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); + } } else { for (int j = 0; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); } - iq1m_process_1block(xb+ 0, weight+ 0, L, scales.data() + 8*ibl + 2*k+0, index+0, &shift1, pairs); - iq1m_process_1block(xb+16, weight+16, L, scales.data() + 8*ibl + 2*k+1, index+2, &shift2, pairs); + if (sumx2l > 1e-14f) { + iq1m_process_1block(xb+ 0, weight+ 0, L, scales.data() + 8*ibl + 2*k+0, index+0, &shift1, pairs); + } else { + scales[8*ibl+2*k+0] = 0; + index[0] = index[1] = 1029; + } + if (sumx2h > 1e-14f) { + iq1m_process_1block(xb+16, weight+16, L, scales.data() + 8*ibl + 2*k+1, index+2, &shift2, pairs); + } else { + scales[8*ibl+2*k+1] = 0; + index[2] = index[3] = 1029; + } max[k] = std::max(max[k], std::max(scales[8*ibl+2*k+0], scales[8*ibl+2*k+1])); for (int i = 0; i < 4; ++i) { y[ibl].qs[4*i + k] = index[i] & 255; @@ -6452,6 +6818,47 @@ void vec_dot_iq1_m_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } +void quantize_row_q8_KV(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_KV(x, vy, k); +} + +void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) { + quantize_row_q8_KV(x, y, k); +} + +size_t quantize_q8_KV(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + (void)imatrix; + auto row_size = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto q = (char *)dst; + for (int row = 0; row < nrows; ++row) { + quantize_row_q8_KV(src, q, n_per_row); + src += n_per_row; + q += row_size; + } + return row_size*nrows; +} + +void dequantize_row_q8_KV(const void * x, float * y, int64_t k) { + auto dptr = (const float *)x; + float d = dptr[0]; + auto q8 = (const int8_t *)(dptr + 2); + for (int j = 0; j < k; ++j) y[j] = d * q8[j]; +} + +void vec_dot_q8_KV_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + + //================================================ namespace { @@ -6466,22 +6873,42 @@ struct Modify { modify_func_t mod_func; int nrows; }; -} - -bool iqk_modify_tensor(struct ggml_tensor * tensor) { +const Modify * get_modify_info(ggml_type type) { static const std::unordered_map k_mod_map = { #ifdef __ARM_NEON { GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} }, #endif #ifdef HAVE_FANCY_SIMD - { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} }, - { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, + { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} }, + { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, + { GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} }, #endif }; - auto it = k_mod_map.find(tensor->type); - if (it == k_mod_map.end()) return false; + auto it = k_mod_map.find(type); + return it != k_mod_map.end() ? &it->second : nullptr; +} +bool is_forbidden_tensor(const std::string& name) { + static const std::string kTokenEmbd{"token_embd.weight"}; + if (name == kTokenEmbd) return true; + //if (auto pos = name.find("attn_kv_b.weight"); pos != std::string::npos) return true; + return false; +} +} - auto& m = it->second; +bool iqk_should_modify_tensor([[maybe_unused]] const struct ggml_tensor * tensor) { + return false; + //if (is_forbidden_tensor(tensor->name)) return false; + //auto mptr = get_modify_info(tensor->type); + //return mptr ? true : false; +} + +bool iqk_modify_tensor(struct ggml_tensor * tensor) { + return false; + auto mptr = get_modify_info(tensor->type); + if (!mptr) return false; + if (is_forbidden_tensor(std::string{tensor->name})) return false; + + auto& m = *mptr; int nrows = ggml_nrows(tensor); int nchunks = nrows/m.nrows; int max_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); @@ -6504,12 +6931,8 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) { return true; } -void iqk_repack_tensor(struct ggml_tensor * tensor) { - constexpr int kChunk = 8; - if (!tensor) return; - if (!ggml_is_contiguous(tensor)) return; - if (strncmp(tensor->name, "token_embd.weight", GGML_MAX_NAME) == 0) return; - if (tensor->ne[1] % 4 || tensor->ne[2]*tensor->ne[3] > 1) return; +namespace { +const Repack * get_repack_info(ggml_type type) { static const std::unordered_map k_map = { { GGML_TYPE_IQ2_K, { GGML_TYPE_IQ2_K_R4, 4, (Repack::repack_func)repack_iq2_k} }, { GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} }, @@ -6534,20 +6957,41 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} }, { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} }, + { GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8, (Repack::repack_func)repack_q8_KV} }, #ifdef __AVX512BF16__ { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16}}, { GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16} }, #endif }; + auto it = k_map.find(type); + return it != k_map.end() ? &it->second : nullptr; +} +} - auto it = k_map.find(tensor->type); - if (it == k_map.end()) return; - if (tensor->ne[1] % it->second.num_rows) return; +int iqk_repacked_type(const struct ggml_tensor * tensor) { + if (!ggml_is_contiguous(tensor)) return (int)tensor->type; + if (is_forbidden_tensor(tensor->name)) return (int)tensor->type; + auto rptr = get_repack_info(tensor->type); + return rptr && tensor->ne[1] % rptr->num_rows == 0 ? (int)rptr->new_type : (int)tensor->type; +} - auto& r = it->second; +void iqk_repack_tensor(struct ggml_tensor * tensor) { + constexpr int kChunk = 8; + if (!tensor) return; + if (!ggml_is_contiguous(tensor)) return; + if (is_forbidden_tensor(tensor->name)) return; + if (tensor->ne[1] % 4) return; + + auto rptr = get_repack_info(tensor->type); + if (!rptr) return; + if (tensor->ne[1] % rptr->num_rows) return; + + auto& r = *rptr; + + auto nrows = ggml_nrows(tensor); int max_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); - int num_chunks = (tensor->ne[1] + kChunk*r.num_rows - 1)/(kChunk*r.num_rows); + int num_chunks = (nrows + kChunk*r.num_rows - 1)/(kChunk*r.num_rows); int nthread = std::min(num_chunks, max_thread); //printf("%s(%s): %s -> %s. %d rows, %d chunks, %d threads\n", __func__, tensor->name, ggml_type_name(tensor->type), ggml_type_name(r.new_type), @@ -6555,7 +6999,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { std::atomic counter(0);; auto compute = [&counter, &r, tensor, num_chunks, chunkSize = kChunk] () { - int nrows = tensor->ne[1]; + int nrows = ggml_nrows(tensor); int n_per_row = tensor->ne[0]; auto row_size = ggml_row_size(tensor->type, n_per_row); std::vector qtmp(r.num_rows*row_size); @@ -6567,7 +7011,8 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { int last_row = std::min(first_row + chunkSize*r.num_rows, nrows); for (int row = first_row; row < last_row; row += r.num_rows) { std::memcpy(qtmp.data(), data + row*row_size, r.num_rows*row_size); - r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size, true); + //r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size, true); + r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size, false); } } }; diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 8299ec74..93034ac0 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -1,3 +1,9 @@ +// +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #include @@ -235,6 +241,18 @@ size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds void dequantize_row_q8_k_r8(const block_q8_k_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q8_k_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_q8_KV_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_KV(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_KV(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_KV(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_KV_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q8_KV_r8_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_KV_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_KV_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_KV_r8(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_KV_r8_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -244,6 +262,7 @@ void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_2_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); @@ -251,9 +270,20 @@ void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT d void iqk_repack_tensor(struct ggml_tensor * tensor); bool iqk_modify_tensor(struct ggml_tensor * tensor); +int iqk_repacked_type(const struct ggml_tensor * tensor); // int instead of ggml_type so we don't need to include ggml.h +bool iqk_should_modify_tensor(const struct ggml_tensor * tensor); + // So we can re-pack Microsoft's BitNet I2_S quants void dequantize_row_ms_i2s(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +typedef void (*to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +typedef void (*from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void iqk_quantize_any(int from_type, int to_type, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, + uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3, + const void * GGML_RESTRICT x, void * GGML_RESTRICT y, void * work_buffer, + to_float_t to_float, from_float_t from_float, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 90d5efec..6819979f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -181,6 +181,7 @@ class GGUFType: class MODEL_ARCH(IntEnum): LLAMA = auto() + DECI = auto() FALCON = auto() BAICHUAN = auto() GROK = auto() @@ -198,6 +199,8 @@ class MODEL_ARCH(IntEnum): QWEN = auto() QWEN2 = auto() QWEN2MOE = auto() + QWEN3 = auto() + QWEN3MOE = auto() PHI2 = auto() PHI3 = auto() PLAMO = auto() @@ -207,6 +210,7 @@ class MODEL_ARCH(IntEnum): MINICPM = auto() GEMMA = auto() GEMMA2 = auto() + GEMMA3 = auto() STARCODER2 = auto() MAMBA = auto() XVERSE = auto() @@ -218,6 +222,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK2 = auto() CHATGLM = auto() BITNET = auto() + BITNET_25 = auto() T5 = auto() T5ENCODER = auto() JAIS = auto() @@ -274,6 +279,8 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -310,6 +317,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.DECI: "deci", MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.BAICHUAN: "baichuan", MODEL_ARCH.GROK: "grok", @@ -327,6 +335,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.QWEN2: "qwen2", MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN3: "qwen3", + MODEL_ARCH.QWEN3MOE: "qwen3moe", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PLAMO: "plamo", @@ -336,6 +346,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MINICPM: "minicpm", MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.GEMMA2: "gemma2", + MODEL_ARCH.GEMMA3: "gemma3", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.XVERSE: "xverse", @@ -347,6 +358,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.BITNET_25: "bitnet-25", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", @@ -403,6 +415,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -458,6 +472,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.DECI: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.GROK: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -693,6 +727,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.QWEN3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN3MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -738,6 +806,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, @@ -967,6 +1037,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, @@ -1011,6 +1083,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_SUB_NORM, MODEL_TENSOR.FFN_SUB_NORM, ], + MODEL_ARCH.BITNET_25: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.ATTN_SUB_NORM, + MODEL_TENSOR.FFN_SUB_NORM, + ], MODEL_ARCH.T5: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, @@ -1079,6 +1173,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DECI: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], MODEL_ARCH.BAICHUAN: [ MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, @@ -1139,47 +1237,86 @@ class PoolingType(IntEnum): class GGMLQuantizationType(IntEnum): - F32 = 0 - F16 = 1 - Q4_0 = 2 - Q4_1 = 3 - Q5_0 = 6 - Q5_1 = 7 - Q8_0 = 8 - Q8_1 = 9 - Q2_K = 10 - Q3_K = 11 - Q4_K = 12 - Q5_K = 13 - Q6_K = 14 - Q8_K = 15 - IQ2_XXS = 16 - IQ2_XS = 17 - IQ3_XXS = 18 - IQ1_S = 19 - IQ4_NL = 20 - IQ3_S = 21 - IQ2_S = 22 - IQ4_XS = 23 - I8 = 24 - I16 = 25 - I32 = 26 - I64 = 27 - F64 = 28 - IQ1_M = 29 - BF16 = 30 - Q4_0_4_4 = 31 - Q4_0_4_8 = 32 - Q4_0_8_8 = 33 - IQ1_BN = 34, - IQ2_BN = 35, - Q8_K64 = 36, - IQ2_K = 37, - IQ3_K = 38, - IQ4_K = 39, - IQ5_K = 40, - IQ6_K = 41, - IQ2_TN = 42, + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 + IQ2_XXS = 16 + IQ2_XS = 17 + IQ3_XXS = 18 + IQ1_S = 19 + IQ4_NL = 20 + IQ3_S = 21 + IQ2_S = 22 + IQ4_XS = 23 + I8 = 24 + I16 = 25 + I32 = 26 + I64 = 27 + F64 = 28 + IQ1_M = 29 + BF16 = 30 + Q4_0_4_4 = 31 + Q4_0_4_8 = 32 + Q4_0_8_8 = 33 + I2_S = 36 + Q8_0_X4 = 97 + Q8_1_X4 = 98 + Q8_2_X4 = 99 + Q6_0 = 133 + IQ1_BN = 134 + IQ2_BN = 135 + Q8_K64 = 136 + IQ2_K = 137 + IQ3_K = 138 + IQ4_K = 139 + IQ5_K = 140 + IQ6_K = 141 + IQ4_KS = 144 + IQ2_KS = 145 + IQ4_KSS = 146 + Q8_K16 = 147 + Q8_K32 = 148 + Q8_KR8 = 149 + Q8_K128 = 150 + Q8_KV = 151 + Q4_0_R8 = 202 + Q5_0_R4 = 206 + Q8_0_R8 = 208 + Q2_K_R4 = 210 + Q3_K_R4 = 211 + Q4_K_R4 = 212 + Q5_K_R4 = 213 + Q6_K_R4 = 214 + IQ2_XXS_R4= 216 + IQ2_XS_R4 = 217 + IQ3_XXS_R4= 218 + IQ1_S_R4 = 219 + IQ4_NL_R4 = 220 + IQ3_S_R4 = 221 + IQ2_S_R4 = 222 + IQ4_XS_R8 = 223 + IQ1_M_R4 = 229 + BF16_R16 = 230 + Q6_0_R4 = 233 + IQ2_BN_R4 = 335 + IQ2_K_R4 = 337 + IQ3_K_R4 = 338 + IQ4_K_R4 = 339 + IQ5_K_R4 = 340 + IQ4_KS_R4 = 344 + Q8_KV_R8 = 398 + Q8_K_R8 = 399 class ExpertGatingFuncType(IntEnum): @@ -1193,50 +1330,71 @@ class ExpertGatingFuncType(IntEnum): # from llama_ftype in llama.h # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. class LlamaFileType(IntEnum): - ALL_F32 = 0 - MOSTLY_F16 = 1 # except 1d tensors - MOSTLY_Q4_0 = 2 # except 1d tensors - MOSTLY_Q4_1 = 3 # except 1d tensors - # MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 - # MOSTLY_Q4_2 = 5 # support has been removed - # MOSTLY_Q4_3 = 6 # support has been removed - MOSTLY_Q8_0 = 7 # except 1d tensors - MOSTLY_Q5_0 = 8 # except 1d tensors - MOSTLY_Q5_1 = 9 # except 1d tensors - MOSTLY_Q2_K = 10 # except 1d tensors - MOSTLY_Q3_K_S = 11 # except 1d tensors - MOSTLY_Q3_K_M = 12 # except 1d tensors - MOSTLY_Q3_K_L = 13 # except 1d tensors - MOSTLY_Q4_K_S = 14 # except 1d tensors - MOSTLY_Q4_K_M = 15 # except 1d tensors - MOSTLY_Q5_K_S = 16 # except 1d tensors - MOSTLY_Q5_K_M = 17 # except 1d tensors - MOSTLY_Q6_K = 18 # except 1d tensors - MOSTLY_IQ2_XXS = 19 # except 1d tensors - MOSTLY_IQ2_XS = 20 # except 1d tensors - MOSTLY_Q2_K_S = 21 # except 1d tensors - MOSTLY_IQ3_XS = 22 # except 1d tensors - MOSTLY_IQ3_XXS = 23 # except 1d tensors - MOSTLY_IQ1_S = 24 # except 1d tensors - MOSTLY_IQ4_NL = 25 # except 1d tensors - MOSTLY_IQ3_S = 26 # except 1d tensors - MOSTLY_IQ3_M = 27 # except 1d tensors - MOSTLY_IQ2_S = 28 # except 1d tensors - MOSTLY_IQ2_M = 29 # except 1d tensors - MOSTLY_IQ4_XS = 30 # except 1d tensors - MOSTLY_IQ1_M = 31 # except 1d tensors - MOSTLY_BF16 = 32 # except 1d tensors - MOSTLY_Q4_0_4_4 = 33 # except 1d tensors - MOSTLY_Q4_0_4_8 = 34 # except 1d tensors - MOSTLY_Q4_0_8_8 = 35 # except 1d tensors - MOSTLY_IQ1_BN = 36, # except 1d tensors - MOSTLY_IQ2_BN = 37, # except 1d tensors - MOSTLY_IQ2_K = 38, # except 1d tensors - MOSTLY_IQ3_K = 39, # except 1d tensors - MOSTLY_IQ4_K = 40, # except 1d tensors - MOSTLY_IQ5_K = 41, # except 1d tensors - MOSTLY_IQ6_K = 42, # except 1d tensors - MOSTLY_IQ2_TN = 43, # except 1d tensors + ALL_F32 = 0 + MOSTLY_F16 = 1 #except 1d tensors + MOSTLY_Q4_0 = 2 #except 1d tensors + MOSTLY_Q4_1 = 3 #except 1d tensors + MOSTLY_Q4_1_SOME_F16 = 4 #tok_embeddings.weight and output.weight are F16 + MOSTLY_Q8_0 = 7 #except 1d tensors + MOSTLY_Q5_0 = 8 #except 1d tensors + MOSTLY_Q5_1 = 9 #except 1d tensors + MOSTLY_Q2_K = 10 #except 1d tensors + MOSTLY_Q3_K = 11 #except 1d tensors + MOSTLY_Q4_K = 12 #except 1d tensors + MOSTLY_Q5_K = 13 #except 1d tensors + MOSTLY_Q6_K = 14 #except 1d tensors + MOSTLY_IQ2_XXS = 15 #except 1d tensors + MOSTLY_IQ2_XS = 16 #except 1d tensors + MOSTLY_IQ3_XXS = 17 #except 1d tensors + MOSTLY_IQ1_S = 18 #except 1d tensors + MOSTLY_IQ4_NL = 19 #except 1d tensors + MOSTLY_IQ3_S = 20 #except 1d tensors + MOSTLY_IQ2_S = 21 #except 1d tensors + MOSTLY_IQ4_XS = 22 #except 1d tensors + MOSTLY_IQ1_M = 23 #except 1d tensors + MOSTLY_BF16 = 24 #except 1d tensors + MOSTLY_Q4_0_4_4 = 25 #except 1d tensors + MOSTLY_Q4_0_4_8 = 26 #except 1d tensors + MOSTLY_Q4_0_8_8 = 27 #except 1d tensors + MOSTLY_Q6_0 = 127 #except 1d tensors + MOSTLY_IQ1_BN = 128 #except 1d tensors + MOSTLY_IQ2_BN = 129 #except 1d tensors + MOSTLY_IQ2_K = 130 #except 1d tensors + MOSTLY_IQ3_K = 131 #except 1d tensors + MOSTLY_IQ4_K = 132 #except 1d tensors + MOSTLY_IQ5_K = 133 #except 1d tensors + MOSTLY_IQ6_K = 134 #except 1d tensors + MOSTLY_IQ4_KS = 137 #except 1d tensors + MOSTLY_IQ2_KS = 138 #except 1d tensors + MOSTLY_IQ4_KSS = 139 #except 1d tensors + MOSTLY_Q8_KV = 140 #except 1d tensors + MOSTLY_Q4_0_R8 = 202 #except 1d tensors + MOSTLY_Q8_0_R8 = 207 #except 1d tensors + MOSTLY_Q5_0_R4 = 208 #except 1d tensors + MOSTLY_Q2_K_R4 = 210 #except 1d tensors + MOSTLY_Q3_K_R4 = 211 #except 1d tensors + MOSTLY_Q4_K_R4 = 212 #except 1d tensors + MOSTLY_Q5_K_R4 = 213 #except 1d tensors + MOSTLY_Q6_K_R4 = 214 #except 1d tensors + MOSTLY_IQ2_XXS_R4 = 215 #except 1d tensors + MOSTLY_IQ2_XS_R4 = 216 #except 1d tensors + MOSTLY_IQ3_XXS_R4 = 217 #except 1d tensors + MOSTLY_IQ1_S_R4 = 218 #except 1d tensors + MOSTLY_IQ4_NL_R4 = 219 #except 1d tensors + MOSTLY_IQ3_S_R4 = 220 #except 1d tensors + MOSTLY_IQ2_S_R4 = 221 #except 1d tensors + MOSTLY_IQ4_XS_R8 = 222 #except 1d tensors + MOSTLY_IQ1_M_R4 = 223 #except 1d tensors + MOSTLY_BF16_R16 = 224 #except 1d tensors + MOSTLY_Q6_0_R4 = 227 #except 1d tensors + MOSTLY_IQ2_BN_R4 = 329 #except 1d tensors + MOSTLY_IQ2_K_R4 = 330 #except 1d tensors + MOSTLY_IQ3_K_R4 = 331 #except 1d tensors + MOSTLY_IQ4_K_R4 = 332 #except 1d tensors + MOSTLY_IQ5_K_R4 = 333 #except 1d tensors + MOSTLY_IQ4_KS_R4 = 337 #except 1d tensors + MOSTLY_Q8_KV_R8 = 398 #except 1d tensors + MOSTLY_Q8_K_R8 = 399 #except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1281,39 +1439,89 @@ class GGUFValueType(IntEnum): # Items here are (block size, type size) QK_K = 256 + +#Values generated programatically GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { - GGMLQuantizationType.F32: (1, 4), - GGMLQuantizationType.F16: (1, 2), - GGMLQuantizationType.Q4_0: (32, 2 + 16), - GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), - GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), - GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), - GGMLQuantizationType.Q8_0: (32, 2 + 32), - GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), - GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), - GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), - GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), - GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), - GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), - GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), - GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4), - GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), - GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8), - GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), - GGMLQuantizationType.IQ4_NL: (32, 2 + 16), - GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), - GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), - GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), - GGMLQuantizationType.I8: (1, 1), - GGMLQuantizationType.I16: (1, 2), - GGMLQuantizationType.I32: (1, 4), - GGMLQuantizationType.I64: (1, 8), - GGMLQuantizationType.F64: (1, 8), - GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), - GGMLQuantizationType.BF16: (1, 2), - GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16), - GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16), - GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), + GGMLQuantizationType.F32 : ( 1, 4), + GGMLQuantizationType.F16 : ( 1, 2), + GGMLQuantizationType.Q4_0 : ( 32, 18), + GGMLQuantizationType.Q4_1 : ( 32, 20), + GGMLQuantizationType.Q5_0 : ( 32, 22), + GGMLQuantizationType.Q5_1 : ( 32, 24), + GGMLQuantizationType.Q8_0 : ( 32, 34), + GGMLQuantizationType.Q8_1 : ( 32, 36), + GGMLQuantizationType.Q2_K : ( 256, 84), + GGMLQuantizationType.Q3_K : ( 256, 110), + GGMLQuantizationType.Q4_K : ( 256, 144), + GGMLQuantizationType.Q5_K : ( 256, 176), + GGMLQuantizationType.Q6_K : ( 256, 210), + GGMLQuantizationType.Q8_K : ( 256, 292), + GGMLQuantizationType.IQ2_XXS : ( 256, 66), + GGMLQuantizationType.IQ2_XS : ( 256, 74), + GGMLQuantizationType.IQ3_XXS : ( 256, 98), + GGMLQuantizationType.IQ1_S : ( 256, 50), + GGMLQuantizationType.IQ4_NL : ( 32, 18), + GGMLQuantizationType.IQ3_S : ( 256, 110), + GGMLQuantizationType.IQ2_S : ( 256, 82), + GGMLQuantizationType.IQ4_XS : ( 256, 136), + GGMLQuantizationType.I8 : ( 1, 1), + GGMLQuantizationType.I16 : ( 1, 2), + GGMLQuantizationType.I32 : ( 1, 4), + GGMLQuantizationType.I64 : ( 1, 8), + GGMLQuantizationType.F64 : ( 1, 8), + GGMLQuantizationType.IQ1_M : ( 256, 56), + GGMLQuantizationType.BF16 : ( 1, 2), + GGMLQuantizationType.Q4_0_4_4 : ( 32, 18), + GGMLQuantizationType.Q4_0_4_8 : ( 32, 18), + GGMLQuantizationType.Q4_0_8_8 : ( 32, 18), + GGMLQuantizationType.I2_S : ( 1, 1), + GGMLQuantizationType.Q8_0_X4 : ( 32, 34), + GGMLQuantizationType.Q8_1_X4 : ( 32, 36), + GGMLQuantizationType.Q8_2_X4 : ( 32, 36), + GGMLQuantizationType.Q6_0 : ( 32, 26), + GGMLQuantizationType.IQ1_BN : ( 64, 13), + GGMLQuantizationType.IQ2_BN : ( 64, 16), + GGMLQuantizationType.Q8_K64 : ( 64, 68), + GGMLQuantizationType.IQ2_K : ( 256, 76), + GGMLQuantizationType.IQ3_K : ( 256, 110), + GGMLQuantizationType.IQ4_K : ( 256, 144), + GGMLQuantizationType.IQ5_K : ( 256, 176), + GGMLQuantizationType.IQ6_K : ( 256, 212), + GGMLQuantizationType.IQ4_KS : ( 256, 136), + GGMLQuantizationType.IQ2_KS : ( 256, 70), + GGMLQuantizationType.IQ4_KSS : ( 256, 128), + GGMLQuantizationType.Q8_K16 : ( 64, 64), + GGMLQuantizationType.Q8_K32 : ( 256, 292), + GGMLQuantizationType.Q8_KR8 : ( 256, 292), + GGMLQuantizationType.Q8_K128 : ( 128, 140), + GGMLQuantizationType.Q8_KV : ( 32, 32), + GGMLQuantizationType.Q4_0_R8 : ( 32, 18), + GGMLQuantizationType.Q5_0_R4 : ( 32, 22), + GGMLQuantizationType.Q8_0_R8 : ( 32, 34), + GGMLQuantizationType.Q2_K_R4 : ( 256, 84), + GGMLQuantizationType.Q3_K_R4 : ( 256, 110), + GGMLQuantizationType.Q4_K_R4 : ( 256, 144), + GGMLQuantizationType.Q5_K_R4 : ( 256, 176), + GGMLQuantizationType.Q6_K_R4 : ( 256, 210), + GGMLQuantizationType.IQ2_XXS_R4 : ( 256, 66), + GGMLQuantizationType.IQ2_XS_R4 : ( 256, 74), + GGMLQuantizationType.IQ3_XXS_R4 : ( 256, 98), + GGMLQuantizationType.IQ1_S_R4 : ( 32, 6), + GGMLQuantizationType.IQ4_NL_R4 : ( 32, 18), + GGMLQuantizationType.IQ3_S_R4 : ( 256, 110), + GGMLQuantizationType.IQ2_S_R4 : ( 256, 82), + GGMLQuantizationType.IQ4_XS_R8 : ( 256, 136), + GGMLQuantizationType.IQ1_M_R4 : ( 32, 7), + GGMLQuantizationType.BF16_R16 : ( 1, 2), + GGMLQuantizationType.Q6_0_R4 : ( 32, 26), + GGMLQuantizationType.IQ2_BN_R4 : ( 64, 16), + GGMLQuantizationType.IQ2_K_R4 : ( 256, 76), + GGMLQuantizationType.IQ3_K_R4 : ( 256, 110), + GGMLQuantizationType.IQ4_K_R4 : ( 256, 144), + GGMLQuantizationType.IQ5_K_R4 : ( 256, 176), + GGMLQuantizationType.IQ4_KS_R4 : ( 256, 136), + GGMLQuantizationType.Q8_KV_R8 : ( 32, 32), + GGMLQuantizationType.Q8_K_R8 : ( 256, 258), } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a70b69c5..9688b02c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -82,6 +82,9 @@ class TensorNameMap: "rope.freqs", # llama-pth "rotary_pos_emb.inv_freq", # chatglm ), + + MODEL_TENSOR.ROPE_FACTORS_LONG: (), + MODEL_TENSOR.ROPE_FACTORS_SHORT: (), } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { @@ -131,6 +134,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.qkv_proj", # phi3 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm + "layers.{bid}.attention.wqkv", ), # Attention query @@ -175,7 +179,8 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf + "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert "transformer.h.{bid}.attn.out_proj", # gpt-j @@ -446,6 +451,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), @@ -456,10 +469,14 @@ class TensorNameMap: MODEL_TENSOR.ATTN_SUB_NORM: ( "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet + "layers.{bid}.attention.attn_sub_norm", # bitnet + "model.layers.{bid}.self_attn.attn_sub_norm", ), MODEL_TENSOR.FFN_SUB_NORM: ( "model.layers.{bid}.mlp.ffn_layernorm", # bitnet + "layers.{bid}.feed_forward.ffn_sub_norm", # bitnet + "model.layers.{bid}.mlp.ffn_sub_norm", ), MODEL_TENSOR.DEC_ATTN_NORM: ( diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index dc574991..cca09798 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -122,8 +122,30 @@ class SpecialVocab: tokenizer = json.load(f) if self.load_merges: merges = tokenizer.get('model', {}).get('merges') - if isinstance(merges, list) and merges and isinstance(merges[0], str): - self.merges = merges + if isinstance(merges, list) and merges: + if isinstance(merges[0], str): + self.merges = merges + elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str): + # New format since transformers 4.45 to support spaces in merges + # ref: https://github.com/ggml-org/llama.cpp/issues/9692 + # TODO: internally store as the new format instead of converting to old + if any(' ' in s for pair in merges for s in pair): + logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}') + self.merges = [ + ' '.join( + [ + # ensure the spaces are properly encoded + ''.join( + chr(ord(c) + 256) if c == ' ' else c + for c in part + ) + for part in pair + ] + ) + for pair in merges + ] + else: + raise ValueError("Unknown tokenizer merges format") added_tokens = tokenizer.get('added_tokens', {}) else: added_tokens = {} @@ -132,7 +154,12 @@ class SpecialVocab: return True with open(tokenizer_config_file, encoding = 'utf-8') as f: tokenizer_config = json.load(f) - chat_template = tokenizer_config.get('chat_template') + chat_template_alt = None + chat_template_file = path / 'chat_template.json' + if chat_template_file.is_file(): + with open(chat_template_file, encoding = 'utf-8') as f: + chat_template_alt = json.load(f).get('chat_template') + chat_template = tokenizer_config.get('chat_template', chat_template_alt) if chat_template is None or isinstance(chat_template, (str, list)): self.chat_template = chat_template else: diff --git a/include/llama.h b/include/llama.h index 21479525..ed24f862 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #ifndef LLAMA_H #define LLAMA_H @@ -93,7 +100,12 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 23, //llama.cpp lists this as 28 + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28 + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, }; // note: these values should be synchronized with ggml_rope @@ -180,9 +192,10 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ2_KT = 149, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ3_KT = 150, // except 1d tensors - LLAMA_FTYPE_MOSTLY_IQ4_KT = 151, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_KV = 149, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_KT = 150, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_KT = 151, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_KT = 152, // except 1d tensors // LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors @@ -209,6 +222,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 = 345, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file @@ -219,7 +233,8 @@ extern "C" { LLAMA_ROPE_SCALING_TYPE_NONE = 0, LLAMA_ROPE_SCALING_TYPE_LINEAR = 1, LLAMA_ROPE_SCALING_TYPE_YARN = 2, - LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN, + LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3, + LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE, }; enum llama_pooling_type { @@ -306,6 +321,11 @@ extern "C" { }; }; + struct llama_model_tensor_buft_override { + const char * pattern; + ggml_backend_buffer_type_t buft; + }; + struct llama_model_params { int32_t n_gpu_layers; // number of layers to store in VRAM enum llama_split_mode split_mode; // how to split the model across multiple GPUs @@ -333,12 +353,15 @@ extern "C" { // override key-value pairs of the model meta data const struct llama_model_kv_override * kv_overrides; + const struct llama_model_tensor_buft_override * tensor_buft_overrides; + // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool repack_tensors;// repack if available + bool use_thp; // uase transparent huge pages (linux only) }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -377,6 +400,11 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + int mla_attn; // whether to use MLA attention [EXPERIMENTAL] + int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL] + bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] + int min_experts; + float thresh_experts; // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -405,8 +433,11 @@ extern "C" { bool pure; // quantize all tensors to the default type bool keep_split; // quantize to the same number of shards bool ignore_imatrix_rules; // If set to true, the built-in rules for refusing to quantize into certain quants without imatrix are ignored + bool only_repack; // Only repack tensors void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides + void * custom_quants; // pointer to vector containing custom quantization rules + void * repack_pattern; // pointer to a vector containing regexes to be used for matching tensor names. Can be null } llama_model_quantize_params; // grammar types diff --git a/src/llama-impl.h b/src/llama-impl.h index 95277409..a9cbe0df 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #define LLAMA_API_INTERNAL diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 4bd5aa81..09399417 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -439,6 +439,27 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GPT4O: + regex_exprs = { + // original regex from tokenizer.json + // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_SUPERBPE: + regex_exprs = { + "\\p{N}+", + "(?=(\\d{3})+(?!\\d))", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE: + regex_exprs = { + // original regex from tokenizer.json + // "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + // FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?) + "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { diff --git a/src/llama.cpp b/src/llama.cpp index 61d56e8c..6cdc00c8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "llama-impl.h" #include "llama-vocab.h" #include "llama-grammar.h" @@ -12,6 +19,8 @@ // TODO: fix this include #include "iqk/iqk_quantize.h" +#define IK_PRINT_TIMING 0 + #ifdef GGML_USE_RPC # include "ggml-rpc.h" #endif @@ -99,6 +108,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -174,6 +184,8 @@ static std::string format(const char * fmt, ...) { enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA4, + LLM_ARCH_DECI, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, LLM_ARCH_GROK, @@ -191,6 +203,8 @@ enum llm_arch { LLM_ARCH_QWEN, LLM_ARCH_QWEN2, LLM_ARCH_QWEN2MOE, + LLM_ARCH_QWEN3, + LLM_ARCH_QWEN3MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PLAMO, @@ -200,6 +214,7 @@ enum llm_arch { LLM_ARCH_MINICPM, LLM_ARCH_GEMMA, LLM_ARCH_GEMMA2, + LLM_ARCH_GEMMA3, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, @@ -210,17 +225,23 @@ enum llm_arch { LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, + LLM_ARCH_GLM4, LLM_ARCH_BITNET, + LLM_ARCH_BITNET_25, + LLM_ARCH_BITNET_B158, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, - LLM_ARCH_GRANITE = 46, + LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, + LLM_ARCH_COHERE2, LLM_ARCH_UNKNOWN, }; static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA4, "llama4" }, + { LLM_ARCH_DECI, "deci" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GROK, "grok" }, { LLM_ARCH_GPT2, "gpt2" }, @@ -238,6 +259,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN3, "qwen3" }, + { LLM_ARCH_QWEN3MOE, "qwen3moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PLAMO, "plamo" }, @@ -247,6 +270,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINICPM, "minicpm" }, { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, @@ -257,12 +281,16 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_BITNET_25, "bitnet-25" }, + { LLM_ARCH_BITNET_B158, "bitnet-b1.58" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_COHERE2, "cohere2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -307,6 +335,8 @@ enum llm_kv { LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_RESIDUAL_SCALE, LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, + LLM_KV_INTERLEAVE_MOE_LAYER_STEP, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -410,6 +440,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, + { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, + { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -539,6 +571,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, @@ -600,6 +634,61 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_DECI, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_LLAMA4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_BAICHUAN, { @@ -885,6 +974,45 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_QWEN3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN3MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PHI2, { @@ -1051,6 +1179,26 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GEMMA3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { @@ -1203,6 +1351,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1234,6 +1384,25 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_GLM4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_BITNET, { @@ -1252,6 +1421,62 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, }, }, + { + LLM_ARCH_BITNET_25, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, + { + LLM_ARCH_BITNET_B158, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, { LLM_ARCH_T5, { @@ -1355,7 +1580,21 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, - + { + LLM_ARCH_COHERE2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1396,6 +1635,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_LLAMA4, + LLM_CHAT_TEMPLATE_BITNET, LLM_CHAT_TEMPLATE_UNKNOWN, }; @@ -1431,6 +1672,8 @@ static const std::map LLM_CHAT_TEMPLATES = { { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, + { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, + { "bitnet", LLM_CHAT_TEMPLATE_BITNET }, }; @@ -1798,6 +2041,7 @@ using llama_files = std::vector>; struct llama_mmap { void * addr; size_t size; + size_t mapped_page_size = 0; llama_mmap(const llama_mmap &) = delete; @@ -1807,7 +2051,7 @@ struct llama_mmap { // list of mapped fragments (first_offset, last_offset) std::vector> mapped_fragments; - llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) { + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false, [[maybe_unused]] bool use_thp = false) { size = file->size; int fd = fileno(file->fp); int flags = MAP_SHARED; @@ -1820,6 +2064,29 @@ struct llama_mmap { strerror(errno)); } if (prefetch) { flags |= MAP_POPULATE; } + if (use_thp) { + size_t huge = get_default_huge_page_size(); + auto size = huge*((file->size + huge - 1)/huge); + addr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0); + if (addr != MAP_FAILED) { + printf("%s: using THP with page size %zu MiB ", __func__, huge/(1024*1024)); + fflush(stdout); + size_t tot = 0; + while (tot < file->size) { + auto n_read = pread(fd, static_cast(addr) + tot, file->size - tot, tot); + if (n_read < 0) throw std::runtime_error(format("Reading into mapped huge pages failed at %zu (%s)", tot, strerror(errno))); + printf("."); fflush(stdout); + tot += n_read; + } + printf(" done\n"); + mapped_fragments.emplace_back(0, file->size); + mapped_page_size = huge; + return; + } + else { + fprintf(stderr, "%s: mmap with huge page size %zu MiB failed (%s)\n", __func__, huge/(1024*1024), strerror(errno)); + } + } #endif addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); if (addr == MAP_FAILED) { // NOLINT @@ -1864,7 +2131,7 @@ struct llama_mmap { void unmap_fragment(size_t first, size_t last) { // note: this function must not be called multiple times with overlapping ranges // otherwise, there is a risk of invalidating addresses that have been repurposed for other mappings - int page_size = sysconf(_SC_PAGESIZE); + int page_size = mapped_page_size > 0 ? mapped_page_size : sysconf(_SC_PAGESIZE); align_range(&first, &last, page_size); size_t len = last - first; @@ -1906,6 +2173,28 @@ struct llama_mmap { mapped_fragments = std::move(new_mapped_fragments); } +#ifdef __linux__ + static int get_default_huge_page_size() { + int pg_size = 2048; + std::ifstream in("/proc/meminfo"); + if (in) { + std::string line; + while (true) { + std::getline(in, line); + if (in.fail()) break; + if (auto pos = line.find("Hugepagesize:"); pos != std::string::npos) { + std::istringstream str(line.data() + pos + 13); + int aux; + str >> aux; + if (!str.fail()) pg_size = aux; + break; + } + } + } + return pg_size * 1024; + } +#endif + ~llama_mmap() { for (const auto & frag : mapped_fragments) { if (munmap((char *) addr + frag.first, frag.second - frag.first)) { @@ -1916,7 +2205,7 @@ struct llama_mmap { #elif defined(_WIN32) static constexpr bool SUPPORTED = true; - llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false) { + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false, [[maybe_unused]] bool use_thp = false) { GGML_UNUSED(numa); size = file->size; @@ -1978,10 +2267,11 @@ struct llama_mmap { #else static constexpr bool SUPPORTED = false; - llama_mmap(struct llama_file * file, size_t prefetch = -1, bool numa = false) { + llama_mmap(struct llama_file * file, size_t prefetch = -1, bool numa = false, bool use_thp = false) { GGML_UNUSED(file); GGML_UNUSED(prefetch); GGML_UNUSED(numa); + GGML_UNUSED(use_thp); throw std::runtime_error("mmap not supported"); } @@ -2255,6 +2545,7 @@ enum e_model { MODEL_16B, MODEL_20B, MODEL_30B, + MODEL_32B, MODEL_34B, MODEL_35B, MODEL_40B, @@ -2262,6 +2553,7 @@ enum e_model { MODEL_70B, MODEL_236B, MODEL_314B, + MODEL_405B, MODEL_671B, MODEL_SMALL, MODEL_MEDIUM, @@ -2274,6 +2566,8 @@ enum e_model { MODEL_10B_128x3_66B, MODEL_57B_A14B, MODEL_27B, + MODEL_17B_16E, + MODEL_17B_128E, }; static const size_t kiB = 1024; @@ -2306,6 +2600,7 @@ struct llama_hparams { uint32_t n_layer; uint32_t n_rot; uint32_t n_swa = 0; // sliding window attention (SWA) + uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; @@ -2335,7 +2630,9 @@ struct llama_hparams { float rope_attn_factor = 1.0f; float rope_freq_base_train; + float rope_freq_base_train_swa; float rope_freq_scale_train; + float rope_freq_scale_train_swa; uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul; @@ -2358,6 +2655,14 @@ struct llama_hparams { bool use_alibi = false; bool attn_soft_cap = false; + uint32_t n_moe_layer_step = 0; + bool use_kq_norm = true; + uint32_t n_attn_chunk = 0; + // values below seems to be fixed on llama4 + uint32_t n_no_rope_layer_step = 4; + uint32_t n_attn_temp_floor_scale = 8192; + float f_attn_temp_scale = 0.1; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = -1; @@ -2374,6 +2679,7 @@ struct llama_hparams { if (this->n_layer != other.n_layer) return true; if (this->n_rot != other.n_rot) return true; if (this->n_swa != other.n_swa) return true; + if (this->n_swa_pattern != other.n_swa_pattern) return false; if (this->n_embd_head_k != other.n_embd_head_k) return true; if (this->n_embd_head_v != other.n_embd_head_v) return true; if (this->n_expert != other.n_expert) return true; @@ -2503,6 +2809,11 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + int mla_attn; + int attn_max_batch; + bool fused_moe_up_gate; + int min_experts; + float thresh_experts; enum llama_pooling_type pooling_type; @@ -2541,6 +2852,8 @@ struct llama_layer { struct ggml_tensor * wq_b; struct ggml_tensor * wkv_a_mqa; struct ggml_tensor * wkv_b; + struct ggml_tensor * wk_b; + struct ggml_tensor * wv_b; struct ggml_tensor * wq_cross; struct ggml_tensor * wk_cross; struct ggml_tensor * wv_cross; @@ -2626,6 +2939,9 @@ struct llama_layer { struct ggml_tensor * ffn_gate_scale; struct ggml_tensor * ffn_up_scale; struct ggml_tensor * ffn_down_scale; + + std::unique_ptr computed_wk_b; + std::unique_ptr computed_wv_b; }; struct llama_kv_cell { @@ -2674,6 +2990,10 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; + // DeepSeek MLA + std::vector kv_l; + std::vector kvt_l; + std::vector ctxs; std::vector bufs; @@ -2839,6 +3159,8 @@ struct llama_context { struct llama_kv_cache kv_self; struct llama_control_vector cvec; + std::vector scale_data; + std::unordered_map lora_adapters; std::vector backends; @@ -2915,6 +3237,7 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens] }; struct llama_lora_weight { @@ -3104,8 +3427,8 @@ static bool llama_kv_cache_init( cache.size = kv_size; cache.used = 0; - cache.type_k = type_k; - cache.type_v = type_v; + cache.type_k = type_k; + cache.type_v = type_v; cache.cells.clear(); cache.cells.resize(kv_size); @@ -3132,7 +3455,7 @@ static bool llama_kv_cache_init( for (auto & it : buft_layer_count) { int n_layers = it.second; struct ggml_init_params params = { - /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(), + /*.mem_size =*/ 5u*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -3145,20 +3468,90 @@ static bool llama_kv_cache_init( cache.ctxs.push_back(ctx); } - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); + if (model.arch == LLM_ARCH_DEEPSEEK2) { + bool have_wkv_b = true; + for (auto& l : model.layers) { + if (!l.wkv_b) { + have_wkv_b = false; + break; + } + } + if (!have_wkv_b) { + if (cparams.mla_attn != 1) { + LLAMA_LOG_WARN("=========================================================\n"); + LLAMA_LOG_WARN("%s: missing wkv_b tensor(s)\n", __func__); + LLAMA_LOG_WARN("%s: changing MLA from %d to 1\n", __func__, cparams.mla_attn); + if (cparams.mla_attn > 1) { + LLAMA_LOG_WARN("%s: ** Prompt processing performance will be crippled **\n", __func__); + } + LLAMA_LOG_WARN("=========================================================\n"); + // Sorry for the hack. + auto& non_cparams = const_cast(cparams); + non_cparams.mla_attn = 1; + } + } + } + if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) { + // DeepSeek MLA + cache.kv_l.reserve(n_layer); + if (cparams.mla_attn == 1 && !cparams.flash_attn) { + cache.kvt_l.reserve(n_layer); + } + } else { + cache.k_l.reserve(n_layer); + cache.v_l.reserve(n_layer); + } + + bool warn = true; + int n_mla = 0; for (int i = 0; i < (int) n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + const uint32_t n_head = hparams.n_head(i); + const uint32_t n_head_kv = hparams.n_head_kv(i); + const uint32_t n_embd_head_k= hparams.n_embd_head_k; + struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); + ggml_tensor * k; + ggml_tensor * v; + if (cparams.mla_attn) { + // DeepSeek MLA + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + //LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); + if (cparams.flash_attn) { + ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); + ggml_format_name(kv, "cache_kv_l%d", i); + cache.kv_l.push_back(kv); + } else { + auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; + ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); + ggml_format_name(kv, "cache_kv_l%d", i); + cache.kv_l.push_back(kv); + if (cparams.mla_attn == 1) { + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); + ggml_format_name(kvt, "cache_kvt_l%d", i); + cache.kvt_l.push_back(kvt); + } + } + n_mla++; + } + else { + k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); + v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.k_l.push_back(k); + cache.v_l.push_back(v); + } + } + if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn && n_mla < n_layer && n_mla > 0) { + LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer)); + LLAMA_LOG_ERROR("%s: bailing out\n", __func__); + GGML_ABORT("fatal error"); } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -3738,7 +4131,7 @@ static size_t llama_model_max_nodes(const llama_model & /*model*/) { // return 32768; //} - return 8192; + return 65536; } struct llama_model_loader { @@ -3752,6 +4145,7 @@ struct llama_model_loader { bool use_mmap = false; bool check_tensors; bool repack_tensors = false; + bool use_thp = false; llama_files files; llama_ftype ftype; @@ -3778,6 +4172,7 @@ struct llama_model_loader { std::vector weights; std::unordered_map kv_overrides; + const llama_model_tensor_buft_override * tensor_buft_overrides; struct gguf_context * meta = NULL; std::vector contexts; @@ -3785,7 +4180,9 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, const struct llama_model_kv_override * param_overrides_p) { + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -3797,6 +4194,8 @@ struct llama_model_loader { } } + tensor_buft_overrides = param_tensor_buft_overrides_p; + struct ggml_context * ctx = NULL; struct gguf_init_params params = { /*.no_alloc = */ true, @@ -3934,6 +4333,7 @@ struct llama_model_loader { case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break; case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q8_KV: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV; break; case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break; @@ -3944,6 +4344,7 @@ struct llama_model_loader { case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break; case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break; + case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break; case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break; case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; @@ -4046,6 +4447,7 @@ struct llama_model_loader { this->use_mmap = use_mmap; this->check_tensors = check_tensors; this->repack_tensors = repack_tensors; + this->use_thp = use_thp; } ~llama_model_loader() { @@ -4359,12 +4761,12 @@ struct llama_model_loader { } } - void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr) { + void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr, bool use_thp = false) { if (use_mmap) { mappings.reserve(files.size()); mmaps_used.reserve(files.size()); for (const auto & file : files) { - std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa())); + std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa(), use_thp)); mmaps_used.emplace_back(mapping->size, 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); @@ -4665,6 +5067,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q8_KV: return "Q8_KV"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_R4: return "Q2_K_R4"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; @@ -4681,6 +5084,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; case LLAMA_FTYPE_MOSTLY_Q6_K_R4: return "Q6_K_R4"; case LLAMA_FTYPE_MOSTLY_Q8_K_R8: return "Q8_K_R8"; + case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: return "Q8_KV_R8"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:return "IQ2_XXS_R4 - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; @@ -4775,6 +5179,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_16B: return "16B"; case MODEL_20B: return "20B"; case MODEL_30B: return "30B"; + case MODEL_32B: return "32B"; case MODEL_34B: return "34B"; case MODEL_35B: return "35B"; case MODEL_40B: return "40B"; @@ -4782,6 +5187,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_70B: return "70B"; case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; + case MODEL_405B: return "405B"; case MODEL_671B: return "671B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; @@ -4794,6 +5200,8 @@ static const char * llama_model_type_name(e_model type) { case MODEL_10B_128x3_66B: return "10B+128x3.66B"; case MODEL_57B_A14B: return "57B.A14B"; case MODEL_27B: return "27B"; + case MODEL_17B_16E: return "17Bx16E (Scout)"; + case MODEL_17B_128E: return "17Bx128E (Maverick)"; default: return "?B"; } } @@ -4895,6 +5303,10 @@ static void llm_load_hparams( } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); // non-transformer models do not have attention heads @@ -4913,7 +5325,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25 || model.arch == LLM_ARCH_BITNET_B158 || model.arch == LLM_ARCH_DECI) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -4951,6 +5363,35 @@ static void llm_load_hparams( } } } break; + case LLM_ARCH_LLAMA4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full + hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick + hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later + + switch (hparams.n_expert) { + case 16: model.type = MODEL_17B_16E; break; + case 128: model.type = MODEL_17B_128E; break; + default: model.type = MODEL_UNKNOWN; + } + + if (model.type == MODEL_17B_128E) { + hparams.use_kq_norm = false; + } + } break; + case LLM_ARCH_DECI: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 80: model.type = e_model::MODEL_70B; break; + case 162: model.type = e_model::MODEL_405B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_MINICPM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5132,6 +5573,22 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN3MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_PHI2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -5242,6 +5699,28 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GEMMA3: + { + hparams.n_swa_pattern = 6; + + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_2B; break; + case 34: model.type = e_model::MODEL_4B; break; + case 48: model.type = e_model::MODEL_12B; break; + case 62: model.type = e_model::MODEL_27B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + + hparams.f_attention_scale = model.type == e_model::MODEL_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -5398,6 +5877,25 @@ static void llm_load_hparams( } break; case LLM_ARCH_DEEPSEEK2: { + if (hparams.n_head_kv() == 1) { + printf("==========================================================================\n"); + printf("Detected incompatible DeepSeek model.\n"); + printf("Will try to fix, but there are no guarantees\n\n"); + printf("*** Your prompt processing speed will be crippled ***\n\n"); + printf("Consider making your own ik_llama.cpp compatible model or\n"); + printf("ask the model provider to make one for you,\n"); + int n_nead_kv = hparams.n_gqa(); + if (n_nead_kv%16 != 0 || hparams.n_embd_head_k != 576 || hparams.n_embd_head_v != 512 || + hparams.n_rot != 64) { + printf("Sorry, uknown model => cannot fix it => bailing out\n"); + GGML_ABORT("Fatal error"); + } + for (auto& item : hparams.n_head_kv_arr) item = n_nead_kv; + hparams.n_embd_head_k = 192; + hparams.n_embd_head_v = 128; + printf("==========================================================================\n"); + //GGML_ABORT("Fatal error"); + } bool is_lite = (hparams.n_layer == 27); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); @@ -5410,7 +5908,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == 0) { + if (hparams.expert_gating_func == 0) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; @@ -5433,6 +5931,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GLM4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_9B; break; + case 61: model.type = e_model::MODEL_32B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5442,6 +5949,16 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BITNET_B158: + case LLM_ARCH_BITNET_25: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: model.type = e_model::MODEL_2B; break; // bitnet2b_2501 + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_T5: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5508,6 +6025,17 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_COHERE2: + { + hparams.n_swa_pattern = 4; + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -5729,6 +6257,7 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; vocab.tokenizer_clean_spaces = false; } else if ( + tokenizer_pre == "glm4" || tokenizer_pre == "chatglm-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; vocab.special_bos_id = -1; @@ -5752,6 +6281,23 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "codeshell") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT4O; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "superbpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "trillion") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TRILLION; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "bailingmoe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; + vocab.tokenizer_clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -6100,6 +6646,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); @@ -6203,6 +6750,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } + if (model.arch == LLM_ARCH_QWEN3MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } + if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); @@ -6339,6 +6890,26 @@ static bool llm_load_tensors( model.ctxs.push_back(ctx); } + auto ctx_for_buft = [&model, &ctx_map, ctx_size](ggml_backend_buffer_type_t buft) -> ggml_context * { + if (auto it = ctx_map.find(buft); it != ctx_map.end()) return it->second; + + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ctx_map[buft] = ctx; + model.ctxs.emplace_back(ctx); + + return ctx; + }; + LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, model.ctxs.size()*ctx_size/1024.0/1024.0); // create tensors for the weights @@ -6355,6 +6926,7 @@ static bool llm_load_tensors( const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab_type = hparams.n_vocab_type; + const int64_t n_rot = hparams.n_rot; const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; @@ -6372,6 +6944,20 @@ static bool llm_load_tensors( model.layers.resize(n_layer); + auto create_tensor = [&ml, &ctx_map, &ctx_for_buft] (ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) { + if (ml.tensor_buft_overrides) { + for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { + std::regex pattern(overrides->pattern); + if (std::regex_search(name, pattern)) { + LLAMA_LOG_INFO("Tensor %s buffer type overriden to %s\n", name.c_str(), ggml_backend_buft_name(overrides->buft)); + ctx = ctx_for_buft(overrides->buft); + break; + } + } + } + return ml.create_tensor(ctx, name, ne, flags); + }; + const auto tn = LLM_TN(model.arch); switch (model.arch) { case LLM_ARCH_LLAMA: @@ -6380,16 +6966,16 @@ static bool llm_load_tensors( case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6399,39 +6985,39 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); if (n_expert == 0) { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); // optional MLP bias - layer.ffn_gate_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } else { - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_exps) { - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } else { // merge split expert into a single tensor for compatibility with older models // requires disabling mmap @@ -6459,22 +7045,147 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_DECI: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.wo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + } + + // optional bias tensors + + + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (n_ff > 0) { + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + } + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_rot/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_rot/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } + + if (n_ff > 0) { + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + + // optional MLP bias + layer.ffn_gate_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_LLAMA4: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, + llama_model_loader::TENSOR_NOT_REQUIRED); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, + llama_model_loader::TENSOR_DUPLICATED); + } + + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0"); + for (int i = 0; i < n_layer; ++i) { + bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0; + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, + llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + + if (is_moe_layer) { + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } else { + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; case LLM_ARCH_GROK: { if (n_expert == 0) { throw std::runtime_error("Grok model cannot have zero experts"); } - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6484,23 +7195,23 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + layer.attn_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_exps) { - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } else { // merge split expert into a single tensor for compatibility with older models // requires disabling mmap @@ -6526,7 +7237,7 @@ static bool llm_load_tensors( } } - layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); } } break; case LLM_ARCH_DBRX: @@ -6535,12 +7246,12 @@ static bool llm_load_tensors( throw std::runtime_error("DBRX model cannot have zero experts"); } - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6549,25 +7260,25 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + layer.attn_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } } break; case LLM_ARCH_BAICHUAN: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6576,32 +7287,32 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_FALCON: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -6611,32 +7322,32 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2 = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_STARCODER: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.pos_embd = create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { // needs to be on GPU - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6647,37 +7358,37 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.type_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); if (model.arch == LLM_ARCH_BERT) { - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + model.pos_embd = create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); } - model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + model.tok_norm = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -6686,45 +7397,45 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; if (model.arch == LLM_ARCH_BERT) { - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); } else { - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); } - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); - layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + layer.attn_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + layer.attn_out_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); if (model.arch == LLM_ARCH_BERT) { - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); } else { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); } - layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); - layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + layer.layer_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); } } break; case LLM_ARCH_JINA_BERT_V2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings - model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); // token_type_embeddings + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings + model.type_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); // token_type_embeddings - model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm - model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias + model.tok_norm = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm + model.tok_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -6732,51 +7443,51 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; // JinaBertLayer - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens - layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens + layer.bo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm - layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + layer.attn_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm + layer.attn_out_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); - layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2 = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.layer_out_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); - layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + layer.layer_out_norm = create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); } } break; case LLM_ARCH_BLOOM: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_norm = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6785,38 +7496,38 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_MPT: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.pos_embd = create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -6826,43 +7537,43 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); // AWQ ScaleActivation layer - layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_act = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_STABLELM: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6871,40 +7582,40 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors, present in Stable LM 2 1.6B - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); // optional q and k layernorms, present in StableLM 2 12B - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); // optional FFN norm, not present in StableLM 2 12B which uses parallel residual - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_QWEN: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6913,30 +7624,30 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}); } } break; case LLM_ARCH_QWEN2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6946,33 +7657,33 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_QWEN2MOE: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6981,21 +7692,21 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); GGML_ASSERT(n_expert > 0); GGML_ASSERT(n_expert_used > 0); @@ -7003,29 +7714,31 @@ static bool llm_load_tensors( // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); // Shared expert branch const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}); - layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}); - layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}); - layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}); + layer.ffn_gate_inp_shexp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}); + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}); } } break; - case LLM_ARCH_PHI2: + case LLM_ARCH_QWEN3: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); - model.output_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } } for (int i = 0; i < n_layer; ++i) { @@ -7034,43 +7747,118 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; + case LLM_ARCH_QWEN3MOE: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + GGML_ASSERT(n_expert > 0); + GGML_ASSERT(n_expert_used > 0); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + } + } break; + case LLM_ARCH_PHI2: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); } - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_PHI3: { const int64_t n_embd_head = n_embd / n_head; - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }); } for (int i = 0; i < n_layer; ++i) { @@ -7079,28 +7867,28 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }); - layer.rope_long = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); - layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_long = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); } } break; case LLM_ARCH_PLAMO: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7109,28 +7897,28 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_GPT2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.pos_embd = create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7139,34 +7927,34 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_CODESHELL: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7175,32 +7963,32 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_ORION: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7208,30 +7996,30 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_INTERNLM2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7240,26 +8028,26 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - // layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + // layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_GEMMA: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7267,26 +8055,26 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); } } break; case LLM_ARCH_GEMMA2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7294,34 +8082,66 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); - layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}); + } + } break; + case LLM_ARCH_GEMMA3: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, + llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; case LLM_ARCH_STARCODER2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7332,29 +8152,29 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); // optional bias tensors - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}); } } break; case LLM_ARCH_MAMBA: @@ -7367,16 +8187,16 @@ static bool llm_load_tensors( // only an expansion factor of 2 is supported for now GGML_ASSERT(2 * n_embd == d_inner); - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed, duplicated to allow offloading if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7387,32 +8207,32 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; // norm - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); + layer.ssm_in = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); - layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); - layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); + layer.ssm_conv1d = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d_b = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); - layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); + layer.ssm_x = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); - layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); - layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); + layer.ssm_dt = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); + layer.ssm_dt_b = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); // no "weight" suffix for these - layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); - layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + layer.ssm_a = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_d = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); // out_proj - layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + layer.ssm_out = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } } break; case LLM_ARCH_XVERSE: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7421,28 +8241,28 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_COMMAND_R: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); // init output from the input tok embed - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } for (int i = 0; i < n_layer; ++i) { @@ -7451,33 +8271,33 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); if (n_layer >= 64){ - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}); } - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7486,25 +8306,25 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_OPENELM: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); // init output from the input tok embed - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } for (int i = 0; i < n_layer; ++i) { @@ -7517,28 +8337,28 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_GPTNEOX: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7547,37 +8367,37 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_ARCTIC: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7587,24 +8407,24 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}); - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_norm_exps = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_norm_exps = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } } break; case LLM_ARCH_DEEPSEEK2: @@ -7620,12 +8440,12 @@ static bool llm_load_tensors( const int64_t n_ff_exp = hparams.n_ff_exp; const int64_t n_expert_shared = hparams.n_expert_shared; - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7634,57 +8454,67 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); if (!is_lite) { - layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); + layer.attn_q_a_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); } - layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); + layer.attn_kv_a_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); if (!is_lite) { - layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); - layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}); + layer.wq_a = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); + layer.wq_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}); } else { - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); } - layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}); - layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}); + layer.wkv_a_mqa = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i),{n_embd, kv_lora_rank + (n_embd_head_qk_rope)}); + layer.wkv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), + {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (!layer.wkv_b) { + // Incompatible mainline model. Let's see if we can still load it + layer.wk_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v, n_head}, 0); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + } else { + layer.wk_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 1); + layer.wv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 1); + } + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } else { - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_exp_probs_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 1); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_exp_probs_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 1); GGML_ASSERT(n_expert > 0); GGML_ASSERT(n_expert_used > 0); // MoE branch - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); // Shared expert branch - layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}); - layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}); - layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}); + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}); } } } break; case LLM_ARCH_BITNET: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading } const uint32_t n_ff = hparams.n_ff(); @@ -7695,44 +8525,129 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_sub_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wq_scale = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wk_scale = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wv_scale = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo_scale = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_sub_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_BITNET_B158: + case LLM_ARCH_BITNET_25: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.attn_sub_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + layer.ffn_sub_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + // optional bias tensors + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_exps) { + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + } else { + // merge split expert into a single tensor for compatibility with older models + // requires disabling mmap + use_mmap_buffer = false; + + ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type; + ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type; + ggml_type type_up = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, 0).c_str())->type; + + layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd, n_ff, n_expert); + layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down, n_ff, n_embd, n_expert); + layer.ffn_up_exps = ggml_new_tensor_3d(ctx_split, type_up, n_embd, n_ff, n_expert); + + ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str()); + ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str()); + ggml_set_name(layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i).c_str()); + + for (uint32_t x = 0; x < n_expert; ++x) { + // the individual experts are loaded into a view of the merged tensor + ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x); + ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x); + ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x); + } + } + } } } break; case LLM_ARCH_T5: { const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_enc = create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7742,55 +8657,55 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_rel_b_enc = create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + layer.wq_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); - layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_rel_b = create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); - layer.attn_norm_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_cross = create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}); // this tensor seems to be unused in HF transformers implementation - layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_rel_b_cross = create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wq_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wk_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + layer.wq_cross = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk_cross = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv_cross = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo_cross = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_T5ENCODER: { const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm_enc = create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7800,29 +8715,29 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_rel_b_enc = create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + layer.wq_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); - layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_JAIS: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // Output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7831,36 +8746,36 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_CHATGLM: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7869,18 +8784,88 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + } + } break; + case LLM_ARCH_COHERE2: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + llama_model_loader::TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } + } + break; + case LLM_ARCH_GLM4: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; default: @@ -7890,7 +8875,7 @@ static bool llm_load_tensors( ml.done_getting_tensors(); - ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr); + ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr, ml.use_thp); model.mappings.reserve(ml.mappings.size()); // create the backend buffers @@ -7911,7 +8896,7 @@ static bool llm_load_tensors( // only the mmap region containing the tensors in the model is mapped to the backend buffer // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size - if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(true)) { + if (ml.use_mmap && use_mmap_buffer && (buft == llama_default_buffer_type_cpu(true) || buft == ggml_backend_cpu_buffer_type())) { for (uint32_t idx = 0; idx < ml.files.size(); idx++) { void * addr = nullptr; size_t first, last; @@ -8018,6 +9003,146 @@ static bool llm_load_tensors( } } + if (model.arch == LLM_ARCH_DEEPSEEK2) { + int n_to_compute = 0; + for (auto& l : model.layers) { + if (!l.wk_b) ++n_to_compute; + } + if (n_to_compute > 0) { + // Prepare wk_b tensors to enable MLA usage also for model files that do not include + // the wk_b tensors (because, e.g., they were converted using mainline llama.cpp) + // We do it here because otherwise wkv_b may get run-time-repacked, which will make + // preparation of wk_b impossible. It also has the benefit that wk_b will get automatically + // run-time repacked if the rtr option is set. The downside is that we will prepare wk_b + // even if it is not needed (because MLA is not being used). If we wanted to avoid + // computing wk_b from wkv_b if not needed, we would need to propagate the context parameters + // to the model loading function. On the other hand, in some hypothetical bright future, + // where we are able to use the optimum settings for the computation, which for DeepSeekV3/R1/Lite + // is no MLA + FA for prompt processing, and MLA + FA for token generation, it would be useful + // to change the MLA setting on the fly, depending on context. In that case, having prepared + // the MLA tensors here is the right ting to do^TM. + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + const int32_t n_embd_head_v = hparams.n_embd_head_v; + const int32_t n_head = hparams.n_head(0); + std::vector work_data; + LLAMA_LOG_INFO("============ %s: need to compute %d wk_b tensors\n", __func__, n_to_compute); + for (int il = 1; il < n_layer; ++il) { + // Somehow the number of heads is being defined as being per layer. Not sure why this is the + // case, but for now we do not support strange models that have different numbers of heads + // in different model layers. + if (hparams.n_head(il) != n_head) throw std::runtime_error("Unsupported configuration"); + } + auto total_size_wkb = 0; + size_t max_wkv_size = 0; + size_t max_wk_size = 0; + for (auto& l : model.layers) { + if (!l.wk_b) { + auto new_type = ggml_is_quantized(l.wkv_b->type) ? GGML_TYPE_Q8_0 : l.wkv_b->type; + auto size = ggml_row_size(new_type, n_embd_head_qk_nope)*kv_lora_rank*n_head; + max_wk_size = std::max(max_wk_size, size); + if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) { + max_wkv_size = std::max(max_wkv_size, ggml_nbytes(l.wkv_b)); + } + } + } + auto context_size = max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float); + context_size *= 2; // just in case; + std::vector wkv_buffer; + if (max_wkv_size > 0) wkv_buffer.resize(max_wkv_size); + // So, transposing tensors and then making them contiguous as needed for wk_b may or may not + // be supported on all backends. Hence, to be sure that the preparation of wk_b will + // work correctly, we do it on the CPU backend. We then copy the resulting tensor data to + // the bacikend where wkv_b is stored. + ggml_init_params params{context_size, nullptr, true}; + auto ctx = ggml_init(params); + auto graph = ggml_new_graph_custom(ctx, 8, false); + std::vector tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size); + for (int il = 0; il < n_layer; ++il) { + auto& l = model.layers[il]; + if (l.wk_b) continue; + auto wkv_b = *l.wkv_b; + if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) { + ggml_backend_tensor_get(l.wkv_b, wkv_buffer.data(), 0, ggml_nbytes(l.wkv_b)); + wkv_b.data = wkv_buffer.data(); + } + auto wk_b_view = ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_qk_nope, n_head, + l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), 0); + auto wk_b_f32 = ggml_cast(ctx, wk_b_view, GGML_TYPE_F32); + wk_b_f32->data = tensor_data.data(); + auto wk_b_f32_tview = ggml_transpose(ctx, wk_b_f32); + auto wk_b_f32_t = ggml_cont(ctx, wk_b_f32_tview); + wk_b_f32_t->data = (char *)wk_b_f32->data + ggml_nbytes(wk_b_f32); + + auto new_type = ggml_is_quantized(wkv_b.type) ? + wkv_b.type >= GGML_TYPE_Q4_0_R8 && wkv_b.type <= GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_0_R8 : GGML_TYPE_Q8_0 : wkv_b.type; + auto wk_b = ggml_cast(ctx, wk_b_f32_t, new_type); + wk_b->data = (char *)wk_b_f32_t->data + ggml_nbytes(wk_b_f32_t); + + ggml_build_forward_expand(graph, wk_b); + + auto plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); + if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); + plan.work_data = work_data.data(); + + auto status = ggml_graph_compute(graph, &plan); + if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wk_b"); + + auto name = std::string{"blk."} + std::to_string(il) + ".attn_k_b.weight"; + + l.computed_wk_b = std::make_unique(*wk_b); + l.computed_wk_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wk_b)); + l.computed_wk_b->data = ggml_backend_buffer_get_base(l.computed_wk_b->buffer); + l.computed_wk_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents + // of wk_b, which no longer exist, and will therefore crash. + for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wk_b->src[j] = nullptr; + ggml_set_name(l.computed_wk_b.get(), name.c_str()); + ggml_backend_buffer_set_usage(l.computed_wk_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_tensor_set(l.computed_wk_b.get(), wk_b->data, 0, ggml_nbytes(wk_b)); + if (ggml_backend_buffer_is_host(l.computed_wk_b->buffer)) { + iqk_modify_tensor(l.computed_wk_b.get()); + } + + l.wk_b = l.computed_wk_b.get(); + + ggml_graph_clear(graph); + auto wv_b = ggml_cont(ctx, ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_v, n_head, + l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), l.wkv_b->nb[1]*n_embd_head_qk_nope)); + wv_b->data = tensor_data.data(); + ggml_build_forward_expand(graph, wv_b); + plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); + if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); + plan.work_data = work_data.data(); + status = ggml_graph_compute(graph, &plan); + if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wv_b"); + + name = std::string{"blk."} + std::to_string(il) + ".attn_v_b.weight"; + + l.computed_wv_b = std::make_unique(*wv_b); + l.computed_wv_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wv_b)); + l.computed_wv_b->data = ggml_backend_buffer_get_base(l.computed_wv_b->buffer); + l.computed_wv_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents + // of wk_b, which no longer exist, and will therefore crash. + for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wv_b->src[j] = nullptr; + ggml_set_name(l.computed_wv_b.get(), name.c_str()); + ggml_backend_buffer_set_usage(l.computed_wv_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_tensor_set(l.computed_wv_b.get(), wv_b->data, 0, ggml_nbytes(wv_b)); + if (ggml_backend_buffer_is_host(l.computed_wv_b->buffer)) { + iqk_modify_tensor(l.computed_wv_b.get()); + } + + l.wv_b = l.computed_wv_b.get(); + + printf("Computed %s as %ld x %ld x %ld and stored in buffer %s\n", name.c_str(), wk_b->ne[0], wk_b->ne[1], wk_b->ne[2], + ggml_backend_buffer_name(l.computed_wk_b->buffer)); + + ggml_graph_clear(graph); + } + ggml_free(ctx); + } + } + if (use_mmap_buffer) { for (auto & mapping : ml.mappings) { model.mappings.emplace_back(std::move(mapping)); @@ -8082,7 +9207,8 @@ static bool llm_load_tensors( // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { - llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.repack_tensors, params.kv_overrides); + llama_model_loader ml(fname, params.use_mmap, params.check_tensors, + params.repack_tensors, params.use_thp, params.kv_overrides, params.tensor_buft_overrides); model.hparams.vocab_only = params.vocab_only; @@ -8219,17 +9345,24 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + GGML_ASSERT(kv.size == n_ctx); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, - (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); - cb(k_cache_view, "k_cache_view", il); + //struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, + // (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); + //cb(k_cache_view, "k_cache_view", il); + + auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k); + ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv, + k_row_size, k_row_size*n_head_kv*kv_head); // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cache_view = nullptr; if (cparams.flash_attn) { @@ -8489,6 +9622,7 @@ llm_expert_gating_func_type gating_op, int il) { int64_t n_embd = cur->ne[0]; int64_t n_tokens = cur->ne[1]; + bool weight_before_ffn = lctx.model.arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); @@ -8517,8 +9651,15 @@ llm_expert_gating_func_type gating_op, cb(selection_probs, "ffn_moe_probs_biased", il); } + // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k + // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198 + if (lctx.model.arch == LLM_ARCH_LLAMA4) { + selection_probs = logits; + } + // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, + lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); @@ -8543,36 +9684,37 @@ llm_expert_gating_func_type gating_op, } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); - ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(gate, "ffn_moe_gate", il); + if (weight_before_ffn) { + // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d) + ggml_tensor * repeated = ggml_new_tensor_3d(ctx, cur->type, n_embd, n_expert_used, n_tokens); + repeated = ggml_repeat(ctx, cur, repeated); // [n_embd, n_expert_used, n_tokens] + cur = ggml_mul(ctx, repeated, weights); + cb(cur, "ffn_moe_weighted", il); + } - // This is equivalent to the commented out code below - ggml_tensor * par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + ggml_tensor * par; + if (lctx.cparams.fused_moe_up_gate) { + par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } else { + ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); - //switch (type_op) { - // case LLM_FFN_SILU: - // { - // gate = ggml_silu(ctx, gate); - // cb(gate, "ffn_moe_silu", il); - // } break; - // case LLM_FFN_GELU: - // { - // gate = ggml_gelu(ctx, gate); - // cb(gate, "ffn_moe_gelu", il); - // } break; - // default: - // GGML_ABORT("fatal error"); - //} - //ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] + ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); + + // This is equivalent to the commented out code below + par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } cb(par, "ffn_moe_gate_par", il); ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); - experts = ggml_mul(ctx, experts, weights); + if (!weight_before_ffn) { + experts = ggml_mul(ctx, experts, weights); + cb(cur, "ffn_moe_weighted", il); + } if (n_expert_used == 1) { return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0)); @@ -8642,7 +9784,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il], n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.k_l[il]->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa), ggml_row_size(kv.k_l[il]->type, n_embd_head_k), 0); cb(k, "k", il); @@ -8665,52 +9807,19 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { + // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA + // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. + // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || + (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } //ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below - - //try from phi2 - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - - //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - //kq = ggml_scale(ctx, kq, 30); - - kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f); - } - - if (hparams.attn_soft_cap) { - //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); - kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias, - 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); - } - cb(kq, "kq_soft_max_ext", il); - - GGML_ASSERT(kv.size == n_ctx); - - // split cached v into n_head heads + // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], n_kv, n_embd_head_v, n_head_kv, @@ -8719,14 +9828,100 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); + if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || + model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + //kq = ggml_scale(ctx, kq, 30); + + kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f); + } + + if (hparams.attn_soft_cap) { + //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias, + 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + } + cb(kq, "kq_soft_max_ext", il); + + GGML_ASSERT(kv.size == n_ctx); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } + else { + // For now we will not support this option if k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]; + GGML_ASSERT(k->ne[2] == v->ne[2] && k->ne[2] == q->ne[2]); + int n_step = (kq_size + cparams.attn_max_batch - 1)/cparams.attn_max_batch; + n_step = std::min(n_step, int(k->ne[2])); + int n_per_step = (q->ne[2] + n_step - 1)/n_step; + auto r2k = q->ne[2] / k->ne[2]; + auto r2v = q->ne[2] / v->ne[2]; + n_step = q->ne[2]; + n_per_step = 1; + ggml_tensor * kqv; + for (int i12 = 0; i12 < q->ne[2]; i12 += n_per_step) { + int this_ne12 = i12 + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i12; + int i02 = i12/r2k; + auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02); + auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12); + auto kq_i = ggml_mul_mat(ctx, k_i, q_i); + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || + model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { + ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32); + } + if (model.arch == LLM_ARCH_GROK) { + kq_i = ggml_softcap(ctx, kq_i, 0.08838834764831845f/30.0f, 30.f); + } + if (hparams.attn_soft_cap) { + kq_i = ggml_softcap_max(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias, + 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + } else { + kq_i = ggml_soft_max_ext(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias); + } + i02 = i12 / r2v; + auto v_i = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], this_ne12, v->nb[1], v->nb[2], v->nb[2]*i02); + auto kqv_i = ggml_mul_mat(ctx, v_i, kq_i); + if (i12 == 0) { + kqv = kqv_i; + } else { + kqv = ggml_concat(ctx, kqv, kqv_i, 2); + } + } + ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } } ggml_build_forward_expand(graph, cur); @@ -8821,6 +10016,11 @@ struct llm_build_context { const int32_t n_ctx_orig; const bool flash_attn; + const int mla_attn; + const int attn_max_batch; + const bool fused_moe_up_gate; + const int min_experts; + const float thresh_experts; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8836,7 +10036,8 @@ struct llm_build_context { llama_context & lctx, const llama_batch & batch, const llm_build_cb & cb, - bool worst_case) : + bool worst_case, + bool warmup) : model (lctx.model), lctx (lctx), hparams (model.hparams), @@ -8854,7 +10055,7 @@ struct llm_build_context { n_embd_head_v (hparams.n_embd_head_v), n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), - n_expert_used (hparams.n_expert_used), + n_expert_used (warmup ? hparams.n_expert : hparams.n_expert_used), freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor), @@ -8870,6 +10071,11 @@ struct llm_build_context { kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), + mla_attn (cparams.mla_attn), + attn_max_batch (cparams.attn_max_batch), + fused_moe_up_gate(cparams.fused_moe_up_gate), + min_experts (cparams.min_experts), + thresh_experts (cparams.thresh_experts), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -9039,6 +10245,14 @@ struct llm_build_context { return lctx.inp_pos; } + struct ggml_tensor * build_inpup_scale(int n_tokens) { + int n_pos_per_token = 1; + lctx.inp_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token); + cb(lctx.inp_scale, "inp_scale", -1); + ggml_set_input(lctx.inp_scale); + return lctx.inp_scale; + } + struct ggml_tensor * build_rope_factors(int il) { // choose long/short freq factors based on the context size const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; @@ -9219,22 +10433,36 @@ struct llm_build_context { GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); - struct ggml_tensor * cur; - struct ggml_tensor * inpL; + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_attn_scale = nullptr; inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); + if (model.arch == LLM_ARCH_LLAMA4) { + inp_attn_scale = build_inpup_scale(n_tokens); + } + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + //bool is_swa = hparams.n_swa > 0 && h_params.n_swa_pattern > 0 ? + ggml_tensor * KQ_mask = build_inp_KQ_mask(); + ggml_tensor * KQ_mask_swa = nullptr; + if (hparams.n_swa > 0 && hparams.n_swa_pattern > 0) { + KQ_mask_swa = build_inp_KQ_mask_swa(); + } //const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 1.f; for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; + bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true; + auto this_KQ_mask = hparams.n_swa > 0 && hparams.n_swa_pattern > 0 && il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1) ? + KQ_mask_swa : KQ_mask; + // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, @@ -9272,6 +10500,226 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } + if (use_rope) { + Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } else if (inp_attn_scale) { + Qcur = ggml_mul(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_attn_scale); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (model.arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { + // Llama4TextL2Norm + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + // Why is hparams.f_residual_scale not simply absorbed into model.layers[il].wv ? + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // non-MoE + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else if (model.arch == LLM_ARCH_LLAMA4) { + // llama4 MoE + ggml_tensor * ffn_inp_normed = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, lctx, ffn_inp_normed, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLM_EXPERT_GATING_FUNC_SIGMOID, + cb, il); + + // Shared experts + ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, ffn_inp_normed, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(shexp_out, "ffn_moe_shexp", il); + + cur = ggml_add(ctx0, moe_out, shexp_out); + cb(cur, "ffn_moe_out_merged", il); + + } else { + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, + cb, il); + cb(cur, "ffn_moe_out", il); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + // Why is hparams.f_residual_scale not simply absorbed into model.layers[il].ffn_down_exps ? + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + // For Granite architecture + if (hparams.f_logit_scale) { + // Why is hparams.f_logit_scale not simply absorbed into model.output ? + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + } + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_deci() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_head = hparams.n_head(il); + const int64_t n_ff = hparams.n_ff(il); + + if (n_head == 0) { // attention-free layer of Llama-3_1-Nemotron-51B + cur = inpL; + } else { + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + } + + if (n_head > 0 && n_head_kv == 0) { // "linear attention" of Llama-3_1-Nemotron-51B + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "wo", il); + } else if (n_head > 0) { + // self-attention + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9299,14 +10747,21 @@ struct llm_build_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architecture + // FFN-free layer of Llama-3_1-Nemotron-Ultra-253B + if (n_ff == 0) { + continue; + } + if (hparams.f_residual_scale) { - // Why is hparams.f_residual_scale not simply absorbed into model.layers[il].wv ? cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); } - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); + // modified to support attention-free layer of Llama-3_1-Nemotron-51B + struct ggml_tensor * ffn_inp = cur; + if (n_head > 0) { + ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + } // feed-forward network if (model.layers[il].ffn_gate_inp == nullptr) { @@ -9322,30 +10777,9 @@ struct llm_build_context { NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); - } else { - // MoE branch - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_moe_ffn(ctx0, lctx, cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - false, 0.0, - LLM_EXPERT_GATING_FUNC_SOFTMAX, - cb, il); - cb(cur, "ffn_moe_out", il); } - // For Granite architecture if (hparams.f_residual_scale) { - // Why is hparams.f_residual_scale not simply absorbed into model.layers[il].ffn_down_exps ? cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); } @@ -9369,9 +10803,7 @@ struct llm_build_context { // lm_head cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); - // For Granite architecture if (hparams.f_logit_scale) { - // Why is hparams.f_logit_scale not simply absorbed into model.output ? cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); } @@ -11162,6 +12594,245 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_qwen3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_qwen3moe() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = + llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, + cb, il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_phi2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -12357,6 +14028,126 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_gemma3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head_k = hparams.n_embd_head_k; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (batch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + // gemma3 requires different mask for layers using sliding window (SWA) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true); + struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true); + + // "5-to-1 interleaved attention" + // 5 layers of local attention followed by 1 layer of global attention + static const int sliding_window_pattern = 6; + + for (int il = 0; il < n_layer; ++il) { + const bool is_sliding = (il + 1) % sliding_window_pattern; + const float freq_base_l = is_sliding ? 10000.0f : freq_base; + const float freq_scale_l = is_sliding ? 1.0f : freq_scale; + struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il); + } + + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = llm_build_norm(ctx0, sa_out, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, cb, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } struct ggml_cgraph * build_starcoder2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -13335,6 +15126,10 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // whether to use n_tokens as the matrix dimension during multiplication or n_head + // n_tokens is higher during prompt processing, this allows to optimize for this case + bool pp_opt = n_tokens >= 128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head; + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -13373,87 +15168,339 @@ struct llm_build_context { cb(q_nope, "q_nope", il); // and {n_head * n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + struct ggml_tensor * q_rope = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, hparams.n_embd_head_k), ggml_row_size(q->type, hparams.n_embd_head_k * n_head), ggml_row_size(q->type, n_embd_head_qk_nope)); - cb(q_pe, "q_pe", il); + cb(q_rope, "q_rope", il); + + q_rope = ggml_rope_ext( + ctx0, q_rope, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_rope, "q_rope", il); // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); + struct ggml_tensor * kv_rope_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_rope_compresseed, "kv_rope_compresseed", il); + + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_rope_compresseed->nb[1], + kv_rope_compresseed->nb[1], + ggml_row_size(kv_rope_compresseed->type, kv_lora_rank)); + cb(k_rope, "k_rope", il); + + // shared RoPE key + k_rope = ggml_rope_ext( + ctx0, k_rope, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_rope, "k_rope", il); // split into {kv_lora_rank, n_tokens} - struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], + struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_rope_compresseed, kv_lora_rank, n_tokens, + kv_rope_compresseed->nb[1], 0); cb(kv_compressed, "kv_compressed", il); - // and {n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); - cb(k_pe, "k_pe", il); - - //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, cb, il); cb(kv_compressed, "kv_compressed", il); - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); + if (lctx.cparams.mla_attn) { - // split into {n_head * n_embd_head_qk_nope, n_tokens} - struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); + ggml_tensor * kv_cache_trans; - // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); + if (lctx.cparams.mla_attn == 1 && !lctx.cparams.flash_attn) { + ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, + ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head)); + cb(kv_cache_trans_view, "kv_cache_trans_view", il); - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); + // note: storing transposed c^KV in the transposed KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); + kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il], + n_kv, kv_lora_rank, + ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), + 0); + cb(kv_cache_trans, "kv_cache_trans", il); + } - //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(q_pe, "q_pe", il); + ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0); + cb(kvr, "kvr", il); - // shared RoPE key - //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_pe, "k_pe", il); + auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); + ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens, + row_size, row_size*kv_head); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view)); + ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank + n_embd_head_qk_rope, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache, "kv_cache", il); - struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); + ggml_tensor * kqv; - struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); + if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && pp_opt) { // PP for mla=2,3 - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, - k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); + + auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024); + int n_max_head = n_head; + if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) { + while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) { + n_max_head /= 2; kv_f32_size /= 2; + } + } + GGML_ASSERT(n_head % n_max_head == 0); + + auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head; + + auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, + kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + + // There is still an issue with one or more of the ops GGML_OP_REPEAT, GGML_OP_CONCAT, GGML_OP_CPY on CUDA when + // the KV cache is quantized. Hence, in that case we will simply use fp16 for now. + // The downside of the following line is that fp16 will be used even if attention is computed on the CPU + // if the build is with CUDA enabled. + auto kv_type = lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu ? kv_self.kv_l[il]->type : GGML_TYPE_F16; + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1; + ggml_tensor * k_rope; + if (kv_cache_rope->type == kv_type) { + k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); + } else { + auto kv_cache_rope_f16 = ggml_cast(ctx0, kv_cache_rope, GGML_TYPE_F16); + k_rope = ggml_repeat(ctx0, kv_cache_rope_f16, &repeater); + } + cb(k_rope, "k_rope", il); + + auto q = ggml_concat(ctx0, q_nope, q_rope, 0); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_concat", il); + + ggml_build_forward_expand(gf, q); + + for (int iter = 0; iter < n_head/n_max_head; ++iter) { + + auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head, + model.layers[il].wkv_b->nb[1], model.layers[il].wkv_b->nb[1]*n_per_head*n_max_head*iter); + + auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope); + cb(kv_f32, "kv_f32", il); + + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_max_head, + ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + cb(v_f32, "v_f32", il); + + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head, + ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + cb(k_nope_f32, "k_nope_f32", il); + + auto v = ggml_cast(ctx0, v_f32, kv_type); + cb(v, "v", il); + + auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_type); + cb(k_nope, "k_nope", il); + + ggml_build_forward_expand(gf, k_nope); + ggml_build_forward_expand(gf, v); + + auto k = ggml_concat(ctx0, k_nope, k_rope, 0); + cb(k, "k", il); + + ggml_build_forward_expand(gf, k); + + auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], n_max_head, + q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter); + + kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + if (q->ne[1] <= 8) { + ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + } + cb(kqv, "kqv", il); + + if (iter == 0) { + cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens); + } else { + cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0); + } + + } + + } + else { + + ggml_tensor * kqv_compressed; + + auto wkv_b = model.layers[il].wkv_b; + auto wk_b = model.layers[il].wk_b->ne[1] == kv_lora_rank ? model.layers[il].wk_b + : ggml_reshape_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head); + + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); + cb(q_nope2, "q_nope2", il); + + ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); + cb(q, "q", il); + + if (lctx.cparams.flash_attn && (lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3)) { + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache_lora, "kv_cache_lora", il); + + kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + cb(kqv_compressed, "kqv_compressed", il); + + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + else { + if (lctx.cparams.mla_attn > 1) { + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache, "kv_cache_lora", il); + + kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); + cb(kv_cache_trans, "kv_cache_trans", il); + } + + auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB + if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) { + if (!pp_opt) { + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_perm", il); + } + + ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); + if (kv_cache->ne[1] < 256) { + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + cb(kq, "kq", il); + + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } + + kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + cb(kqv_compressed, "kqv_compressed", il); + + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + + } else { + + int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; + n_step = std::min(n_step, int(q->ne[2])); + int n_per_step = (q->ne[2] + n_step - 1)/n_step; + + for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) { + int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head; + ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head); + ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); + kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); + if (i_head == 0) { + kqv_compressed = kqv_i; + } else { + kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + } + ggml_build_forward_expand(gf, kqv_compressed); + } + cb(kqv_compressed, "kqv_compressed", il); + } + } + + auto wv_b = model.layers[il].wv_b; + if (wv_b->ne[1] != n_embd_head_v) { + wv_b = ggml_reshape_3d(ctx0, wv_b, kv_lora_rank, n_embd_head_v, n_head); + cb(wv_b, "wv_b", il); + } + // There is an issue with quantized GEMV on CUDA when the left operand (the matrix) is + // not contiguous. So, for now, we create wv_b during model loading and use that + // instead of the commented out 3D view below. + //auto wv_b = ggml_view_3d(ctx0, wkv_b, kv_lora_rank, n_embd_head_v, n_head, + // wkv_b->nb[1], wkv_b->nb[1]*(n_embd_head_v + n_embd_head_qk_nope), + // wkv_b->nb[1]*n_embd_head_qk_nope); + //cb(wv_b, "wv_b", il); + + kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); + cb(kqv, "kqv", il); + + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); + cb(kqv, "kqv_perm", il); + + cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); + cb(cur, "kqv_2d", il); + } + + ggml_build_forward_expand(gf, cur); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "kqv_out", il); + + } + else { + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_rope, 0); + cb(q_states, "q_states", il); + + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_rope, q_rope), 0); + cb(k_states, "k_states", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + + } } if (il == n_layer - 1) { @@ -13690,6 +15737,287 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bitnet_158() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + // printf("%f\n\n\n\n",((float*)rope_factors->data)[1]); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + NULL, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_sub_norm", il); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + if (model.layers[il].wo_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + } + if (model.layers[il].bo) { + cur = ggml_add(ctx0, cur, model.layers[il].bo); + } + cb(cur, "attn_o_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + // n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, + NULL, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_sub_norm", il); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur); + if (model.layers[il].ffn_down_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + } + cb(cur, "ffn_down", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_cohere2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const float f_logit_scale = hparams.f_logit_scale; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + // cohere2 requires different mask for layers using sliding window (SWA) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); + + // sliding window switch pattern + const int32_t sliding_window_pattern = 4; + + for (int il = 0; il < n_layer; ++il) { + // three layers sliding window attention (window size 4096) and ROPE + // fourth layer uses global attention without positional embeddings + const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); + struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + struct ggml_tensor * ffn_inp = cur; + + // self-attention + { + // rope freq factors for 128k context + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + if (is_sliding) { + Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, + beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur", il); + } else { + // For non-sliding layers, just reshape without applying RoPE + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + cb(Kcur, "Kcur", il); + } + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, + KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + + struct ggml_tensor * attn_out = cur; + + // feed-forward network + { + cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, + NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, + cb, il); + cb(cur, "ffn_out", il); + } + + // add together residual + FFN + self-attention + cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, attn_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + if (f_logit_scale) { + cur = ggml_scale(ctx0, cur, f_logit_scale); + } + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_t5_encoder() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -14232,6 +16560,140 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_glm4() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Post-attention norm (new!) + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "post_attn_norm", il); + + // Add the input (residual connection after post-attention norm) + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + // Pre-MLP norm + cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // MLP + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + // Post-MLP norm + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "post_mlp_norm", il); + } + + // Add residual connection after post-MLP norm + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + // Final norm + cur = llm_build_norm(ctx0, inpL, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // Output projection + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -14240,7 +16702,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -14257,7 +16719,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -14274,7 +16736,7 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -14291,6 +16753,10 @@ static struct ggml_cgraph * llama_build_graph( bool worst_case) { const auto & model = lctx.model; +#if IK_PRINT_TIMING + auto tim1 = ggml_time_us(); +#endif + // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) { if (il >= 0) { @@ -14324,17 +16790,26 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_cgraph * result = NULL; - struct llm_build_context llm(lctx, batch, cb, worst_case); + const llama_vocab * vocab = llama_get_vocab(&lctx); + llama_token bos = llama_token_bos_impl(*vocab); + llama_token eos = llama_token_eos_impl(*vocab); + bool is_warming_up = (batch.n_tokens == 1 && batch.token[0] == bos); + struct llm_build_context llm(lctx, batch, cb, worst_case, is_warming_up); llm.init(); switch (model.arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { result = llm.build_llama(); } break; + case LLM_ARCH_DECI: + { + result = llm.build_deci(); + } break; case LLM_ARCH_BAICHUAN: { result = llm.build_baichuan(); @@ -14385,6 +16860,14 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen2moe(); } break; + case LLM_ARCH_QWEN3: + { + result = llm.build_qwen3(); + } break; + case LLM_ARCH_QWEN3MOE: + { + result = llm.build_qwen3moe(); + } break; case LLM_ARCH_PHI2: { result = llm.build_phi2(); @@ -14425,6 +16908,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_gemma2(); } break; + case LLM_ARCH_GEMMA3: + { + result = llm.build_gemma3(); + } break; case LLM_ARCH_STARCODER2: { result = llm.build_starcoder2(); @@ -14469,10 +16956,23 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_chatglm(); } break; + case LLM_ARCH_GLM4: + { + result = llm.build_glm4(); + } break; case LLM_ARCH_BITNET: { result = llm.build_bitnet(); } break; + case LLM_ARCH_BITNET_B158: + case LLM_ARCH_BITNET_25: + { + result = llm.build_bitnet_158(); + } break; + case LLM_ARCH_COHERE2: + { + result = llm.build_cohere2(); + } break; case LLM_ARCH_T5: { if (lctx.is_encoding) { @@ -14500,6 +17000,11 @@ static struct ggml_cgraph * llama_build_graph( llm.free(); +#if IK_PRINT_TIMING + auto tim2 = ggml_time_us(); + printf("%s(...): %d us\n", __func__, int(tim2-tim1)); +#endif + return result; } @@ -14556,6 +17061,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // set input data // +#if IK_PRINT_TIMING + auto tim1 = ggml_time_us(); +#endif const auto & hparams = lctx.model.hparams; const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -14579,6 +17087,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } + if (lctx.inp_pos && lctx.inp_scale) { + int n_tokens = batch.n_tokens; + GGML_ASSERT(ggml_nelements(lctx.inp_scale) >= n_tokens); + if (int(lctx.scale_data.size()) < n_tokens) lctx.scale_data.resize(n_tokens); + int n_pos_per_token = 1; + for (int i = 0; i < n_tokens; ++i) { + lctx.scale_data[i] = std::log(std::floor((batch.pos[i] + 1.0f) / hparams.n_attn_temp_floor_scale) + 1.0f) * hparams.f_attn_temp_scale + 1.0f; + } + ggml_backend_tensor_set(lctx.inp_scale, lctx.scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(lctx.inp_scale)); + } + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -14660,8 +17179,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // may need to cut off old tokens for sliding window if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; + if (hparams.n_attn_chunk) { + llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + f = -INFINITY; + } + } else { + if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } } data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; } @@ -14915,6 +17441,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } +#if IK_PRINT_TIMING + auto tim2 = ggml_time_us(); + printf("%s(...): %d us\n", __func__, int(tim2-tim1)); +#endif } // Make sure enough space is available for outputs. @@ -15923,6 +18453,115 @@ static void llama_tensor_dequantize_internal( workers.clear(); } +static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { + bool convert_incompatible_tensor = false; + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || + new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || + new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || + new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || + new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || + new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_Q4_K_R4 || + new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R8 || + new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 || + new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4 || new_type == GGML_TYPE_Q2_K_R4 || + new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4|| + new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 || + new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || + new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4) { + if (nx % QK_K != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } + } + if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN || new_type == GGML_TYPE_IQ2_BN_R4) { + if (nx % QK_IQ1BN != 0) { + convert_incompatible_tensor = true; + } + } + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: + case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XXS_R4: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_Q4_K_R4: + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: + case GGML_TYPE_Q5_K_R4: + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q6_0; break; + case GGML_TYPE_IQ6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + } + return new_type; +} + +static std::pair interleaved_properties(ggml_type type) { + static std::unordered_map> k_map = { + { GGML_TYPE_Q4_0_4_4, { GGML_TYPE_Q4_0, 4} }, + { GGML_TYPE_Q4_0_4_8, { GGML_TYPE_Q4_0, 4} }, + { GGML_TYPE_Q4_0_8_8, { GGML_TYPE_Q4_0, 8} }, + { GGML_TYPE_Q4_0_R8, { GGML_TYPE_Q4_0, 8} }, + { GGML_TYPE_Q5_0_R4, { GGML_TYPE_Q5_0, 4} }, + { GGML_TYPE_Q6_0_R4, { GGML_TYPE_Q6_0, 4} }, + { GGML_TYPE_Q8_0_R8, { GGML_TYPE_Q8_0, 8} }, + { GGML_TYPE_Q2_K_R4, { GGML_TYPE_Q2_K, 4} }, + { GGML_TYPE_Q3_K_R4, { GGML_TYPE_Q3_K, 4} }, + { GGML_TYPE_Q4_K_R4, { GGML_TYPE_Q4_K, 4} }, + { GGML_TYPE_Q5_K_R4, { GGML_TYPE_Q5_K, 4} }, + { GGML_TYPE_Q6_K_R4, { GGML_TYPE_Q6_K, 4} }, + { GGML_TYPE_IQ2_XXS_R4, { GGML_TYPE_IQ2_XXS, 4} }, + { GGML_TYPE_IQ2_XS_R4, { GGML_TYPE_IQ2_XS, 4} }, + { GGML_TYPE_IQ2_S_R4, { GGML_TYPE_IQ2_S, 4} }, + { GGML_TYPE_IQ3_XXS_R4, { GGML_TYPE_IQ3_XXS, 4} }, + { GGML_TYPE_IQ3_S_R4, { GGML_TYPE_IQ3_S, 4} }, + { GGML_TYPE_IQ4_XS_R8, { GGML_TYPE_IQ4_XS, 8} }, + { GGML_TYPE_IQ4_NL_R4, { GGML_TYPE_IQ4_NL, 4} }, + { GGML_TYPE_IQ1_S_R4, { GGML_TYPE_IQ1_S, 4} }, + { GGML_TYPE_IQ1_M_R4, { GGML_TYPE_IQ1_M, 4} }, + { GGML_TYPE_IQ2_BN_R4, { GGML_TYPE_IQ2_BN, 4} }, + { GGML_TYPE_IQ2_K_R4, { GGML_TYPE_IQ2_K, 4} }, + { GGML_TYPE_IQ3_K_R4, { GGML_TYPE_IQ3_K, 4} }, + { GGML_TYPE_IQ4_K_R4, { GGML_TYPE_IQ4_K, 4} }, + { GGML_TYPE_IQ4_KS_R4, { GGML_TYPE_IQ4_KS, 4} }, + { GGML_TYPE_IQ5_K_R4, { GGML_TYPE_IQ5_K, 4} }, + { GGML_TYPE_Q8_KV_R8, { GGML_TYPE_Q8_KV, 8} }, + { GGML_TYPE_Q8_K_R8, { GGML_TYPE_Q8_K, 8} }, + { GGML_TYPE_BF16_R16, { GGML_TYPE_BF16, 16} }, + }; + if (auto it = k_map.find(type); it != k_map.end()) return it->second; + return {type, 1}; +} + static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { const std::string name = ggml_get_name(tensor); @@ -15934,6 +18573,19 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2; }; + auto custom_type = GGML_TYPE_COUNT; + if (qs.params->custom_quants) { + using CustomQ = std::pair; + auto& q_rules = *static_cast*>(qs.params->custom_quants); + for (auto& rule : q_rules) { + std::regex pattern(rule.first); + if (std::regex_search(name, pattern)) { + custom_type = rule.second; + break; + } + } + } + //auto get_layer = [] (const char * name) { // int il; // if (sscanf(name, "blk.%d.", &il) == 1) return il; @@ -15990,7 +18642,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_IQ5_K; } else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R8 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 && - new_type != GGML_TYPE_Q8_K_R8) { + new_type != GGML_TYPE_Q8_K_R8 && new_type != GGML_TYPE_Q8_KV && new_type != GGML_TYPE_Q8_KV_R8) { new_type = GGML_TYPE_Q6_K; } } @@ -16016,67 +18668,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN_R4) { new_type = GGML_TYPE_IQ4_NL; } - else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || - new_type == GGML_TYPE_Q4_0_8_8) { - new_type = GGML_TYPE_Q4_0; - } - else if (new_type == GGML_TYPE_IQ4_NL_R4) { - new_type = GGML_TYPE_IQ4_NL; - } - else if (new_type == GGML_TYPE_IQ4_XS_R8) { - new_type = GGML_TYPE_IQ4_XS; - } - else if (new_type == GGML_TYPE_Q2_K_R4) { - new_type = GGML_TYPE_Q2_K; - } - else if (new_type == GGML_TYPE_Q3_K_R4) { - new_type = GGML_TYPE_Q3_K; - } - else if (new_type == GGML_TYPE_Q4_K_R4) { - new_type = GGML_TYPE_Q4_K; - } - else if (new_type == GGML_TYPE_Q5_K_R4) { - new_type = GGML_TYPE_Q5_K; - } - else if (new_type == GGML_TYPE_Q6_K_R4) { - new_type = GGML_TYPE_Q6_K; - } - else if (new_type == GGML_TYPE_Q8_K_R8) { - new_type = GGML_TYPE_Q8_0; - } - else if (new_type == GGML_TYPE_IQ2_K_R4) { - new_type = GGML_TYPE_IQ2_K; - } - else if (new_type == GGML_TYPE_IQ3_K_R4) { - new_type = GGML_TYPE_IQ3_K; - } - else if (new_type == GGML_TYPE_IQ3_S_R4) { - new_type = GGML_TYPE_IQ3_S; - } - else if (new_type == GGML_TYPE_IQ4_K_R4) { - new_type = GGML_TYPE_IQ4_K; - } - else if (new_type == GGML_TYPE_IQ5_K_R4) { - new_type = GGML_TYPE_IQ5_K; - } - else if (new_type == GGML_TYPE_IQ4_KS_R4) { - new_type = GGML_TYPE_IQ4_KS; - } - else if (new_type == GGML_TYPE_Q4_0_R8) { - new_type = GGML_TYPE_Q4_0; - } - else if (new_type == GGML_TYPE_Q5_0_R4) { - new_type = GGML_TYPE_Q5_0; - } - else if (new_type == GGML_TYPE_Q6_0_R4) { - new_type = GGML_TYPE_Q6_0; - } - else if (new_type == GGML_TYPE_Q8_0_R8) { - new_type = GGML_TYPE_Q8_0; - } - else if (new_type == GGML_TYPE_BF16_R16) { - new_type = GGML_TYPE_BF16; - } } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4) { if (name.find("attn_v.weight") != std::string::npos) { @@ -16176,7 +18767,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { - if (qs.model.hparams.n_expert >= 4) { + if (qs.params->attn_output_type < GGML_TYPE_COUNT) new_type = qs.params->attn_output_type; + else if (qs.model.hparams.n_expert >= 4) { new_type = GGML_TYPE_Q5_K; } else { if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_K; @@ -16451,94 +19043,24 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.i_ffn_up; } - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // IK: let's remove this, else Q2_K is almost the same as Q3_K_S - //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || - new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || - new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || - new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || - new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || - new_type == GGML_TYPE_IQ2_KT || new_type == GGML_TYPE_IQ3_KT || new_type == GGML_TYPE_IQ4_KT || - new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_Q4_K_R4 || - new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R8 || - new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 || - new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4 || new_type == GGML_TYPE_Q2_K_R4 || - new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4|| - new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 || - new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || - new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4) { - int nx = tensor->ne[0]; - int ny = tensor->ne[1]; - if (nx % QK_K != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } + if (custom_type < GGML_TYPE_COUNT) { + new_type = custom_type; + LLAMA_LOG_INFO("Using custom type %s for tensor %s\n", ggml_type_name(new_type), name.c_str()); } - if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN || new_type == GGML_TYPE_IQ2_BN_R4) { - int nx = tensor->ne[0]; - if (nx % QK_IQ1BN != 0) { - convert_incompatible_tensor = true; - } - } - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XXS_R4: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_XS_R4: - case GGML_TYPE_IQ2_KS: - case GGML_TYPE_IQ2_KT: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ2_S_R4: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_KT: - case GGML_TYPE_IQ4_KT: - case GGML_TYPE_IQ3_XXS_R4: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ3_S_R4: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q2_K_R4: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q3_K_R4: - case GGML_TYPE_IQ2_K: - case GGML_TYPE_IQ2_K_R4: - case GGML_TYPE_IQ3_K: - case GGML_TYPE_IQ3_K_R4: - case GGML_TYPE_IQ4_KSS: - case GGML_TYPE_IQ4_KS: - case GGML_TYPE_IQ4_KS_R4: - case GGML_TYPE_IQ4_XS_R8: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_IQ4_K: - case GGML_TYPE_IQ4_K_R4: - case GGML_TYPE_Q4_K_R4: - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_IQ5_K: - case GGML_TYPE_IQ5_K_R4: - case GGML_TYPE_Q5_K_R4: - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q6_0; break; - case GGML_TYPE_IQ6_K: - case GGML_TYPE_Q6_K_R4: - case GGML_TYPE_Q8_K_R8: - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); - } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + + auto working_type = change_type_if_necessary(new_type, tensor->ne[0], tensor->ne[1]); + if (working_type != new_type) { ++qs.n_fallback; + new_type = working_type; + } + + if (name == "token_embd.weight") { + auto working_type = interleaved_properties(new_type).first; + if (working_type != new_type) { + printf("\n============ Token embeddings cannot be quantized with row-interleaved quants\n"); + printf("---> Changed %s to %s\n", ggml_type_name(new_type), ggml_type_name(working_type)); + new_type = working_type; + } } return new_type; @@ -16598,17 +19120,55 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa return new_size; } +static llama_ftype repacked_ftype(llama_ftype ftype) { + static std::unordered_map k_map = { + { LLAMA_FTYPE_MOSTLY_Q4_0, LLAMA_FTYPE_MOSTLY_Q4_0_R8 }, + { LLAMA_FTYPE_MOSTLY_Q8_0, LLAMA_FTYPE_MOSTLY_Q8_0_R8 }, + { LLAMA_FTYPE_MOSTLY_Q5_0, LLAMA_FTYPE_MOSTLY_Q5_0_R4 }, + { LLAMA_FTYPE_MOSTLY_Q2_K, LLAMA_FTYPE_MOSTLY_Q2_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q3_K_S, LLAMA_FTYPE_MOSTLY_Q3_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q3_K_M, LLAMA_FTYPE_MOSTLY_Q3_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q3_K_L, LLAMA_FTYPE_MOSTLY_Q3_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q4_K_S, LLAMA_FTYPE_MOSTLY_Q4_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q4_K_M, LLAMA_FTYPE_MOSTLY_Q4_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q5_K_S, LLAMA_FTYPE_MOSTLY_Q5_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q5_K_M, LLAMA_FTYPE_MOSTLY_Q5_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q6_K, LLAMA_FTYPE_MOSTLY_Q6_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ2_XXS, LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ2_XS, LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ3_XXS, LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ1_S, LLAMA_FTYPE_MOSTLY_IQ1_S_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ4_NL, LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ3_S, LLAMA_FTYPE_MOSTLY_IQ3_S_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ2_M, LLAMA_FTYPE_MOSTLY_IQ2_M_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ4_XS, LLAMA_FTYPE_MOSTLY_IQ4_XS_R8 }, + { LLAMA_FTYPE_MOSTLY_IQ1_M, LLAMA_FTYPE_MOSTLY_IQ1_M_R4 }, + { LLAMA_FTYPE_MOSTLY_Q6_0, LLAMA_FTYPE_MOSTLY_Q6_0_R4 }, + { LLAMA_FTYPE_MOSTLY_BF16, LLAMA_FTYPE_MOSTLY_BF16_R16 }, + { LLAMA_FTYPE_MOSTLY_IQ2_BN, LLAMA_FTYPE_MOSTLY_IQ2_BN_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ2_K, LLAMA_FTYPE_MOSTLY_IQ2_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ3_K, LLAMA_FTYPE_MOSTLY_IQ3_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ4_K, LLAMA_FTYPE_MOSTLY_IQ4_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ5_K, LLAMA_FTYPE_MOSTLY_IQ5_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ4_KS, LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 }, + { LLAMA_FTYPE_MOSTLY_Q8_KV, LLAMA_FTYPE_MOSTLY_Q8_KV_R8 }, + }; + if (auto it = k_map.find(ftype); it != k_map.end()) return it->second; + return ftype; +} + static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; llama_ftype ftype = params->ftype; - switch (params->ftype) { + switch (ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q6_0: default_type = GGML_TYPE_Q6_0; break; case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_Q8_KV:default_type = GGML_TYPE_Q8_KV;break; case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_MOSTLY_BF16_R16: default_type = GGML_TYPE_BF16_R16; break; @@ -16632,6 +19192,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; case LLAMA_FTYPE_MOSTLY_Q6_K_R4: default_type = GGML_TYPE_Q6_K_R4; break; case LLAMA_FTYPE_MOSTLY_Q8_K_R8: default_type = GGML_TYPE_Q8_K_R8; break; + case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: default_type = GGML_TYPE_Q8_KV_R8; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:default_type = GGML_TYPE_IQ2_XXS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; @@ -16702,7 +19263,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, kv_overrides); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, /* use_thp */ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model; @@ -16715,7 +19276,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s ftype = model.ftype; } const std::unordered_map> * imatrix_data = nullptr; - if (params->imatrix) { + if (!params->only_repack && params->imatrix) { imatrix_data = static_cast>*>(params->imatrix); if (imatrix_data) { LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); @@ -16737,7 +19298,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // copy the KV pairs from the input file gguf_set_kv (ctx_out, ml.meta); gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV - gguf_set_val_u32(ctx_out, "general.file_type", ftype); // TODO: use LLM_KV // Remove split metadata gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_NO).c_str()); @@ -16762,11 +19322,41 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } + bool is_repacked = ml.ftype >= LLAMA_FTYPE_MOSTLY_Q4_0_R8 && ml.ftype <= LLAMA_FTYPE_MOSTLY_Q8_K_R8; + int n_to_repack = 0, n_to_modify = 0; + const std::vector * repack_pattern = nullptr; + if (params->repack_pattern) repack_pattern = (const std::vector *)params->repack_pattern; + for (int i = 0; i < ml.n_tensors; ++i) { const struct ggml_tensor * meta = ml.get_tensor_meta(i); const std::string name = ggml_get_name(meta); + if (params->only_repack) { + auto repacked_type = (ggml_type)iqk_repacked_type(meta); + bool repack = false, modify = false; + if (repacked_type != meta->type) { + repack = true; + } else if (!is_repacked) { + if (iqk_should_modify_tensor(meta)) { + modify = true; + } + } + if ((repack || modify) && repack_pattern) { + bool found = false; + for (auto& r : *repack_pattern) { + std::regex pattern(r); + if (std::regex_search(name, pattern)) { + found = true; + break; + } + } + if (!found) repack = modify = false; + } + if (repack) ++n_to_repack; + else if (modify) ++n_to_modify; + } + // TODO: avoid hardcoded tensor names - use the TN_* constants if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { @@ -16776,6 +19366,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } + if (params->only_repack) { + if (n_to_repack == 0 && n_to_modify == 0) { + printf("=========================== %s: nothing to do for only_repack option\n", __func__); + return; + } + ftype = repacked_ftype(model.ftype); + printf("===================== Model ftype: %s: Repacked ftype: %s\n", llama_model_ftype_name(model.ftype).c_str(), + llama_model_ftype_name(ftype).c_str()); + } + + gguf_set_val_u32(ctx_out, "general.file_type", ftype); // TODO: use LLM_KV + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks @@ -16783,8 +19385,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // - qs.n_attention_wv == 0 for Mamba models // - qs.n_attention_wv == model.hparams.n_layer for Transformer models // - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models + // - model.arch == LLM_ARCH_DECI for Deci-Nemotron models // - GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected"); + GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected"); size_t total_size_org = 0; size_t total_size_new = 0; @@ -16916,14 +19519,62 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s void * new_data; size_t new_size; - if (quantize) { - new_type = default_type; - if (new_type == GGML_TYPE_BF16_R16 && strcmp(tensor->name, "token_embd.weight") == 0) { - new_type = GGML_TYPE_BF16; + if (params->only_repack) { + ggml_type repacked_type = (ggml_type)iqk_repacked_type(tensor); + bool modify = !is_repacked && iqk_should_modify_tensor(tensor); + if ((modify || repacked_type != tensor->type) && repack_pattern) { + bool found = false; + for (auto& r : *repack_pattern) { + std::regex pattern(r); + if (std::regex_search(tensor->name, pattern)) { + found = true; break; + } + } + if (!found) { + modify = false; + repacked_type = tensor->type; + } } + if (modify || repacked_type != tensor->type) { + new_type = repacked_type; + new_size = ggml_nbytes(tensor); + if ((int)work.size() < new_size) work.resize(new_size); + new_data = work.data(); + + auto aux_tensor = *tensor; + aux_tensor.data = work.data(); + std::memcpy(aux_tensor.data, tensor->data, new_size); + + if (repacked_type != tensor->type) { + iqk_repack_tensor(&aux_tensor); + GGML_ASSERT(aux_tensor.type == repacked_type); + } else { + bool did_modify = iqk_modify_tensor(&aux_tensor); + GGML_ASSERT(did_modify); + } + } + else { + new_type = tensor->type; + new_size = ggml_nbytes(tensor); + new_data = tensor->data; + } + LLAMA_LOG_INFO("size = %8.3f MB, type = %s\n", new_size/1024.0/1024.0, ggml_type_name(new_type)); + goto QuantizationDone; + } + + if (quantize) { + + new_type = default_type; // get more optimal quantization type based on the tensor shape, layer, etc. - if (!params->pure && ggml_is_quantized(default_type)) { + if (params->pure) { + auto working_type = change_type_if_necessary(new_type, tensor->ne[0], tensor->ne[1]); + if (working_type != new_type) { + ++qs.n_fallback; + new_type = working_type; + } + } + else if (ggml_is_quantized(default_type)) { new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { @@ -16957,6 +19608,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type = params->ffn_up_type; } + if (strcmp(tensor->name, "token_embd.weight") == 0) { + // token embeddings cannot be quantized with row-interleaved quants + auto working_type = interleaved_properties(new_type).first; + if (working_type != new_type) { + printf("\n============ Token embeddings cannot be quantized with row-interleaved quants\n"); + printf("---> Changed %s to %s\n", ggml_type_name(new_type), ggml_type_name(working_type)); + new_type = working_type; + } + } + // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. quantize = tensor->type != new_type; @@ -16973,6 +19634,23 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const float * imatrix = nullptr; if (imatrix_data) { auto it = imatrix_data->find(tensor->name); + if (it == imatrix_data->end()) { + // MLA hack: most imatrix files floating around the Internet have been computed with standard attention. + // This means that the imatrix file does not contain data for the *.attn_k_b.weight and *.attn_v_b.weight + // required by MLA. But the *.attn_v_b.weight tensors "see" the exact same activations as the + // *.attn_kv_b.weight tensors used in standard attention. Hence, if we find imatrix data for + // *.attn_kv_b.weight we can use it for *.attn_v_b.weight and vice versa. + std::string name{tensor->name}; + static std::array alternatives{".attn_v_b.weight", ".attn_kv_b.weight"}; + for (int j = 0; j < int(alternatives.size()); ++j) { + if (auto pos = name.find(alternatives[j]); pos != std::string::npos) { + int j1 = (j + 1) % alternatives.size(); + auto alternative_name = name.substr(0, pos) + alternatives[j1]; + it = imatrix_data->find(alternative_name); + break; + } + } + } if (it == imatrix_data->end()) { LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); } else { @@ -17004,7 +19682,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type == GGML_TYPE_IQ1_S_R4|| new_type == GGML_TYPE_IQ1_M_R4|| (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || - (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0))) { + (new_type == GGML_TYPE_Q2_K && ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0))) { LLAMA_LOG_ERROR("\n\n============================================================\n"); LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); @@ -17024,115 +19702,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } int chunk_size_multiplier = 1; - if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) { - if ((new_type == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = GGML_TYPE_Q4_0; - else if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0; - if (new_type == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8; - else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ4_NL_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_NL; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ4_XS_R8) { - if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_IQ4_XS; - else chunk_size_multiplier = 8; - } - else if (new_type == GGML_TYPE_Q4_0_R8) { - if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q4_0; - else chunk_size_multiplier = 8; - } - else if (new_type == GGML_TYPE_Q5_0_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_0; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q6_0_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_0; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q8_0_R8) { - if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; - else chunk_size_multiplier = 8; - } - else if (new_type == GGML_TYPE_Q2_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q2_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q3_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q3_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q4_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q5_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q6_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q8_K_R8) { - if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; - else chunk_size_multiplier = 8; - } - else if (new_type == GGML_TYPE_IQ2_BN_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ2_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ3_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ4_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ5_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ5_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ4_KS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_KS; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ2_XXS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_XXS; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ2_XS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_XS; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ2_S_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_S; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ3_XXS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_XXS; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ3_S_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_S; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ1_S_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ1_S; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ1_M_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ1_M; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_BF16_R16) { - if (tensor->ne[1] % 16 != 0) new_type = GGML_TYPE_BF16; - else chunk_size_multiplier = 16; + auto [working_type, num_rows] = interleaved_properties(new_type); + if (tensor->ne[1] % num_rows != 0) { + new_type = working_type; + } else { + chunk_size_multiplier = num_rows; } LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); @@ -17165,6 +19739,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); } + +QuantizationDone:; total_size_org += ggml_nbytes(tensor); total_size_new += new_size; @@ -17413,11 +19989,13 @@ struct llama_model_params llama_model_default_params() { /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.tensor_buft_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.repack_tensors =*/ false, + /*.use_thp =*/ false, }; #ifdef GGML_USE_METAL @@ -17456,6 +20034,11 @@ struct llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.mla_attn =*/ 0, + /*.attn_max_batch =*/ 0, + /*.fused_moe_up_gate =*/ false, + /*.min_experts =*/ -1, + /*.thtesh_experts =*/ 0.0f, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17483,8 +20066,11 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.pure =*/ false, /*.keep_split =*/ false, /*.ignore_imatrix_rules =*/ false, + /*.only_repack =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.custom_quants =*/ nullptr, + /*.repack_pattern =*/ nullptr, }; return result; @@ -17628,10 +20214,10 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { - LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); - params.flash_attn = false; - } + //if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { + // LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); + // params.flash_attn = false; + //} if (params.type_v != GGML_TYPE_F16 && params.type_v != GGML_TYPE_BF16 && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); @@ -17654,6 +20240,12 @@ struct llama_context * llama_new_context_with_model( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; + cparams.attn_max_batch = params.attn_max_batch; + cparams.fused_moe_up_gate= params.fused_moe_up_gate; + cparams.min_experts = params.min_experts; + cparams.thresh_experts = params.thresh_experts; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17716,10 +20308,21 @@ struct llama_context * llama_new_context_with_model( params.seed = time(NULL); } + if (model->arch != LLM_ARCH_DEEPSEEK2 && cparams.mla_attn > 0) { + LLAMA_LOG_WARN("=====================================================================\n"); + LLAMA_LOG_WARN(" MLA is only available for LLM_ARCH_DEEPSEEK2 -> turning off MLA\n"); + LLAMA_LOG_WARN("=====================================================================\n"); + cparams.mla_attn = 0; + } + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); + LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); + LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); + LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -17912,10 +20515,44 @@ struct llama_context * llama_new_context_with_model( memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + if (memory_size_k + memory_size_v > 0) { + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } + } + + { + size_t memory_size_kv = 0; + size_t memory_size_kvt = 0; + + ggml_type kv_type = GGML_TYPE_COUNT; + ggml_type kvt_type = GGML_TYPE_COUNT; + + for (auto & kv : ctx->kv_self.kv_l) { + memory_size_kv += ggml_nbytes(kv); + kv_type = kv->type; + } + + for (auto & kvt : ctx->kv_self.kvt_l) { + memory_size_kvt += ggml_nbytes(kvt); + kvt_type = kvt->type; + } + + if (memory_size_kv + memory_size_kvt > 0) { + if (cparams.mla_attn == 1 && !cparams.flash_attn) { + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, + (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), + ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), + ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f)); + } else { + GGML_ASSERT(memory_size_kvt == 0); + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T: not used\n", __func__, + (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), + ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f)); + } + } } // graph outputs buffer @@ -18050,6 +20687,8 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_DECI: + case LLM_ARCH_LLAMA4: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: case LLM_ARCH_PLAMO: @@ -18062,8 +20701,10 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_CHATGLM: + case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_COHERE2: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -18074,13 +20715,18 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: + case LLM_ARCH_BITNET_25: + case LLM_ARCH_BITNET_B158: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: + case LLM_ARCH_QWEN3: + case LLM_ARCH_QWEN3MOE: case LLM_ARCH_PHI2: case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: + case LLM_ARCH_GEMMA3: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: @@ -18115,6 +20761,10 @@ int32_t llama_n_layer(const struct llama_model * model) { return model->hparams.n_layer; } +int32_t llama_n_head(const struct llama_model * model) { + return model->hparams.n_head(); +} + float llama_rope_freq_scale_train(const struct llama_model * model) { return model->hparams.rope_freq_scale_train; } @@ -19874,6 +22524,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { return LLM_CHAT_TEMPLATE_MEGREZ; + } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { + return LLM_CHAT_TEMPLATE_LLAMA4; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -20255,6 +22907,34 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|role_start|>assistant<|role_end|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) { + // Llama 4 + for (auto message : chat) { + std::string role(message->role); + ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>"; + } + if (add_ass) { + ss << "<|header_start|>assistant<|header_end|>\n\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_BITNET) { + // bitnet-25 + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "System: "; + ss << message->content; + } else if (role == "user") { + ss << "User: "; + if (!system_prompt.empty()) { + ss << system_prompt; + system_prompt = ""; + } + ss << message->content << "<|eot_id|>Assistant: "; + } else { + ss << message->content; + } + } } else { // template not supported return -1;