This commit is contained in:
Iwan Kawrakow
2025-05-29 16:28:17 +03:00
parent cc395cf879
commit 7a783af1ad

View File

@@ -632,9 +632,7 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
const float row_av = dptr[1];
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
for (int iy = 0; iy < nrc_y * 2; ++iy) {
accd[iy] = vdupq_n_f32(0.0f);
}
for (int iy = 0; iy < nrc_y * 2; ++iy) accd[iy] = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; ++i) {
const uint32_t * shb = x[i].qs;
@@ -652,17 +650,12 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
uint32_t sh1 = shb[ib+0] >> 8;
uint32_t sh2 = shb[ib+4] >> 8;
for (int jj = 0; jj < 4; ++jj) {
//int j = 32*ib + 8*jj;
// -> (j/8)%4 = (4*ib+jj)%4 = jj%4;
// j/4 = 8*ib + 2*jj;
//const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4));
//const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4));
for (int j = 0; j < 4; ++j) {
uint32_t val1 = ql[8*ib+2*jj+ 0] + ((qh[8*ib+2*jj+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
uint32_t val2 = ql[8*ib+2*jj+32] + ((qh[8*ib+2*jj+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
uint32_t val3 = ql[8*ib+2*jj+ 1] + ((qh[8*ib+2*jj+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
uint32_t val4 = ql[8*ib+2*jj+33] + ((qh[8*ib+2*jj+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
sh1 >>= 6;
sh2 >>= 6;
@@ -675,8 +668,8 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
x2.val[1] = vmulq_f32(scale2, x2.val[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y1 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*jj);
auto y2 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*jj + 128);
auto y1 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*j);
auto y2 = vld1q_f32_x2(y[iy] + i*QK_K + 32*ib + 8*j + 128);
accd[iy*2 + 0] = vfmaq_f32(accd[iy*2 + 0], y1.val[0], x1.val[0]);
accd[iy*2 + 1] = vfmaq_f32(accd[iy*2 + 1], y1.val[1], x1.val[1]);
@@ -689,10 +682,7 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
// Sum the two accumulators for this y row
float32x4_t sum1 = vaddq_f32(accd[iy*2], accd[iy*2 + 1]);
// Compute final result
float result = d*vaddvq_f32(sum1) + row_av*row_sum[iy];
info.store(ix, iy, result);
}