mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
iq5_ks: dot product works on CUDA
This commit is contained in:
@@ -331,10 +331,8 @@ __device__ __forceinline__ float vec_dot_iq5_k_q8_1(
|
||||
__device__ __forceinline__ float vec_dot_iq5_ks_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
// TODO
|
||||
return 0.f;
|
||||
|
||||
const block_iq5_k * bq5 = (const block_iq5_k *) vbq + kbx;
|
||||
float scale = *(const float *)vbq;
|
||||
const block_iq5_ks * bq5 = (const block_iq5_ks *)((const char *)vbq + sizeof(float)) + kbx;
|
||||
const uint8_t * all_values = (const uint8_t *)iq5nl_values;
|
||||
|
||||
int i4 = iqs/4; // 0...7. Blocks of 16 index is 4*(i4/2) + (i4%2) + (0 and 2)
|
||||
@@ -343,9 +341,8 @@ __device__ __forceinline__ float vec_dot_iq5_ks_q8_1(
|
||||
const int32_t * q8_2 = (const int *)bq8_1[2*(i4/2)+1].qs + 4*(i4%2);
|
||||
const uint32_t * q4 = (const uint32_t *)bq5->qs + 8*(i4/2) + 4*(i4%2);
|
||||
const uint32_t * qh = (const uint32_t *)bq5->qh + 4*(i4%2);
|
||||
const uint16_t extra = bq5->extra >> (4*(i4/2) + (i4%2));
|
||||
const uint8_t * values1 = all_values + 32*(extra & 1);
|
||||
const uint8_t * values2 = all_values + 8*(extra & 4);
|
||||
const uint8_t * values1 = all_values + ((bq5->scales[2*(i4/2)+0] & 1) << 5);
|
||||
const uint8_t * values2 = all_values + ((bq5->scales[2*(i4/2)+1] & 1) << 5);
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * a8 = (const uint8_t *)aux32;
|
||||
int v1, v2;
|
||||
@@ -359,11 +356,9 @@ __device__ __forceinline__ float vec_dot_iq5_ks_q8_1(
|
||||
sumi1 = ggml_cuda_dp4a(v1, q8_1[j], sumi1);
|
||||
sumi2 = ggml_cuda_dp4a(v2, q8_2[j], sumi2);
|
||||
}
|
||||
const float d5 = __half2float(bq5->d);
|
||||
const uint8_t sh = bq5->scales_h[i4/2] >> 2*(i4%2);
|
||||
const int ls1 = (((bq5->scales_l[2*(i4/2)+0] >> 4*(i4%2)) & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
const int ls2 = (((bq5->scales_l[2*(i4/2)+1] >> 4*(i4%2)) & 0xf) | ((sh << 0) & 0x30)) - 32;
|
||||
return d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2);
|
||||
const int ls1 = (bq5->scales[2*(i4/2)+0] & 254) - 127;
|
||||
const int ls2 = (bq5->scales[2*(i4/2)+1] & 254) - 127;
|
||||
return scale * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2);
|
||||
}
|
||||
|
||||
#define VDR_IQ6_K_Q8_1_MMVQ 4
|
||||
|
||||
Reference in New Issue
Block a user