mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
CUDA: faster prompt processing for 4-bit quants (#713)
* Use __byte_perm in get_int_from_table_16 * Use get_int_from_table_16 everywhere for 4-bit quants --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -246,6 +246,25 @@ __device__ __forceinline__ void vec_dot_iq4_k_q8_1(
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
uint32_t v1, v2, v3, v4, mask;
|
||||
const uint32_t * values32 = (const uint32_t *)values;
|
||||
|
||||
mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
|
||||
// Perform lookups in the lower half of the table (indices 0-7).
|
||||
v1 = __byte_perm(values32[0], values32[1], q4);
|
||||
// Perform lookups in the upper half of the table (indices 8-15).
|
||||
v2 = __byte_perm(values32[2], values32[3], q4);
|
||||
// Select between the low and high results based on the MSB of each index nibble.
|
||||
v3 = __byte_perm(v1, v2, mask);
|
||||
// Same for the upper part of q4.
|
||||
v1 = __byte_perm(values32[0], values32[1], q4 >> 16);
|
||||
v2 = __byte_perm(values32[2], values32[3], q4 >> 16);
|
||||
v4 = __byte_perm(v1, v2, mask >> 16);
|
||||
|
||||
// Mix the results to get the final int2.
|
||||
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
|
||||
#else
|
||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||
const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
|
||||
@@ -255,6 +274,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);
|
||||
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void vec_dot_iq4_k_r4_q8_1(
|
||||
@@ -389,19 +409,18 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
|
||||
|
||||
float scale = *(const float *)vbq;
|
||||
const block_iq4_ks * bq4 = (const block_iq4_ks *)((const char *)vbq + sizeof(float)) + kbx;
|
||||
const uint8_t * all_values = (const uint8_t *)iq4k_values;
|
||||
|
||||
// iqs is 0...28
|
||||
const int ib32 = iqs/4; // Why iqs/4 ?
|
||||
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
||||
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
|
||||
const float dl = scale * ((bq4->scales[ib32] & 254) - 127);
|
||||
int v1, v2;
|
||||
auto values = iq4k_values + ((bq4->scales[ib32] & 1) << 4);
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2);
|
||||
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
|
||||
auto v = get_int_from_table_16(q4[j], values);
|
||||
sumi = ggml_cuda_dp4a(v.x, q8[j+0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v.y, q8[j+4], sumi);
|
||||
}
|
||||
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
|
||||
}
|
||||
@@ -560,7 +579,6 @@ __device__ __forceinline__ void vec_dot_iq4_kss_q8_1(
|
||||
|
||||
float scale = *(const float *)vbq;
|
||||
const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx;
|
||||
const uint8_t * all_values = (const uint8_t *)iq4k_values;
|
||||
|
||||
// iqs is 0...28
|
||||
const int ib32 = iqs/4; // Why iqs/4 ?
|
||||
@@ -569,14 +587,14 @@ __device__ __forceinline__ void vec_dot_iq4_kss_q8_1(
|
||||
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
|
||||
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
|
||||
const float dl = scale * ((ls & 254) - 127);
|
||||
int v1, v2;
|
||||
auto values = iq4k_values + ((ls & 1) << 4);
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
uint32_t aux32 = q4[j] & 0xfffefffe;
|
||||
aux32 ^= (aux32 >> 1);
|
||||
get_int_from_table_16_shift(aux32, ls & 1, all_values, v1, v2);
|
||||
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
|
||||
auto v = get_int_from_table_16(aux32, values);
|
||||
sumi = ggml_cuda_dp4a(v.x, q8[j+0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v.y, q8[j+4], sumi);
|
||||
}
|
||||
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
|
||||
}
|
||||
|
||||
@@ -2509,9 +2509,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
const int kbx = 0; // threadIdx.x / QI4_XS
|
||||
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
|
||||
|
||||
uint32_t aux32[2];
|
||||
auto a8 = (const uint8_t *)aux32;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
||||
int i = i0 + threadIdx.y;
|
||||
@@ -2523,15 +2520,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx;
|
||||
|
||||
const int q4 = get_int_b4(bxi->qs, kqsx);
|
||||
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
|
||||
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
|
||||
const int2 v = get_int_from_table_16(q4);
|
||||
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
@@ -2842,9 +2838,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
|
||||
const int kqsx = threadIdx.x / 4;
|
||||
|
||||
uint32_t aux32[2];
|
||||
auto a8 = (const uint8_t *)aux32;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
|
||||
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
|
||||
@@ -2857,19 +2850,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
|
||||
const int ls = (bxi->scales[kqsx] & 254) - 127;
|
||||
|
||||
auto values = iq4k_table + ((bxi->scales[kqsx] & 1) << 8);
|
||||
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const int q4 = get_int_b4(bxi->qs, 4*kqsx+j);
|
||||
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
|
||||
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
|
||||
const int2 v = get_int_from_table_16(q4, values);
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
@@ -2896,9 +2888,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
|
||||
const int kqsx = threadIdx.x/4;
|
||||
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * a8 = (const uint8_t *)aux32;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
|
||||
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
|
||||
@@ -2913,19 +2902,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0;
|
||||
|
||||
const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127;
|
||||
auto values = iq4k_table + ((bxi->scales[4*kqsx+ir] & 1) << 8);
|
||||
auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir);
|
||||
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
|
||||
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
|
||||
const int2 v = get_int_from_table_16(q4, values);
|
||||
const int k0 = 8*kqsx + 4*(j%2) + j/2;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, values);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = int_from_table_x(a8+4, values);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, values);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = int_from_table_x(a8+4, values);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
|
||||
@@ -14,9 +14,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
|
||||
const int kqsx = threadIdx.x / 4;
|
||||
|
||||
uint32_t aux32[2];
|
||||
auto a8 = (const uint8_t *)aux32;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
|
||||
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
|
||||
@@ -31,20 +28,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
|
||||
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
|
||||
|
||||
auto values = iq4k_table + ((ls & 1) << 8);
|
||||
auto values = iq4k_values + ((ls & 1) << 4);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
uint32_t val = q4[j] & 0xfffefffe;
|
||||
val = val ^ (val >> 1);
|
||||
aux32[0] = (val >> 0) & 0x0f0f0f0f;
|
||||
aux32[1] = (val >> 4) & 0x0f0f0f0f;
|
||||
auto v = get_int_from_table_16(val, values);
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
|
||||
@@ -1126,21 +1126,26 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
||||
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
|
||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||
const char4 val0_8 = make_char4(
|
||||
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
|
||||
|
||||
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
||||
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
||||
const char4 val1_8 = make_char4(
|
||||
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
|
||||
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
uint32_t v1, v2, v3, v4, mask;
|
||||
const uint32_t * values32 = (const uint32_t *)values;
|
||||
|
||||
mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
|
||||
// Perform lookups in the lower half of the table (indices 0-7).
|
||||
v1 = __byte_perm(values32[0], values32[1], q4);
|
||||
// Perform lookups in the upper half of the table (indices 8-15).
|
||||
v2 = __byte_perm(values32[2], values32[3], q4);
|
||||
// Select between the low and high results based on the MSB of each index nibble.
|
||||
v3 = __byte_perm(v1, v2, mask);
|
||||
// Same for the upper part of q4.
|
||||
v1 = __byte_perm(values32[0], values32[1], q4 >> 16);
|
||||
v2 = __byte_perm(values32[2], values32[3], q4 >> 16);
|
||||
v4 = __byte_perm(v1, v2, mask >> 16);
|
||||
|
||||
// Mix the results to get the final int2.
|
||||
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
|
||||
#else
|
||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||
const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
|
||||
@@ -1150,6 +1155,11 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);
|
||||
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
|
||||
return get_int_from_table_16(q4, kvalues_iq4nl);
|
||||
}
|
||||
|
||||
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
||||
|
||||
Reference in New Issue
Block a user