diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 039d896a..54de45c7 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2401,16 +2401,16 @@ template static __device__ __forceinlin const float d = bxi->d; - uint16_t extra = bxi->extra >> 4*(kqsx/4); + uint16_t extra = bxi->extra >> (kqsx/4); #pragma unroll for (int l = 0; l < qstep/4; ++l) { const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); - aux32[0] = ((ql >> 0) & 0x03030303) | ((extra & 1) * 0x04040404); - aux32[1] = ((ql >> 2) & 0x03030303) | ((extra & 2) * 0x02020202); - aux32[2] = ((ql >> 4) & 0x03030303) | ((extra & 4) * 0x01010101); - aux32[3] = ((ql >> 6) & 0x03030303) | (((extra >> 1) & 4) * 0x01010101); + aux32[0] = ((ql >> 0) & 0x03030303) | (((extra << 2) & 4) * 0x01010101); + aux32[1] = ((ql >> 2) & 0x03030303) | (((extra << 0) & 4) * 0x01010101); + aux32[2] = ((ql >> 4) & 0x03030303) | (((extra >> 2) & 4) * 0x01010101); + aux32[3] = ((ql >> 6) & 0x03030303) | (((extra >> 4) & 4) * 0x01010101); extra >>= 8; const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]);