mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
cuda: slightly faster MMQ for iq3_k, iq3_k_r4
This commit is contained in:
@@ -73,3 +73,18 @@ __device__ __forceinline__ int int_from_table_4(const uint32_t idx, const int *
|
||||
return values[ggml_cuda_dp4a(idx, 0x40100401, 0)];
|
||||
}
|
||||
|
||||
static const __device__ uint16_t iq3k_table[128] = {
|
||||
0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f,
|
||||
0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f,
|
||||
0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f,
|
||||
0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f,
|
||||
0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33,
|
||||
0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33,
|
||||
0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133,
|
||||
0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333,
|
||||
};
|
||||
|
||||
__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) {
|
||||
return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
|
||||
}
|
||||
|
||||
|
||||
@@ -950,21 +950,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
|
||||
#define VDR_IQ3_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ3_K_Q8_1_MMQ 4
|
||||
|
||||
static const __device__ uint16_t iq3k_table[128] = {
|
||||
0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f,
|
||||
0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f,
|
||||
0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f,
|
||||
0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f,
|
||||
0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33,
|
||||
0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33,
|
||||
0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133,
|
||||
0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333,
|
||||
};
|
||||
|
||||
__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) {
|
||||
return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void vec_dot_iq3_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) {
|
||||
const block_iq3_k * bq3 = (const block_iq3_k *) vbq + kbx;
|
||||
|
||||
@@ -2623,8 +2623,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
constexpr int qstep = 8;
|
||||
const int kqsx = threadIdx.x % qstep;
|
||||
|
||||
auto values = iq3nl_values;
|
||||
|
||||
uint32_t aux32[4];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
@@ -2646,57 +2644,43 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
for (int l = 0; l < qstep/4; ++l) {
|
||||
|
||||
const int ql = get_int_b2(bxi->qs, kqsx + qstep*l);
|
||||
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404) | (((extra << 3) & 8) * 0x01010101);
|
||||
aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404) | (((extra << 1) & 8) * 0x01010101);
|
||||
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404) | (((extra >> 1) & 8) * 0x01010101);
|
||||
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404) | (((extra >> 3) & 8) * 0x01010101);
|
||||
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404);
|
||||
aux32[1] = ((ql >> 2) & 0x03030303) | ((qh << 1) & 0x04040404);
|
||||
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404);
|
||||
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404);
|
||||
|
||||
const int val0 = int_from_table_2(aux8+ 0, iq3k_table + ((extra << 6) & 0x40));
|
||||
const int val1 = int_from_table_2(aux8+ 4, iq3k_table + ((extra << 4) & 0x40));
|
||||
const int val2 = int_from_table_2(aux8+ 8, iq3k_table + ((extra << 2) & 0x40));
|
||||
const int val3 = int_from_table_2(aux8+12, iq3k_table + ((extra << 0) & 0x40));
|
||||
|
||||
extra >>= 8;
|
||||
qh >>= 4;
|
||||
|
||||
const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]);
|
||||
const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]);
|
||||
const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]);
|
||||
const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = *(const int *)&val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = *(const int *)&val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = *(const int *)&val3;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = val3;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = *(const int *)&val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = *(const int *)&val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = *(const int *)&val3;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = val3;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
uint16_t sh = bxi->scales_h >> 2*kqsx;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ((2*((bxi->scales_l[kqsx] >> 0) & 0xf) + 1) * (sh & 1 ? -1 : 1));
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ((2*((bxi->scales_l[kqsx] >> 4) & 0xf) + 1) * (sh & 2 ? -1 : 1));
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1));
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1));
|
||||
#else
|
||||
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ((2*((bxi->scales_l[kqsx] >> 0) & 0xf) + 1) * (sh & 1 ? -1 : 1));
|
||||
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ((2*((bxi->scales_l[kqsx] >> 4) & 0xf) + 1) * (sh & 2 ? -1 : 1));
|
||||
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1));
|
||||
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1));
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
static const __device__ uint16_t iq3k_table[128] = {
|
||||
0xc1c1, 0xc1d8, 0xc1e9, 0xc1f6, 0xc101, 0xc10d, 0xc11c, 0xc12f, 0xd8c1, 0xd8d8, 0xd8e9, 0xd8f6, 0xd801, 0xd80d, 0xd81c, 0xd82f,
|
||||
0xe9c1, 0xe9d8, 0xe9e9, 0xe9f6, 0xe901, 0xe90d, 0xe91c, 0xe92f, 0xf6c1, 0xf6d8, 0xf6e9, 0xf6f6, 0xf601, 0xf60d, 0xf61c, 0xf62f,
|
||||
0x01c1, 0x01d8, 0x01e9, 0x01f6, 0x0101, 0x010d, 0x011c, 0x012f, 0x0dc1, 0x0dd8, 0x0de9, 0x0df6, 0x0d01, 0x0d0d, 0x0d1c, 0x0d2f,
|
||||
0x1cc1, 0x1cd8, 0x1ce9, 0x1cf6, 0x1c01, 0x1c0d, 0x1c1c, 0x1c2f, 0x2fc1, 0x2fd8, 0x2fe9, 0x2ff6, 0x2f01, 0x2f0d, 0x2f1c, 0x2f2f,
|
||||
0xc5c5, 0xc5dc, 0xc5ed, 0xc5fa, 0xc505, 0xc511, 0xc520, 0xc533, 0xdcc5, 0xdcdc, 0xdced, 0xdcfa, 0xdc05, 0xdc11, 0xdc20, 0xdc33,
|
||||
0xedc5, 0xeddc, 0xeded, 0xedfa, 0xed05, 0xed11, 0xed20, 0xed33, 0xfac5, 0xfadc, 0xfaed, 0xfafa, 0xfa05, 0xfa11, 0xfa20, 0xfa33,
|
||||
0x05c5, 0x05dc, 0x05ed, 0x05fa, 0x0505, 0x0511, 0x0520, 0x0533, 0x11c5, 0x11dc, 0x11ed, 0x11fa, 0x1105, 0x1111, 0x1120, 0x1133,
|
||||
0x20c5, 0x20dc, 0x20ed, 0x20fa, 0x2005, 0x2011, 0x2020, 0x2033, 0x33c5, 0x33dc, 0x33ed, 0x33fa, 0x3305, 0x3311, 0x3320, 0x3333,
|
||||
};
|
||||
|
||||
__device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16_t * values) {
|
||||
return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_ks(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
|
||||
auto values_l = iq3nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 3);
|
||||
auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6);
|
||||
|
||||
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
|
||||
aux32[0] = ((ql >> 0) & 0x03030303) | ((qh << 2) & 0x04040404);
|
||||
@@ -45,21 +45,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
aux32[2] = ((ql >> 4) & 0x03030303) | ((qh >> 0) & 0x04040404);
|
||||
aux32[3] = ((ql >> 6) & 0x03030303) | ((qh >> 1) & 0x04040404);
|
||||
|
||||
const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]);
|
||||
const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]);
|
||||
const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]);
|
||||
const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]);
|
||||
int val0 = int_from_table_2(aux8+ 0, values_l);
|
||||
int val1 = int_from_table_2(aux8+ 4, values_l);
|
||||
int val2 = int_from_table_2(aux8+ 8, values_l);
|
||||
int val3 = int_from_table_2(aux8+12, values_l);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = val3;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = val3;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
qh >>= 4;
|
||||
|
||||
Reference in New Issue
Block a user