diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 2caea5a0..d6f4cf3a 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2703,6 +2703,21 @@ template static __device__ __forceinlin } } +static const __device__ uint16_t iq3k_table[128] = { + 0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f, + 0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f, + 0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f, + 0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f, + 0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33, + 0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33, + 0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133, + 0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333, +}; + +__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) { + return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16); +} + template static __device__ __forceinline__ void load_tiles_iq3_ks( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -2718,8 +2733,6 @@ template static __device__ __forceinlin constexpr int qstep = 8; const int kqsx = threadIdx.x % qstep; - auto values = iq3nl_values; - uint32_t aux32[4]; const uint8_t * aux8 = (const uint8_t *)aux32; #pragma unroll @@ -2741,28 +2754,29 @@ template static __device__ __forceinlin for (int l = 0; l < qstep/4; ++l) { const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); - aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404) | (((extra << 3) & 8) * 0x01010101); - aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404) | (((extra << 2) & 8) * 0x01010101); - aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404) | (((extra << 1) & 8) * 0x01010101); - aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404) | (((extra << 0) & 8) * 0x01010101); + aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404); + aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404); + aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404); + aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404); + + const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40)); + const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 5) & 0x40)); + const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 4) & 0x40)); + const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 3) & 0x40)); + extra >>= 4; qh >>= 4; - const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]); - const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]); - const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]); - const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]); - #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = *(const int *)&val0; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = *(const int *)&val1; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = *(const int *)&val2; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = *(const int *)&val3; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = val1; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = val2; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = val3; #else - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = *(const int *)&val0; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = *(const int *)&val1; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = *(const int *)&val2; - x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = *(const int *)&val3; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3; #endif // INT8_MMA_AVAILABLE }