mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 14:14:32 +00:00
iq3_k: Metal dot product
Quite slow: 43 t/s for a 7B model
This commit is contained in:
@@ -3069,6 +3069,8 @@ constexpr constant static float kvalues_iq5k_f[64] = {
|
||||
|
||||
constexpr constant static float kvalues_iq2k_f[8] = { -31.f, -13.f, 1.f, 17.f, -26.f, -8.f, 6.f, 22.f };
|
||||
|
||||
constexpr constant static float kvalues_iq3k_f[16] = { -63.f, -40.f, -23.f, -10.f, 1.f, 13.f, 28.f, 47.f, -59.f, -36.f, -19.f, -6.f, 5.f, 17.f, 32.f, 51.f };
|
||||
|
||||
kernel void kernel_cpy_f32_iq4_nl(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
@@ -5314,7 +5316,6 @@ kernel void kernel_mul_mv_iq2_k_f32(
|
||||
kernel_mul_mv_iq2_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
// TODO
|
||||
void kernel_mul_mv_iq3_k_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@@ -5346,14 +5347,12 @@ void kernel_mul_mv_iq3_k_f32_impl(
|
||||
|
||||
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
device const block_iq2_k * x = (device const block_iq2_k *) src0 + ib_row + offset0;
|
||||
device const block_iq3_k * x = (device const block_iq3_k *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
||||
const int step = (sizeof(block_q2_K) * nb) / 4;
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
const int iq = it/4; // 0 or 1
|
||||
@@ -5362,18 +5361,12 @@ void kernel_mul_mv_iq3_k_f32_impl(
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
||||
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
|
||||
uint32_t vl[2], vh[2];
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
//float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
//for (int i = 0; i < 8; ++i) {
|
||||
// yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
||||
// yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
||||
// yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
||||
// yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
||||
//}
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0];
|
||||
yl[i+ 8] = y4[i+32];
|
||||
@@ -5383,28 +5376,34 @@ void kernel_mul_mv_iq3_k_f32_impl(
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
device const block_iq2_k & xb = x[row*nb + ib];
|
||||
device const uint32_t * q32 = (device const uint32_t *)xb.qs + 8*iq + 2*ir;
|
||||
device const uint32_t * sc = (device const uint32_t *)xb.scales;
|
||||
device const block_iq3_k & xb = x[row*nb + ib];
|
||||
device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir;
|
||||
device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir;
|
||||
device const uint32_t * sc = (device const uint32_t *)xb.scales_l;
|
||||
|
||||
const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1;
|
||||
thread const int8_t * s8 = (thread const int8_t *)&scales32;
|
||||
uint16_t extra = xb.extra >> (8*iq + is);
|
||||
uint16_t signs = xb.scales_h >> (8*iq + is);
|
||||
|
||||
vl[0] = ql16[0] | ql16[1] << 16;
|
||||
vl[1] = ql16[2] | ql16[3] << 16;
|
||||
vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2;
|
||||
vh[1] = ((qh16[2] | (qh16[3] << 16)) << 4*(1-iq)) >> 2;
|
||||
|
||||
float4 acc = {0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
constant float * values = kvalues_iq2k_f + 4*(extra & 1);
|
||||
constant float * values = kvalues_iq3k_f + 8*(extra & 1);
|
||||
extra >>= 2;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
aux32 = (q32[i] >> 2*l) & 0x03030303;
|
||||
acc[l] += values[aux8[0]] * yl[8*l + 4*i + 0] +
|
||||
+ values[aux8[1]] * yl[8*l + 4*i + 1] +
|
||||
+ values[aux8[2]] * yl[8*l + 4*i + 2] +
|
||||
+ values[aux8[3]] * yl[8*l + 4*i + 3];
|
||||
}
|
||||
aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404);
|
||||
aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404);
|
||||
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
|
||||
vl[0] >>= 2; vl[1] >>= 2;
|
||||
vh[0] >>= 1; vh[1] >>= 1;
|
||||
}
|
||||
|
||||
sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 15) + acc[1] * (s8[1] - 15) * acc[2] * (s8[2] - 15) + acc[3] * (s8[3] - 15));
|
||||
sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) +
|
||||
acc[2] * (signs & 0x10 ? -s8[2] : s8[2]) + acc[3] * (signs & 0x40 ? -s8[3] : s8[3]));
|
||||
|
||||
}
|
||||
|
||||
@@ -6371,7 +6370,6 @@ void dequantize_iq2_k(device const block_iq2_k * xb, short il, thread type4x4 &
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
template <typename type4x4>
|
||||
void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256
|
||||
@@ -6379,19 +6377,6 @@ void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 &
|
||||
device const uint16_t * q16h = (device const uint16_t *)xb->qh + 8*(il&1);
|
||||
half d = xb->d * (2*((xb->scales_l[il/2] >> 4*(il&1)) & 0xf) + 1) * (xb->scales_h & (1 << il) ? -1 : 1);
|
||||
|
||||
//constant int8_t * int_values = iq3nl_values + 8*((xb->extra >> il) & 1);
|
||||
//half values[8] = { d * int_values[0], d * int_values[1], d * int_values[2], d * int_values[3],
|
||||
// d * int_values[4], d * int_values[5], d * int_values[6], d * int_values[7] };
|
||||
//const int shift = 2*((il%8)/2);
|
||||
//uint32_t aux32;
|
||||
//thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
|
||||
//for (int i = 0; i < 4; ++i) {
|
||||
// uint32_t vl = q16l[2*i+0] | (q16l[2*i+1] << 16);
|
||||
// uint32_t vh = q16h[2*i+0] | (q16h[2*i+1] << 16);
|
||||
// aux32 = ((vl >> shift) & 0x03030303) | (((vh >> ((il/2)%8)) << 2) & 0x04040404);
|
||||
// for (int j = 0; j < 4; ++j) reg[i][j] = values[aux8[j]];
|
||||
//}
|
||||
|
||||
constant int8_t * values = iq3nl_values + 8*((xb->extra >> il) & 1);
|
||||
|
||||
const int shift = 2*((il%8)/2);
|
||||
|
||||
Reference in New Issue
Block a user