mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-07 06:50:09 +00:00
Bitnet(2.25 bpw): faster Metal dot product
With this we get TG-128 = 97 t/s.
This commit is contained in:
@@ -5127,16 +5127,16 @@ void kernel_mul_mv_iq2_bn_f32_impl(
|
||||
device const block_iq2_bn * x = (device const block_iq2_bn *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[8];
|
||||
float yl[16];
|
||||
float sumf[N_DST]={0.f};
|
||||
float d1bn[N_DST];
|
||||
|
||||
const int nb32 = nb * (QK_IQ1BN / 32);
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int ir = tiisg%8; // 0...7
|
||||
const int ix = tiisg/4; // 0...7
|
||||
const int ir = tiisg%4; // 0...3
|
||||
|
||||
device const float * y4 = y + 64 * ix + 2 * ir;
|
||||
device const float * y4 = y + 64 * ix + 4 * ir;
|
||||
|
||||
typedef union { float f; uint32_t i; } scale_t;
|
||||
scale_t scale;
|
||||
@@ -5145,32 +5145,34 @@ void kernel_mul_mv_iq2_bn_f32_impl(
|
||||
d1bn[row] = x[nb*row].d;
|
||||
}
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
for (int ib = ix; ib < nb; ib += 8) {
|
||||
|
||||
float sumy = 0.f;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
yl[i+0] = y4[i+ 0]; sumy += yl[i];
|
||||
yl[i+2] = y4[i+16]; sumy += yl[i+2];
|
||||
yl[i+4] = y4[i+32]; sumy += yl[i+4];
|
||||
yl[i+6] = y4[i+48]; sumy += yl[i+6];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0];
|
||||
yl[i+ 4] = y4[i+16]; sumy += yl[i+ 4];
|
||||
yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8];
|
||||
yl[i+12] = y4[i+48]; sumy += yl[i+12];
|
||||
}
|
||||
|
||||
device const uint8_t * qs = x[ib].qs + 2*ir;
|
||||
device const uint8_t * qs = x[ib].qs + 4*ir;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
float acc = 0;
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
acc += yl[j+0] * ((qs[j] >> 0) & 0x03) + yl[j+2] * ((qs[j] >> 2) & 0x03)
|
||||
+ yl[j+4] * ((qs[j] >> 4) & 0x03) + yl[j+6] * ((qs[j] >> 6) & 0x03);
|
||||
float4 acc = {0.f};
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
acc[0] += yl[j+ 0] * (qs[j] & 0x03);
|
||||
acc[1] += yl[j+ 4] * (qs[j] & 0x0c);
|
||||
acc[2] += yl[j+ 8] * (qs[j] & 0x30);
|
||||
acc[3] += yl[j+12] * (qs[j] & 0xc0);
|
||||
}
|
||||
|
||||
sumf[row] += acc - sumy;
|
||||
sumf[row] += acc[0] + 0.25f*acc[1] + 0.0625*acc[2] + 0.015625f*acc[3] - sumy;
|
||||
|
||||
qs += nb*sizeof(block_iq2_bn);
|
||||
}
|
||||
|
||||
y4 += 64 * 4;
|
||||
y4 += 64 * 8;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row += 2) {
|
||||
|
||||
Reference in New Issue
Block a user