mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
iq3_ks: faster mmq
This commit is contained in:
@@ -2703,6 +2703,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
}
|
||||
|
||||
static const __device__ uint16_t iq3k_table[128] = {
|
||||
0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f,
|
||||
0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f,
|
||||
0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f,
|
||||
0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f,
|
||||
0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33,
|
||||
0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33,
|
||||
0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133,
|
||||
0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333,
|
||||
};
|
||||
|
||||
__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) {
|
||||
return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_ks(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
@@ -2718,8 +2733,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
constexpr int qstep = 8;
|
||||
const int kqsx = threadIdx.x % qstep;
|
||||
|
||||
auto values = iq3nl_values;
|
||||
|
||||
uint32_t aux32[4];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
@@ -2741,28 +2754,29 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
for (int l = 0; l < qstep/4; ++l) {
|
||||
|
||||
const int ql = get_int_b2(bxi->qs, kqsx + qstep*l);
|
||||
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404) | (((extra << 3) & 8) * 0x01010101);
|
||||
aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404) | (((extra << 2) & 8) * 0x01010101);
|
||||
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404) | (((extra << 1) & 8) * 0x01010101);
|
||||
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404) | (((extra << 0) & 8) * 0x01010101);
|
||||
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404);
|
||||
aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404);
|
||||
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404);
|
||||
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404);
|
||||
|
||||
const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40));
|
||||
const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 5) & 0x40));
|
||||
const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 4) & 0x40));
|
||||
const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 3) & 0x40));
|
||||
|
||||
extra >>= 4;
|
||||
qh >>= 4;
|
||||
|
||||
const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]);
|
||||
const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]);
|
||||
const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]);
|
||||
const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = *(const int *)&val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = *(const int *)&val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = *(const int *)&val3;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = val3;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = *(const int *)&val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = *(const int *)&val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = *(const int *)&val3;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user