Files
ik_llama.cpp/ggml/src/ggml-cuda/dequantize.cuh
Kawrakow 633e0617b0 Enable CUDA graphs for MoE models + GPT-OSS support (#689)
* gmp-oss: common

* gpt-oss: attnetion sinks, swiglu_oai

* gpt-oss: WIP llama

Model loads and runs (CPU only), but PPL is much to high
(~1500 for 1st batch vs ~200 in mainline).
Is it because of SWA, because of vocab, or did I introduce a bug somewhere?

* gpt-oss: CPU seems to be working

It was the SWA thta was missing in the previous commit.

There are issues with EOG tokens, so this still needs to be added.

* CUDA: ADD_ID

Just a copy from mainline

* gpt-oss: Seems to be working on CUDA

* gpt-oss: add sinks to the attn-vec kernels

* CUDA: add head size of 64 to new mma

Haven't turned it on yet, but observe slightly better PP and slightly
worse TG performance with that.

* gpt-oss: add ability to use -fmoe (only CUDA for now)

* Move row sums to the write place

* Add sinks to iqk flash attention

* gpt_oss: Implement -fmoe on the CPU

* Simdify swiglu_oai

Turning it off for now as performance becomes more variable,
so perhaps I'm running into thermal trottling imore often
because of making the CPU work too hard.

* llama: factor out model loader

* Builds successfully

* It runs, but mmap does not work

* Fix llama_mmap so mmap works

* Minor

* Fix CUDA after latest changes

* Attempt to use CUDA graphs with MoE models - not working

* CUDA graphs WIP - still not working

* CUDA graphs - seems to be working

Likely not all MLA variants are working.
I no longer remember why I added the q8_0 cpy that
transposes the tensor, but if really needed, this is now
missing. Also missing is q6_0.

* Make q8_0 cache work for DeepSeek models with CUDA graphs

* cuda: cpy for q6_0

* Fix llama_mmap on non-Linux platforms

* Adding forgotten file

* Iterating on Windows build failures

* cuda: re-add q8_0 -> q8_0 transpose

so mla = 2 can be used with CUDA graphs and q8_0 cache.

* Disable graphs without -fmoe

* Minor

* Turn graphs on by default

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-08-15 09:18:07 +03:00

122 lines
3.1 KiB
Plaintext

#include "common.cuh"
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
const dfloat d = x[ib].d;
const int vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
#ifdef GGML_CUDA_F16
v = __hsub2(v, {8.0f, 8.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 8.0f) * d;
v.y = (v.y - 8.0f) * d;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
const int vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q5_0 * x = (const block_q5_0 *) vx;
const dfloat d = x[ib].d;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
#ifdef GGML_CUDA_F16
v = __hsub2(v, {16.0f, 16.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 16.0f) * d;
v.y = (v.y - 16.0f) * d;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q6_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q6_0 * x = (const block_q6_0 *) vx;
const dfloat d = x[ib].d;
const uint8_t h = x[ib].qh[iqs%8] >> 2*(iqs/8);
v.x = ((x[ib].qs[iqs] & 0xf) | ((h & 0x3) << 4));
v.y = ((x[ib].qs[iqs] >> 4) | ((h & 0xc) << 2));
#ifdef GGML_CUDA_F16
v = __hsub2(v, {32.0f, 32.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 32.0f) * d;
v.y = (v.y - 32.0f) * d;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx;
const dfloat d = x[ib].d;
v.x = x[ib].qs[iqs + 0];
v.y = x[ib].qs[iqs + 1];
#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
#else
v.x *= d;
v.y *= d;
#endif // GGML_CUDA_F16
}