This commit is contained in:
Iwan Kawrakow
2025-05-14 11:53:11 +03:00
parent 7ec38dab8a
commit 775b9091cb

View File

@@ -2503,8 +2503,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;
auto values = iq5nl_values;
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;
#pragma unroll
@@ -2517,17 +2515,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_iq5_k * bxi = (const block_iq5_k *)(x + i*stride) + kbx0;
// kqsx = 0 -> 0,1,2,3 + 8,9,10,11
// kqsx = 1 -> 4,5,6,7 + 12,13,14,15
// kqsx = 2 -> 16,17,18,19 + 24,25,26,27
// kqsx = 3 -> 20,21,22,23 + 28,29,30,31
// or is it better
// kqsx = 0 -> 0,1 + 8,9 + 16,17 + 24,25
// kqsx = 1 -> 2,3 + 10,11 + 18,19 + 26,27, etc.
// or perhaps even
// kqsx = 0 -> 0, 8, 16, 24, 32, 40, 48, 56
// kqsx = 1 -> 1, 9, 17, 25, 33, 41, 49, 57, etc.
int qh = get_int_b4(bxi->qh, kqsx);
uint16_t extra = bxi->extra >> (kqsx/4);
@@ -2535,13 +2522,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
for (int l = 0; l < qstep/2; ++l) {
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) << 5);
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) << 3);
aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4);
aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3);
qh >>= 2;
auto values_l = iq5nl_values + ((extra & 1) << 5);
auto values_h = iq5nl_values + ((extra & 4) << 3);
extra >>= 4;
const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]);
const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]);
const char4 val0 = make_char4(values_l[aux8[0]], values_l[aux8[1]], values_l[aux8[2]], values_l[aux8[3]]);
const char4 val1 = make_char4(values_h[aux8[4]], values_h[aux8[5]], values_h[aux8[6]], values_h[aux8[7]]);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0;
@@ -2552,9 +2542,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
#endif // INT8_MMA_AVAILABLE
}
// iq4_k: scales_h[ib/8] |= (l_h << 2*(ib%8)); ib: 0...15
// iq5_k: scales_h[ib/4] |= (l_h << 2*(ib%4)); ib: 0...15
const uint8_t sh = bxi->scales_h[kqsx/2] >> 4*(kqsx%2);
const int ls1 = ((bxi->scales_l[kqsx] & 0xf) | ((sh << 4) & 0x30)) - 32;
const int ls2 = ((bxi->scales_l[kqsx] >> 4) | ((sh << 2) & 0x30)) - 32;