mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
WIP
Absoolutely don't see what is wrong with the iq1_bn and iq2_bn vector dot product kernels.
This commit is contained in:
@@ -5585,14 +5585,14 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
||||
|
||||
constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||
|
||||
const int ib = ix % (QK_IQ1BN / 32);
|
||||
const int i16 = 2*ib + ir;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
||||
|
||||
for (int j = 0; j < 16; ++j) yl[j] = y4[j];
|
||||
|
||||
const int ibl = ib32 / (QK_IQ1BN / 32);
|
||||
const int ib = ib32 % (QK_IQ1BN / 32);
|
||||
const int i16 = 2*ib + ir;
|
||||
|
||||
device const block_iq1_bn * xr = x + ibl;
|
||||
device const uint8_t * ql = xr->ql + 3*i16;
|
||||
device const uint8_t * extra = (device const uint8_t *)&xr->extra;
|
||||
@@ -5602,8 +5602,6 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
||||
float acc = 0;
|
||||
int i = 0;
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
//constant int8_t * vs = iq1bn_values + 5*ql[k];
|
||||
//for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j];
|
||||
uint8_t q = ql[k];
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
uint8_t v = k_mult[j]*q;
|
||||
@@ -5611,8 +5609,6 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
||||
acc += yl[i++] * values[v];
|
||||
}
|
||||
}
|
||||
//constant int8_t * vs = iq1bn_values + 5*extra[0];
|
||||
//acc += yl[15] * vs[i16];
|
||||
uint8_t v = k_mult[i16]*extra[0];
|
||||
v = 3*v >> 8; //(v + (v >> 1)) >> 7;
|
||||
acc += yl[15] * values[v];
|
||||
@@ -5635,7 +5631,6 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: unify with kernel_mul_mv_iq1_bn_f32_impl
|
||||
void kernel_mul_mv_iq1_tn_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@@ -5661,7 +5656,7 @@ void kernel_mul_mv_iq1_tn_f32_impl(
|
||||
|
||||
// Why are we not passing in src0->nb[0]?
|
||||
// But because we are not, we need to use this hack
|
||||
const uint row_size = 2+sizeof(block_iq1_tn)*(ne00/QK_K);
|
||||
const uint row_size = 2 + sizeof(block_iq1_bn)*(ne00/QK_IQ1BN);
|
||||
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
|
||||
@@ -5769,9 +5764,9 @@ void kernel_mul_mv_iq2_bn_f32_impl(
|
||||
device const uint8_t * cx = (device const uint8_t *) src0 + first_row*row_size + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[16];
|
||||
float sumf[N_DST]={0.f};
|
||||
float scale[N_DST];
|
||||
float4 yl[4];
|
||||
float sumf[N_DST]={0.f};
|
||||
float scale[N_DST];
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
scale[row] = *((device const float *)(cx + row*row_size));
|
||||
@@ -5780,29 +5775,26 @@ void kernel_mul_mv_iq2_bn_f32_impl(
|
||||
const int ix = tiisg/4; // 0...7
|
||||
const int ir = tiisg%4; // 0...3
|
||||
|
||||
device const float * y4 = y + QK_IQ1BN * ix + 4 * ir;
|
||||
device const float4 * y4 = (device const float4 *)(y + QK_IQ1BN * ix + 4 * ir);
|
||||
device const uint8_t * qs0 = cx + sizeof(float) + (QK_IQ1BN/4)*ix + 4*ir;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 8) {
|
||||
|
||||
float sumy = 0.f;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0];
|
||||
yl[i+ 4] = y4[i+16]; sumy += yl[i+ 4];
|
||||
yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8];
|
||||
yl[i+12] = y4[i+48]; sumy += yl[i+12];
|
||||
}
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[8]; yl[3] = y4[12];
|
||||
float4 tmp = yl[0] + yl[1] + yl[2] + yl[3];
|
||||
const float sumy = tmp[0] + tmp[1] + tmp[2] + tmp[3];
|
||||
|
||||
device const uint8_t * qs = qs0;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
float4 acc = {0.f};
|
||||
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
acc[0] += yl[j+ 0] * (qs[j] & 0x03);
|
||||
acc[1] += yl[j+ 4] * (qs[j] & 0x0c);
|
||||
acc[2] += yl[j+ 8] * (qs[j] & 0x30);
|
||||
acc[3] += yl[j+12] * (qs[j] & 0xc0);
|
||||
acc[0] += yl[0][j] * (qs[j] & 0x03);
|
||||
acc[1] += yl[1][j] * (qs[j] & 0x0c);
|
||||
acc[2] += yl[2][j] * (qs[j] & 0x30);
|
||||
acc[3] += yl[3][j] * (qs[j] & 0xc0);
|
||||
}
|
||||
|
||||
sumf[row] += acc[0] + 0.25f*acc[1] + 0.0625*acc[2] + 0.015625f*acc[3] - sumy;
|
||||
@@ -5810,7 +5802,7 @@ void kernel_mul_mv_iq2_bn_f32_impl(
|
||||
qs += row_size;
|
||||
}
|
||||
|
||||
y4 += QK_IQ1BN * 8;
|
||||
y4 += QK_IQ1BN * 2;
|
||||
qs0 += QK_IQ1BN * 2;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user