Commit Graph

3368 Commits

Author SHA1 Message Date
Iwan Kawrakow
0d19d19af8 iq3_k: CUDA dot product
Slightly slower than iq3_s - 132 t/s vs 138 t/s for
LLaMA-3.1-8B.
2024-08-01 09:38:06 +02:00
Iwan Kawrakow
4f237d44f6 iq3_k: Basics
Quantize/dequantize, CUDA dequantize.
PPL of LLaMA-3.1-8B is better than iq3_s and iq3_m.
2024-08-01 09:38:06 +02:00
Iwan Kawrakow
36204c4ec7 iq2_k: very slightly better CUDA dot product
169.2 t/s vs 167.8 t/s before.
2024-08-01 09:38:06 +02:00
Iwan Kawrakow
e950b17125 iq2_k: better CUDA dot product
Almost on par with iq2_xs (168 t/s vs 172 t/s).
2024-08-01 09:38:06 +02:00
Iwan Kawrakow
ab4f9e1fdb iq2_k: CUDA dot product finally works
Performance is pathetic: 140 t/s for LLaMA-3.1-8B vs
172 t/s for iq2_xs.
2024-08-01 09:38:06 +02:00
Iwan Kawrakow
69842c6ad8 iq5_k: CUDA dot product finally works 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
f6813cac0e Factor out iqk CUDA dot products
I cannot possibly wait for a 5 minutes nvcc compilation
each time I touch vecdotq.cuh.

