Use __byte_perm in get_int_from_table_16

This commit is contained in:
Iwan Kawrakow
2025-08-21 10:14:59 +03:00
parent 0cb6696943
commit 7fe9cd9968
3 changed files with 71 additions and 28 deletions

View File

@@ -246,6 +246,25 @@ __device__ __forceinline__ void vec_dot_iq4_k_q8_1(
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
#if defined(__CUDA_ARCH__)
uint32_t v1, v2, v3, v4, mask;
const uint32_t * values32 = (const uint32_t *)values;
mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
// Perform lookups in the lower half of the table (indices 0-7).
v1 = __byte_perm(values32[0], values32[1], q4);
// Perform lookups in the upper half of the table (indices 8-15).
v2 = __byte_perm(values32[2], values32[3], q4);
// Select between the low and high results based on the MSB of each index nibble.
v3 = __byte_perm(v1, v2, mask);
// Same for the upper part of q4.
v1 = __byte_perm(values32[0], values32[1], q4 >> 16);
v2 = __byte_perm(values32[2], values32[3], q4 >> 16);
v4 = __byte_perm(v1, v2, mask >> 16);
// Mix the results to get the final int2.
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
#else
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
@@ -255,6 +274,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
#endif
}
__device__ __forceinline__ void vec_dot_iq4_k_r4_q8_1(
@@ -389,19 +409,30 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
float scale = *(const float *)vbq;
const block_iq4_ks * bq4 = (const block_iq4_ks *)((const char *)vbq + sizeof(float)) + kbx;
const uint8_t * all_values = (const uint8_t *)iq4k_values;
//const uint8_t * all_values = (const uint8_t *)iq4k_values;
// iqs is 0...28
const int ib32 = iqs/4; // Why iqs/4 ?
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
const float dl = scale * ((bq4->scales[ib32] & 254) - 127);
int v1, v2;
auto values = iq4k_values + ((bq4->scales[ib32] & 1) << 4);
//auto values = iq4k_table + ((bq4->scales[ib32] & 1) << 8);
//uint32_t aux32[2];
//auto a8 = (const uint8_t *)aux32;
//int v1, v2;
int sumi = 0;
for (int j = 0; j < 4; ++j) {
get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2);
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
//aux32[0] = (q4[j] >> 0) & 0x0f0f0f0f;
//aux32[1] = (q4[j] >> 4) & 0x0f0f0f0f;
//sumi = ggml_cuda_dp4a(int_from_table_x(a8+0, values), q8[j+0], sumi);
//sumi = ggml_cuda_dp4a(int_from_table_x(a8+4, values), q8[j+4], sumi);
auto v = get_int_from_table_16(q4[j], values);
sumi = ggml_cuda_dp4a(v.x, q8[j+0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[j+4], sumi);
////get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2);
////sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
////sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
}
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}

View File

@@ -2842,8 +2842,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kqsx = threadIdx.x / 4;
uint32_t aux32[2];
auto a8 = (const uint8_t *)aux32;
//uint32_t aux32[2];
//auto a8 = (const uint8_t *)aux32;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
@@ -2857,19 +2857,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
const int ls = (bxi->scales[kqsx] & 254) - 127;
auto values = iq4k_table + ((bxi->scales[kqsx] & 1) << 8);
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);
//auto values = iq4k_table + ((bxi->scales[kqsx] & 1) << 8);
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int q4 = get_int_b4(bxi->qs, 4*kqsx+j);
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
const int2 v = get_int_from_table_16(q4, values);
//aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
//aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x; //int_from_table_x(a8+0, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y; //int_from_table_x(a8+4, values);
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x; //int_from_table_x(a8+0, values);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y; //int_from_table_x(a8+4, values);
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE

View File

@@ -1126,21 +1126,26 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
const int8_t * q1_8 = (const int8_t *) &q1_32;
const char4 val1_8 = make_char4(
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
#if defined(__CUDA_ARCH__)
uint32_t v1, v2, v3, v4, mask;
const uint32_t * values32 = (const uint32_t *)values;
mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
// Perform lookups in the lower half of the table (indices 0-7).
v1 = __byte_perm(values32[0], values32[1], q4);
// Perform lookups in the upper half of the table (indices 8-15).
v2 = __byte_perm(values32[2], values32[3], q4);
// Select between the low and high results based on the MSB of each index nibble.
v3 = __byte_perm(v1, v2, mask);
// Same for the upper part of q4.
v1 = __byte_perm(values32[0], values32[1], q4 >> 16);
v2 = __byte_perm(values32[2], values32[3], q4 >> 16);
v4 = __byte_perm(v1, v2, mask >> 16);
// Mix the results to get the final int2.
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
#else
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
@@ -1150,6 +1155,11 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
#endif
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
return get_int_from_table_16(q4, kvalues_iq4nl);
}
#define VDR_IQ4_NL_Q8_1_MMVQ 2