mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-06 22:40:09 +00:00
iq5_k: CUDA dot product still not working
This commit is contained in:
@@ -139,6 +139,9 @@ typedef sycl::half2 ggml_half2;
|
||||
#define QI4_XS (QK_K / (4*QR4_XS))
|
||||
#define QR4_XS 2
|
||||
|
||||
#define QI5_XS (QK_K / (4*QR5_XS))
|
||||
#define QR5_XS 2
|
||||
|
||||
#define QI3_S (QK_K / (4*QR3_S))
|
||||
#define QR3_S 4
|
||||
|
||||
|
||||
@@ -686,8 +686,8 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> {
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> {
|
||||
static constexpr int qk = QK_K;
|
||||
static constexpr int qr = QR4_XS;
|
||||
static constexpr int qi = QI4_XS;
|
||||
static constexpr int qr = QR5_XS;
|
||||
static constexpr int qi = QI5_XS;
|
||||
};
|
||||
|
||||
template<>
|
||||
|
||||
@@ -1277,34 +1277,90 @@ static __device__ __forceinline__ float vec_dot_iq4_k_q8_1(
|
||||
#define VDR_IQ5_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ5_K_Q8_1_MMQ 4
|
||||
|
||||
// TODO
|
||||
static __device__ __forceinline__ float vec_dot_iq5_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
return 0;
|
||||
|
||||
// const block_iq5_k * bq4 = (const block_iq5_k *) vbq + kbx;
|
||||
// const uint8_t * all_values = (const uint8_t *)iq4k_values;
|
||||
const block_iq5_k * bq5 = (const block_iq5_k *) vbq + kbx;
|
||||
const uint8_t * all_values = (const uint8_t *)iq5nl_values;
|
||||
|
||||
// iqs is 0...28
|
||||
const int il = iqs/2; // 0...14
|
||||
const int is = iqs%2; // 0 or 1
|
||||
const int ib32 = 2*(il/4); // 0, 2, 4, 6
|
||||
const int32_t * q8_1 = (const int *)bq8_1[ib32+0].qs + 4*is;
|
||||
const int32_t * q8_2 = (const int *)bq8_1[ib32+1].qs + 4*is;
|
||||
const uint32_t * q4 = (const uint32_t *)bq5->qs + 8*(ib32/2) + 4*is;
|
||||
const uint32_t * qh = (const uint32_t *)bq5->qh + 4*is;
|
||||
const uint16_t extra = bq5->extra >> (2*ib32 + is);
|
||||
const uint8_t * values1 = all_values + 32*(extra & 1);
|
||||
const uint8_t * values2 = all_values + 8*(extra & 4);
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * a8 = (const uint8_t *)aux32;
|
||||
int v1, v2;
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
uint32_t h = qh[j] >> ib32;
|
||||
aux32[0] = ((q4[j] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x10101010);
|
||||
aux32[1] = ((q4[j] >> 4) & 0x0f0f0f0f) | ((h << 3) & 0x10101010);
|
||||
v1 = values1[a8[0]] | (values1[a8[1]] << 8) | (values1[a8[2]] << 16) | (values1[a8[3]] << 24);
|
||||
v2 = values2[a8[4]] | (values2[a8[5]] << 8) | (values2[a8[6]] << 16) | (values2[a8[7]] << 24);
|
||||
sumi1 = ggml_cuda_dp4a(v1, q8_1[j], sumi1);
|
||||
sumi2 = ggml_cuda_dp4a(v2, q8_2[j], sumi2);
|
||||
}
|
||||
// Blocks of 16: 2*ib32 + is, 2*ib32 + is + 2
|
||||
const float d5 = __half2float(bq5->d);
|
||||
const uint8_t sh = bq5->scales_h[ib32/2] >> 2*(is%2);
|
||||
const int ls1 = (((bq5->scales_l[ib32+0] >> 4*is) & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
const int ls2 = (((bq5->scales_l[ib32+1] >> 4*is) & 0xf) | ((sh << 0) & 0x30)) - 32;
|
||||
return d5 * (__low2float(bq8_1[ib32+0].ds) * sumi1 * ls1 + __low2float(bq8_1[ib32+1].ds) * sumi2 * ls2);
|
||||
}
|
||||
|
||||
//#define VDR_IQ5_K_Q8_1_MMVQ 2
|
||||
//#define VDR_IQ5_K_Q8_1_MMQ 8
|
||||
//
|
||||
//static __device__ __forceinline__ float vec_dot_iq5_k_q8_1(
|
||||
// const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
//
|
||||
// const block_iq5_k * bq5 = (const block_iq5_k *) vbq + kbx;
|
||||
//
|
||||
// // iqs is 0...28
|
||||
// const int ib32 = iqs/4;
|
||||
// // Why iqs/4 ?
|
||||
// const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
||||
// const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32;
|
||||
// const uint16_t extra = bq4->extra >> 2*ib32;
|
||||
// // iqs = 0...7 -> bq8_offset = 0, iqs = 8...15 -> bq8_offset = 2, iqs = 16...23 -> bq8_offset = 4, iqs = 24...31 -> bq8_offset = 6
|
||||
// // bq8_offset = 0 -> 0...3, bq8_offset = 2 -> 8...11, bq8_offset = 4 -> 16...19, bq8_offset = 6 -> 24...27
|
||||
// const int bq8_offset = 2*((iqs/2)/4);
|
||||
// const int32_t * q8_1 = (const int *)bq8_1[bq8_offset+0].qs;
|
||||
// const int32_t * q8_2 = (const int *)bq8_1[bq8_offset+1].qs;
|
||||
// const uint32_t * q4 = (const uint32_t *)bq5->qs + 4*bq8_offset + ((iqs/2)%4);
|
||||
// const uint32_t * qh = (const uint32_t *)bq5->qh + ((iqs/2)%4);
|
||||
// const uint16_t extra = bq5->extra >> 2*bq8_offset;
|
||||
// const float d5 = __half2float(bq5->d);
|
||||
//
|
||||
// const uint8_t * values1;
|
||||
// const uint8_t * values2;
|
||||
// uint32_t indx[2];
|
||||
// const uint8_t * a8 = (const uint8_t *)indx;
|
||||
// int v1, v2;
|
||||
// int sumi1 = 0, sumi2 = 0;
|
||||
// for (int j = 0; j < 4; ++j) {
|
||||
// const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16);
|
||||
// get_int_from_table_16_shift(aux32, extra, all_values, v1, v2);
|
||||
// sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1);
|
||||
// sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2);
|
||||
// }
|
||||
// const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds);
|
||||
// const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2);
|
||||
// const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
// const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
|
||||
// return d * (sumi1 * ls1 + sumi2 * ls2);
|
||||
}
|
||||
//
|
||||
// indx[0] = ((q4[0] >> 0) & 0x0f0f0f0f) | (((qh[0] >> (bq8_offset+0)) << 4) & 0x10101010);
|
||||
// indx[1] = ((q4[0] >> 4) & 0x0f0f0f0f) | (((qh[0] >> (bq8_offset+1)) << 4) & 0x10101010);
|
||||
// values1 = (const uint8_t *)iq5nl_values + 32*(extra & 1);
|
||||
// values2 = (const uint8_t *)iq5nl_values + 8*(extra & 4);
|
||||
// v1 = values1[a8[0]] | (values1[a8[1]] << 8) | (values1[a8[2]] << 16) | (values1[a8[3]] << 24);
|
||||
// v2 = values2[a8[4]] | (values2[a8[5]] << 8) | (values2[a8[6]] << 16) | (values2[a8[7]] << 24);
|
||||
// int s1 = ggml_cuda_dp4a(v1, q8_1[0], 0) * (((bq5->scales_l[bq8_offset+0] & 0xf) | ((bq5->scales_h[bq8_offset/2] << 4) & 0x30)) - 32);
|
||||
// int s2 = ggml_cuda_dp4a(v2, q8_2[0], 0) * (((bq5->scales_l[bq8_offset+1] & 0xf) | ((bq5->scales_h[bq8_offset/2] >> 0) & 0x30)) - 32);
|
||||
//
|
||||
// indx[0] = ((q4[4] >> 0) & 0x0f0f0f0f) | (((qh[4] >> (bq8_offset+0)) << 4) & 0x10101010);
|
||||
// indx[1] = ((q4[4] >> 4) & 0x0f0f0f0f) | (((qh[4] >> (bq8_offset+1)) << 4) & 0x10101010);
|
||||
// values1 = (const uint8_t *)iq5nl_values + 16*(extra & 2);
|
||||
// values2 = (const uint8_t *)iq5nl_values + 4*(extra & 8);
|
||||
// v1 = values1[a8[0]] | (values1[a8[1]] << 8) | (values1[a8[2]] << 16) | (values1[a8[3]] << 24);
|
||||
// v2 = values2[a8[4]] | (values2[a8[5]] << 8) | (values2[a8[6]] << 16) | (values2[a8[7]] << 24);
|
||||
// int s3 = ggml_cuda_dp4a(v1, q8_1[4], 0) * (((bq5->scales_l[bq8_offset+0] >> 4) | ((bq5->scales_h[bq8_offset/2] << 2) & 0x30)) - 32);
|
||||
// int s4 = ggml_cuda_dp4a(v2, q8_2[4], 0) * (((bq5->scales_l[bq8_offset+1] >> 4) | ((bq5->scales_h[bq8_offset/2] >> 2) & 0x30)) - 32);
|
||||
//
|
||||
// return d5*(__low2float(bq8_1[bq8_offset+0].ds) * (s1 + s3) + __low2float(bq8_1[bq8_offset+1].ds) * (s2 + s4));
|
||||
//
|
||||
//}
|
||||
|
||||
#define VDR_IQ2_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ2_K_Q8_1_MMQ 4
|
||||
|
||||
Reference in New Issue
Block a user