Absoolutely don't see what is wrong with the iq1_bn and iq2_bn
vector dot product kernels.
This commit is contained in:
Iwan Kawrakow
2024-10-24 13:49:20 +02:00
parent 3ba962a68d
commit 6952e676e2

View File

@@ -5585,14 +5585,14 @@ void kernel_mul_mv_iq1_bn_f32_impl(
constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
const int ib = ix % (QK_IQ1BN / 32);
const int i16 = 2*ib + ir;
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
for (int j = 0; j < 16; ++j) yl[j] = y4[j];
const int ibl = ib32 / (QK_IQ1BN / 32);
const int ib = ib32 % (QK_IQ1BN / 32);
const int i16 = 2*ib + ir;
device const block_iq1_bn * xr = x + ibl;
device const uint8_t * ql = xr->ql + 3*i16;
device const uint8_t * extra = (device const uint8_t *)&xr->extra;
@@ -5602,8 +5602,6 @@ void kernel_mul_mv_iq1_bn_f32_impl(
float acc = 0;
int i = 0;
for (int k = 0; k < 3; ++k) {
//constant int8_t * vs = iq1bn_values + 5*ql[k];
//for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j];
uint8_t q = ql[k];
for (int j = 0; j < 5; ++j) {
uint8_t v = k_mult[j]*q;
@@ -5611,8 +5609,6 @@ void kernel_mul_mv_iq1_bn_f32_impl(
acc += yl[i++] * values[v];
}
}
//constant int8_t * vs = iq1bn_values + 5*extra[0];
//acc += yl[15] * vs[i16];
uint8_t v = k_mult[i16]*extra[0];
v = 3*v >> 8; //(v + (v >> 1)) >> 7;
acc += yl[15] * values[v];
@@ -5635,7 +5631,6 @@ void kernel_mul_mv_iq1_bn_f32_impl(
}
}
// TODO: unify with kernel_mul_mv_iq1_bn_f32_impl
void kernel_mul_mv_iq1_tn_f32_impl(
device const void * src0,
device const float * src1,
@@ -5661,7 +5656,7 @@ void kernel_mul_mv_iq1_tn_f32_impl(
// Why are we not passing in src0->nb[0]?
// But because we are not, we need to use this hack
const uint row_size = 2+sizeof(block_iq1_tn)*(ne00/QK_K);
const uint row_size = 2 + sizeof(block_iq1_bn)*(ne00/QK_IQ1BN);
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
@@ -5769,9 +5764,9 @@ void kernel_mul_mv_iq2_bn_f32_impl(
device const uint8_t * cx = (device const uint8_t *) src0 + first_row*row_size + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[16];
float sumf[N_DST]={0.f};
float scale[N_DST];
float4 yl[4];
float sumf[N_DST]={0.f};
float scale[N_DST];
for (int row = 0; row < N_DST; ++row) {
scale[row] = *((device const float *)(cx + row*row_size));
@@ -5780,29 +5775,26 @@ void kernel_mul_mv_iq2_bn_f32_impl(
const int ix = tiisg/4; // 0...7
const int ir = tiisg%4; // 0...3
device const float * y4 = y + QK_IQ1BN * ix + 4 * ir;
device const float4 * y4 = (device const float4 *)(y + QK_IQ1BN * ix + 4 * ir);
device const uint8_t * qs0 = cx + sizeof(float) + (QK_IQ1BN/4)*ix + 4*ir;
for (int ib = ix; ib < nb; ib += 8) {
float sumy = 0.f;
for (int i = 0; i < 4; ++i) {
yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0];
yl[i+ 4] = y4[i+16]; sumy += yl[i+ 4];
yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8];
yl[i+12] = y4[i+48]; sumy += yl[i+12];
}
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[8]; yl[3] = y4[12];
float4 tmp = yl[0] + yl[1] + yl[2] + yl[3];
const float sumy = tmp[0] + tmp[1] + tmp[2] + tmp[3];
device const uint8_t * qs = qs0;
for (int row = 0; row < N_DST; row++) {
float4 acc = {0.f};
for (int j = 0; j < 4; ++j) {
acc[0] += yl[j+ 0] * (qs[j] & 0x03);
acc[1] += yl[j+ 4] * (qs[j] & 0x0c);
acc[2] += yl[j+ 8] * (qs[j] & 0x30);
acc[3] += yl[j+12] * (qs[j] & 0xc0);
acc[0] += yl[0][j] * (qs[j] & 0x03);
acc[1] += yl[1][j] * (qs[j] & 0x0c);
acc[2] += yl[2][j] * (qs[j] & 0x30);
acc[3] += yl[3][j] * (qs[j] & 0xc0);
}
sumf[row] += acc[0] + 0.25f*acc[1] + 0.0625*acc[2] + 0.015625f*acc[3] - sumy;
@@ -5810,7 +5802,7 @@ void kernel_mul_mv_iq2_bn_f32_impl(
qs += row_size;
}
y4 += QK_IQ1BN * 8;
y4 += QK_IQ1BN * 2;
qs0 += QK_IQ1BN * 2;
}