mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-10 16:30:12 +00:00
iq2_k: Metal dot product finally works
It is slow: 45.4 t/s for 7B model vs 50 t/s for iq2_xs, or 63.3 t/s for q2_K_S.
This commit is contained in:
@@ -5233,8 +5233,6 @@ void kernel_mul_mv_iq2_k_f32_impl(
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
||||
const int step = (sizeof(block_q2_K) * nb) / 4;
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
const int iq = it/4; // 0 or 1
|
||||
@@ -5243,18 +5241,11 @@ void kernel_mul_mv_iq2_k_f32_impl(
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
||||
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
//float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
//for (int i = 0; i < 8; ++i) {
|
||||
// yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
||||
// yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
||||
// yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
||||
// yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
||||
//}
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0];
|
||||
yl[i+ 8] = y4[i+32];
|
||||
@@ -5276,16 +5267,11 @@ void kernel_mul_mv_iq2_k_f32_impl(
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
constant float * values = kvalues_iq2k_f + 4*(extra & 1);
|
||||
extra >>= 2;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
aux32 = (q32[i] >> 2*l) & 0x03030303;
|
||||
acc[l] += values[aux8[0]] * yl[8*l + 4*i + 0] +
|
||||
+ values[aux8[1]] * yl[8*l + 4*i + 1] +
|
||||
+ values[aux8[2]] * yl[8*l + 4*i + 2] +
|
||||
+ values[aux8[3]] * yl[8*l + 4*i + 3];
|
||||
}
|
||||
aux32[0] = (q32[0] >> 2*l) & 0x03030303;
|
||||
aux32[1] = (q32[1] >> 2*l) & 0x03030303;
|
||||
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
|
||||
}
|
||||
|
||||
sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 15) + acc[1] * (s8[1] - 15) * acc[2] * (s8[2] - 15) + acc[3] * (s8[3] - 15));
|
||||
sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 15) + acc[1] * (s8[1] - 15) + acc[2] * (s8[2] - 15) + acc[3] * (s8[3] - 15));
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user