Also, cmake was adding --options-file X.rsp to the nvcc
compile commands, which confuses clangd, so I have turned
that off.
2024-08-01 09:38:06 +02:00
Iwan Kawrakow
22d1568c1c iq5_k: CUDA dot product still not working 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
d8d022a01b iq5_k: Metal
Performance is roughly on par with q5_0.
2024-08-01 09:38:06 +02:00
Iwan Kawrakow
bd36ade98d iq5_k: NEON 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
c0d0607f19 iq5_k: AVX512 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
c56ddee38c iq5_k: AVX2 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
5d341757bc iq5_k: Basics
Quantize/dequantize, CUDA dequantize
2024-08-01 09:38:06 +02:00
Iwan Kawrakow
06e255ac9d iq2_k: Metal. Dot product is wrong 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
f476ea3b50 iq2_k: NEON 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
c0fe03b5c8 iq2_k: slightly faster AVX512 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
7d08719975 iq2_k: simplify AVX512 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
13091d39e8 iq2_k: AVX2 2024-08-01 09:38:06 +02:00
Iwan Kawrakow
c85e139c68 iq2_k: Basics
Quantize/dequantize, CUDA deqantize, AVX512 iqk_mul_mat.
2024-08-01 09:38:06 +02:00
Kawrakow
291066e6df IQ4_K: SOTA 4-bit quantization (#6)
* iq4_k: basics

* quantize/dequantize works
* CUDA dequantize works and one can run PPL calcs. I get
  PPL = 6.5258 for LlaMA-3.1-8B, which is 1.77% above fp16.
  In comparison, q4_K_S (same size) is 2.88% above fp16.
* TG on CUDA does not work. Johannes has changed the way i-quant dot
  products are done, so need to sort out what he had in mind
* iqk_mul_mat is not implemented.

* iq4_k: TG now works on CUDA

* iq4_k: AVX512 implementation

For LLaMA-3.1-8B we get PP-512 = 182.6 t/s, TG-128 = 13.6 t/s,
so almost the same as q4_K_S.

* iq4_k: AVX2 implementation

For LLaMA-3.1-8B we get PP-512 = 203.1 t/s, TG-128 = 12.9 t/s
on the Ryzen-5975X.

* iq4_k: NEON implementation

For LLaMA-3.1-8B we get PP-512 = 60.7 t/s, TG-128 = 25.0 t/s
on the M2-Max. TG is on par with q4_K_S, PP is ~10% slower.

* iq4_k: Metal implementation

For LLaMA-3.1-8B we get PP-512 = 445 t/s, TG-128 = 46.3 t/s
on a 30-core M2-Max GPU. This is to be compared with (currently)
PP-512 = 460 t/s, TG-128 = 51 t/s for q4_K_S.

* iq4_k: scalar dot product

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-07-28 12:11:59 +02:00
Kawrakow
f62615b44f Simdify and multi-thread tanh (#4)
It seemed Gemma-2 performance is lower than expected for its size.
Looking at the architecture, I noticed that tanh is used in each layer,
and then at the end for softcaping the final output. ggml had tanh
set to be computed with a single thread. Combined with tanh(x) being a
pretty expensive operation, this resulted in a significant fraction
of the time being spent in the tanh operation.

After multi-threading ggml_vec_soft_max_f32 and simd-ifying the
tanh computation, I observe a 33% gain in prompt processing speed (!!!)
TG is of course memory bound, but despite this, we still get a
~2% boost at 4 threads (which gives max TG performance on my
Ryzen-7950X).

Simd-ifying:
We have
   tanh(x) = (exp(2*x) - 1)/(exp(2*x) + 1)
so we can just use Justine Tunney's SIMD exp implementation.

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-07-27 08:44:18 +02:00
Kawrakow
154e0d75fc Merge mainline llama.cpp (#3)
* Merging mainline - WIP

* Merging mainline - WIP

AVX2 and CUDA appear to work.
CUDA performance seems slightly (~1-2%) lower as it is so often
the case with llama.cpp/ggml after some "improvements" have been made.

* Merging mainline - fix Metal

* Remove check

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-07-27 07:55:01 +02:00
Kawrakow
0684c3e9c7 Offload Bitnet token embeddings to the GPU - the right way (#2)
OK, I should have checked how it was done for Gemma and do
the same for Bitnet. But better late than never.

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-07-26 12:57:23 +02:00
Kawrakow
94b5916319 Offload Bitnet token embeddings to the GPU (#1)
* bitnet: put token embeddings on the GPU

* Update README with the new CUDA/Meat performance

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-07-26 09:41:04 +02:00
Iwan Kawrakow
c2158c15d9 iqk_mul_mat(NEON): adding forgotten fp16 matrix x vector implementation 2024-07-25 08:37:13 +02:00
Kawrakow
28fb349db4 Update README.md 2024-07-24 19:55:06 +02:00
Kawrakow
eb246cd0ae Update README.md
Trying to avoid line breaks in table
2024-07-24 19:44:52 +02:00
Kawrakow
fc07ca7847 Update README.md 2024-07-24 19:20:46 +02:00
Iwan Kawrakow
770f3585c2 Add copyright notices
Only on the files where I have contributed in a significant way,
or the files I wrote myself.
2024-07-24 20:11:42 +03:00
Iwan Kawrakow
9eee03f4ee Remove unused file 2024-07-24 19:33:19 +03:00
Iwan Kawrakow
3d83f58654 Remove security 2024-07-24 19:25:21 +03:00
Iwan Kawrakow
b64275ca4e Correct spelling in README 2024-07-24 19:22:43 +03:00
Kawrakow
4192244242 Update README.md
Adding some more details
2024-07-24 17:38:37 +02:00
Kawrakow
47c1243e3c Update README.md
Adding MoE and Bitnet performance tables
2024-07-24 16:49:00 +02:00
Kawrakow
8fe7e04456 Update README.md
I hate it when tables look fine in the Preview but then end up with columns split into 2 lines when committed. That's what is happening here, so removed test column from the performance tables.
2024-07-24 11:18:50 +02:00
Kawrakow
a5c39e9476 Update README.md
Added performance comparison tables
2024-07-24 11:01:16 +02:00
Iwan Kawrakow
6b4167164c iqk_mul_mat(NEON): special case for n not divisible by 8
Else fp16 PP performance drops by nearly a factor of 2 compared to
what we had before.
2024-07-24 08:04:47 +02:00
Iwan Kawrakow
2e49f0172f ggml: thread syncronization on Arm
For x86 slaren was genereous enough to add _mm_pause() to the busy
spin wait loop in ggml_barrier(), but everything else just busy
spins, loading an atomic int on every iteration, thus forcing cache
sync between the cores. This results in a massive drop in performance
on my M2-Max laptop when using 8 threads. The closest approximation
to _mm_pause() on Arm seems to be
     __asm__ __volatile__("isb\n");
After adding this to the busy spin loop, performance for 8 threads
recovers back to expected levels.
2024-07-24 08:04:47 +02:00
Iwan Kawrakow
abb740c9a4 Fix "make it work for row sizes that are multiple of 4 on NEON" 2024-07-24 08:04:47 +02:00
Kawrakow
0117e386b3 Update README.md 2024-07-23 18:05:05 +02:00
Kawrakow
11e2472c64 Update README.md 2024-07-23 12:23:06 +02:00
Iwan Kawrakow
99119ec29c When tokenizer info is missing in the model, use llama3 by default 2024-07-19 12:29:01 +03:00
Iwan Kawrakow
30b8bcf1a3 iqk_mul_mat(f16): make it work for row sizes that are multiple of 4 on NEON
Here the performance gain is more modest compared to AVX2: we get
PP-512 = 200 t/s up from 190 t/s for iq1_bn-quantized Bitnet-3B
running on M2 Max.
2024-07-18 13:55:51 +02:00
Iwan Kawrakow
8db01c0804 iqk_mul_mat: attentions matrix multiplications
K*Q and KQ*V are n_kv_embed x n_token x n_head matrix multiplications.
Before this PR, this meant n_head calls to iqk_mul_mat to perform
n_kv_embed x n_token 2D multiplications, each using nth threads.
Instead, in this PR, if n_head is a multiple of nth, each thread
does n_head/nth multiplications of the n_kv_embed x n_token 2D matrices.
This improves PP-512(32 threads) for Bitnet-3B to 433 t/s up from
409 t/s. It is beneficial in other cases too. E.g., for LLaMA-7B,
we go to 201 t/s up from 193 t/s for q4_K_S, and to 144 t/s up from
139 t/s for fp16. All these numbers are for the Ryzen-7950X CPU.
2024-07-18 14:00:56 +03:00
Iwan Kawrakow
744eb9ffa9 iqk_mul_mat(float): make it work for row sizes that are multiple of 4 on AVX2
I was trying to understand where the Bitnet bottleneck is, and at
some point noticed the Q*K matrixt multiplication where Q and K
have the shape of 100 x n_token x 32 x 1. The existing iqk_mul_mat for
floats rerquiers that the row size is a multiple of the SIMD vector size
(so, 16 on the Ryzen-7950X, 8 on the Ryzen-5975), and hence this
matrix multiiplication was getting done with ggml. Changing the iqk_mul_mat
float kernel to handle row sizes that are a multiple of 4 (via __m128
for the last values in a row) resulted in nearly a 20% performance boost
for PP-512 and ~3% for TG-128! If I go to a context of 2048, PP performance
increases by nearly 70%!
2024-07-18 11:39:32 +03:00
Iwan Kawrakow
6a132862fd Fix Makefile, add GGML_USE_IQK_MULMAT ifdefs to iqk-quantize 2024-07-17 16:51:34 +03:00
Iwan Kawrakow
a4017cc047 iq1bn: faster scalar dot product
At the end of the day, lookup is still better when not using simd.
This scalar dot product version gets us 14.7 t/s on a Ryzen-7950X
with 16 threads (up from 10.5 t/s).
2024-07-17 16:09:01 +03:00
Iwan Kawrakow
a0df4002fc iq1bn: fix scalar dot product
The fix makes it faster on the Ryzen-7950X (10.5 t/s vs 8.2 t/s)
but slower on the M2 (6.8 t/s vs 8.6 t/s before).
2024-07-17 13:37:18 +03:00
Iwan Kawrakow
7024ecfeb4 iq1bn: faster AVX2
Instead of shuffling quant data into a 128-bit register containing
8-bit ints, and then converting to 16 bit, we directly shuffle into
a 256-bit register containing 16 bit ints.

TG-128 @ 2 threads goes from 18.3 to 21.6 t/s.
TG-128 performance now saturates already at 8 threads getting 60.4 t/s.
There is almost no impact on PP-512 (322 -> 323 t/s). I guess,
we amortize dequantization cost pretty well, so we don't gain much
there.

We get close to 100 GB/s single-threaded float32 throuput:

./bin/test-quantize-perf --op vec_dot_q -i 10000000 --type iq1_bn
iq1_bn
  vec_dot_q
    4096 values (0.02 MB)
      min cycles/32 vals   :      3.87
      avg cycles/32 vals   :      4.40
      float32 throughput   :     98.27 GB/s
      quantized throughput :      4.99 GB/s
2024-07-17 10:17:05 +03:00
Iwan Kawrakow
febb8bbea0 Remove the no longer used iq1bn_grid_u16 2024-07-17 10:16:50 +03:00