mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
iq1_bn(Metal): 64 -> 66.2 t/s for TG
This commit is contained in:
@@ -5026,16 +5026,16 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
|||||||
device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0;
|
device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||||
|
|
||||||
float yl[16];
|
float yl[8];
|
||||||
float sumf[N_DST]={0.f}, all_sum;
|
float sumf[N_DST]={0.f}, all_sum;
|
||||||
float d1bn[N_DST];
|
float d1bn[N_DST];
|
||||||
|
|
||||||
const int nb32 = nb * (QK_IQ1BN / 32);
|
const int nb32 = nb * (QK_IQ1BN / 32);
|
||||||
|
|
||||||
const int ix = tiisg/2;
|
const int ix = tiisg/4;
|
||||||
const int ir = tiisg%2;
|
const int ir = tiisg%4;
|
||||||
|
|
||||||
device const float * y4 = y + 32 * ix + 16 * ir;
|
device const float * y4 = y + 32 * ix + 8 * ir;
|
||||||
|
|
||||||
typedef union { float f; uint32_t i; } scale_t;
|
typedef union { float f; uint32_t i; } scale_t;
|
||||||
scale_t scale;
|
scale_t scale;
|
||||||
@@ -5046,15 +5046,14 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
|||||||
d1bn[row] = scale.f;
|
d1bn[row] = scale.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t aux32;
|
uint16_t aux16;
|
||||||
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
|
thread const uint8_t * aux8 = (thread const uint8_t *)&aux16;
|
||||||
|
|
||||||
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
for (int ib32 = ix; ib32 < nb32; ib32 += 8) {
|
||||||
|
|
||||||
float2 sumy = {0.f};
|
float sumy = 0.f;
|
||||||
for (int i = 0; i < 8; ++i) {
|
for (int i = 0; i < 8; ++i) {
|
||||||
yl[i+0] = y4[i+0]; sumy[0] += yl[i+ 0];
|
yl[i] = y4[i]; sumy += yl[i];
|
||||||
yl[i+8] = y4[i+8]; sumy[1] += yl[i+ 8];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ibl = ib32 / (QK_IQ1BN / 32);
|
const int ibl = ib32 / (QK_IQ1BN / 32);
|
||||||
@@ -5062,41 +5061,32 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
|||||||
|
|
||||||
device const block_iq1_bn * xr = x + ibl;
|
device const block_iq1_bn * xr = x + ibl;
|
||||||
device const uint16_t * extra = (device const uint16_t *)&xr->extra;
|
device const uint16_t * extra = (device const uint16_t *)&xr->extra;
|
||||||
device const uint8_t * ql = xr->ql + 4 * ib + 2*ir;
|
device const uint8_t * ql = xr->ql + 4 * ib + ir;
|
||||||
device const uint8_t * qh = xr->qh + 2 * ib + ir;;
|
device const uint8_t * qh = xr->qh + 2 * ib + ir/2;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
|
||||||
uint8_t signs = extra[0] >> (8 + 4*ib + 2*ir);
|
uint8_t signs = extra[0] >> (8 + 4*ib + ir);
|
||||||
float2 acc = {0.f};
|
float acc = 0.f;
|
||||||
|
|
||||||
uint32_t v1 = iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)];
|
uint16_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)];
|
||||||
uint32_t v2 = iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)];
|
aux16 = v & 0x0303;
|
||||||
uint32_t v = v1 | (v2 << 16);
|
acc += yl[0] * aux8[0] + yl[4] * aux8[1];
|
||||||
aux32 = v & 0x03030303;
|
aux16 = (v >> 2) & 0x0303;
|
||||||
acc[0] += yl[ 0] * aux8[0] + yl[ 4] * aux8[1];
|
acc += yl[1] * aux8[0] + yl[5] * aux8[1];
|
||||||
acc[1] += yl[ 8] * aux8[2] + yl[12] * aux8[3];
|
aux16 = (v >> 4) & 0x0303;
|
||||||
aux32 = (v >> 2) & 0x03030303;
|
acc += yl[2] * aux8[0] + yl[6] * aux8[1];
|
||||||
acc[0] += yl[ 1] * aux8[0] + yl[ 5] * aux8[1];
|
aux16 = (v >> 6) & 0x0303;
|
||||||
acc[1] += yl[ 9] * aux8[2] + yl[13] * aux8[3];
|
acc += yl[3] * aux8[0] + yl[7] * aux8[1];
|
||||||
aux32 = (v >> 4) & 0x03030303;
|
|
||||||
acc[0] += yl[ 2] * aux8[0] + yl[ 6] * aux8[1];
|
|
||||||
acc[1] += yl[10] * aux8[2] + yl[14] * aux8[3];
|
|
||||||
aux32 = (v >> 6) & 0x03030303;
|
|
||||||
acc[0] += yl[ 3] * aux8[0] + yl[ 7] * aux8[1];
|
|
||||||
acc[1] += yl[12] * aux8[2] + yl[15] * aux8[3];
|
|
||||||
|
|
||||||
acc -= sumy;
|
sumf[row] += (signs & 1 ? sumy-acc : acc-sumy);
|
||||||
float sum = (signs & 1 ? -acc[0] : acc[0]) + (signs & 2 ? -acc[1] : acc[1]);
|
|
||||||
|
|
||||||
sumf[row] += sum;
|
|
||||||
|
|
||||||
extra += nb*sizeof(block_iq1_bn)/2;
|
extra += nb*sizeof(block_iq1_bn)/2;
|
||||||
ql += nb*sizeof(block_iq1_bn);
|
ql += nb*sizeof(block_iq1_bn);
|
||||||
qh += nb*sizeof(block_iq1_bn);
|
qh += nb*sizeof(block_iq1_bn);
|
||||||
}
|
}
|
||||||
|
|
||||||
y4 += 32 * 16;
|
y4 += 32 * 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
|
|||||||
Reference in New Issue
Block a user