mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 02:20:01 +00:00
iq4_kt: failed attemt to adjust CUDA dot product
It was working for 4.125 bpw. But after changing to 4.0 bpw there is something wrong and I don't see the bug.
This commit is contained in:
@@ -176,81 +176,98 @@ static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ v
|
||||
constexpr uint32_t kb = 64248484;
|
||||
constexpr uint32_t kmask = 0x8fff8fff;
|
||||
constexpr uint32_t km32 = 0x3b603b60;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const float * dptr = (const float *)((const char *)vx + row*row_size);
|
||||
const float d = *dptr * 31.75f * 1.01f;
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 1);
|
||||
const float d = dptr[0] * 31.75f * 1.01f;
|
||||
const float row_av = dptr[1];
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
|
||||
dfloat2 tmp = {};
|
||||
dfloat2 tmp1 = {};
|
||||
dfloat2 tmp2 = {};
|
||||
|
||||
const int it = threadIdx.x/2;
|
||||
const int ix = threadIdx.x%2;
|
||||
const int it = threadIdx.x/2; // 0...15
|
||||
const int ix = threadIdx.x%2; // 0 or 1
|
||||
const int ib32 = it/4; // 0...3
|
||||
const int ig = it%4; // 0...3
|
||||
const int jj = ib32*8 + 2*ig; // 0...30 in steps of 2
|
||||
|
||||
uint32_t s[4];
|
||||
const half * h = (const half *)s;
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
// const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
|
||||
// const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
// const dfloat scale1 = x[i].scales[it/8];
|
||||
// const dfloat scale2 = x[i].scales[it/8 + 2];
|
||||
// const dfloat2 dl1 = {scale1, scale1};
|
||||
// const dfloat2 dl2 = {scale2, scale2};
|
||||
// dfloat2 bdot1 = {0, 0};
|
||||
// dfloat2 bdot2 = {0, 0};
|
||||
// uint32_t val1 = ql[2*it+ 0] + 4096;
|
||||
// uint32_t val2 = ql[2*it+32] + 4096;
|
||||
// for (int k = 0; k < 2; ++k) {
|
||||
// val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32;
|
||||
// val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32;
|
||||
// val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32;
|
||||
// val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32;
|
||||
//#ifdef GGML_CUDA_F16
|
||||
// bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
|
||||
// bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2);
|
||||
//#else
|
||||
// bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]);
|
||||
// bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]);
|
||||
// bdot2.x += y[k+64].x * (float)(h[4] + h[5]);
|
||||
// bdot2.y += y[k+64].y * (float)(h[6] + h[7]);
|
||||
//#endif
|
||||
// }
|
||||
// val1 = ql[2*it+ 1] + 4096;
|
||||
// val2 = ql[2*it+33] + 4096;
|
||||
// for (int k = 2; k < 4; ++k) {
|
||||
// val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32;
|
||||
// val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32;
|
||||
// val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32;
|
||||
// val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32;
|
||||
//#ifdef GGML_CUDA_F16
|
||||
// bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
|
||||
// bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2);
|
||||
//#else
|
||||
// bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]);
|
||||
// bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]);
|
||||
// bdot2.x += y[k+64].x * (float)(h[4] + h[5]);
|
||||
// bdot2.y += y[k+64].y * (float)(h[6] + h[7]);
|
||||
//#endif
|
||||
// }
|
||||
//#ifdef GGML_CUDA_F16
|
||||
// tmp = __hfma2(dl1, bdot1, tmp);
|
||||
// tmp = __hfma2(dl2, bdot2, tmp);
|
||||
//#else
|
||||
// tmp.x += dl1.x * bdot1.x + dl2.x * bdot2.x;
|
||||
// tmp.y += dl1.y * bdot1.y + dl2.y * bdot2.y;
|
||||
//#endif
|
||||
const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
|
||||
const uint32_t * shb = x[i].qs;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
const uint32_t offset1 = 4096 + ((shb[ib32+0] & 1) << 15);
|
||||
const uint32_t offset2 = 4096 + ((shb[ib32+4] & 1) << 15);
|
||||
const dfloat scale1 = ((shb[ib32+0] & 0xff) >> 1) - 64;
|
||||
const dfloat scale2 = ((shb[ib32+4] & 0xff) >> 1) - 64;
|
||||
const dfloat2 dl1 = {scale1, scale1};
|
||||
const dfloat2 dl2 = {scale2, scale2};
|
||||
dfloat2 bdot1 = {0, 0};
|
||||
dfloat2 bdot2 = {0, 0};
|
||||
uint32_t val1 = ql[jj+ 0] + ((qh[jj] << 8) & 0xf00) + (((shb[ib32+0] >> (8 + 6*ig+0)) & 7) << 12) + offset1;
|
||||
uint32_t val2 = ql[jj+32] + ((qh[jj] << 4) & 0xf00) + (((shb[ib32+4] >> (8 + 6*ig+0)) & 7) << 12) + offset2;
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32;
|
||||
val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32;
|
||||
val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32;
|
||||
val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32;
|
||||
#ifdef GGML_CUDA_F16
|
||||
bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
|
||||
bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2);
|
||||
tmp2 += y[k] + y[k+64];
|
||||
#else
|
||||
bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]);
|
||||
bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]);
|
||||
bdot2.x += y[k+64].x * (float)(h[4] + h[5]);
|
||||
bdot2.y += y[k+64].y * (float)(h[6] + h[7]);
|
||||
tmp2.x += y[k].x + y[k+64].x;
|
||||
tmp2.y += y[k].y + y[k+64].y;
|
||||
#endif
|
||||
}
|
||||
val1 = ql[jj+ 1] + ((qh[jj+1] << 8) & 0xf00) + (((shb[ib32+0] >> (8 + 6*ig+3)) & 7) << 12) + offset1;
|
||||
val2 = ql[jj+33] + ((qh[jj+1] << 4) & 0xf00) + (((shb[ib32+4] >> (8 + 6*ig+3)) & 7) << 12) + offset2;
|
||||
for (int k = 2; k < 4; ++k) {
|
||||
val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32;
|
||||
val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32;
|
||||
val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32;
|
||||
val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32;
|
||||
#ifdef GGML_CUDA_F16
|
||||
bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1);
|
||||
bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2);
|
||||
tmp2 += y[k] + y[k+64];
|
||||
#else
|
||||
bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]);
|
||||
bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]);
|
||||
bdot2.x += y[k+64].x * (float)(h[4] + h[5]);
|
||||
bdot2.y += y[k+64].y * (float)(h[6] + h[7]);
|
||||
tmp2.x += y[k].x + y[k+64].x;
|
||||
tmp2.y += y[k].y + y[k+64].y;
|
||||
#endif
|
||||
}
|
||||
#ifdef GGML_CUDA_F16
|
||||
tmp1 = __hfma2(dl1, bdot1, tmp1);
|
||||
tmp1 = __hfma2(dl2, bdot2, tmp1);
|
||||
#else
|
||||
tmp1.x += dl1.x * bdot1.x + dl2.x * bdot2.x;
|
||||
tmp1.y += dl1.y * bdot1.y + dl2.y * bdot2.y;
|
||||
#endif
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
float tmp = d * (float)(tmp1.x + tmp1.y) + row_av * (float)(tmp2.x + tmp2.y);
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = d * (float)(tmp.x + tmp.y);
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3612,12 +3612,26 @@ std::vector<float> QuantizerIQKT<block_size, group_size, num_bits>::cluster_poin
|
||||
for (int ip = 0; ip < npoint; ++ip) {
|
||||
auto vp = points.data() + ndim*ip;
|
||||
uint16_t u = 0;
|
||||
if (ncluster == 255) {
|
||||
if (ncluster == 256) {
|
||||
for (int k = 0; k < ndim; ++k) u |= (bin4(vp[k]) << 2*k);
|
||||
} else {
|
||||
int s = 1;
|
||||
for (int k = 0; k < ndim; ++k) { u += s*bin5(vp[k]); s *= 5; }
|
||||
}
|
||||
if (u >= int(counts.size())) {
|
||||
printf("Oops: u = %u, vp = %g, %g, %g, %g\n", u, vp[0], vp[1], vp[2], vp[3]);
|
||||
u = 0;
|
||||
if (ncluster == 256) {
|
||||
for (int k = 0; k < ndim; ++k) {
|
||||
auto bin = bin4(vp[k]); u |= (bin << 2*k);
|
||||
printf(" bin[%d] = %d, u = %u", k, bin, u);
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < ndim; ++k) printf(" bin[%d] = %d", k, bin5(vp[k]));
|
||||
}
|
||||
printf("\n");
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
++counts[u];
|
||||
for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user