mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
Faster iq4_k: Metal
This commit is contained in:
@@ -6492,73 +6492,59 @@ void kernel_mul_mv_iq4_k_f32_impl(
|
||||
|
||||
const int ix = tiisg/16; // 0 or 1
|
||||
const int it = tiisg%16; // 0...15
|
||||
const int ib = it/2;
|
||||
const int il = it%2;
|
||||
|
||||
shared_values[tiisg] = kvalues_iq4k_f[tiisg];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float sumf[2]={0.f}, all_sum;
|
||||
float2 sumf = 0.f;
|
||||
|
||||
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
||||
device const float * yb = y + ix * QK_K + it * 4;
|
||||
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
|
||||
float4 qf1, qf2;
|
||||
float4 qf;
|
||||
|
||||
for (int ibl = ix; ibl < nb; ibl += 2) {
|
||||
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||
//float2 sumy;
|
||||
//sumy[0] = -4.f*(yl[0][0] + yl[0][1] + yl[0][2] + yl[0][3] + yl[2][0] + yl[2][1] + yl[2][2] + yl[2][3]);
|
||||
//sumy[1] = -4.f*(yl[1][0] + yl[1][1] + yl[1][2] + yl[1][3] + yl[3][0] + yl[3][1] + yl[3][2] + yl[3][3]);
|
||||
yl[0] = y4[0]; yl[1] = y4[16]; yl[2] = y4[32]; yl[3] = y4[48];
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
||||
device const block_iq4_k & xb = x[row*nb + ibl];
|
||||
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
||||
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 4*it);
|
||||
|
||||
uint16_t extra = xb.extra >> 2*ib;
|
||||
threadgroup const float * values1 = shared_values + 16*(extra & 1);
|
||||
threadgroup const float * values2 = shared_values + 8*(extra & 2);
|
||||
uint16_t extra = xb.extra >> it;
|
||||
threadgroup const float * values = shared_values + ((extra & 1) << 4);
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
aux32 = q4[0] & 0x0f0f0f0f;
|
||||
qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]};
|
||||
float4 acc = yl[0] * qf;
|
||||
aux32 = (q4[0] >> 4) & 0x0f0f0f0f;
|
||||
qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]};
|
||||
acc += yl[1] * qf;
|
||||
|
||||
aux32[0] = q4[0] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
|
||||
qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
|
||||
acc1 += yl[0] * qf1;
|
||||
acc2 += yl[1] * qf2;
|
||||
aux32 = q4[16] & 0x0f0f0f0f;
|
||||
qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]};
|
||||
acc += yl[2] * qf;
|
||||
aux32 = (q4[16] >> 4) & 0x0f0f0f0f;
|
||||
qf = {values[q8[0]], values[q8[1]], values[q8[2]], values[q8[3]]};
|
||||
acc += yl[3] * qf;
|
||||
|
||||
aux32[0] = q4[1] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {values1[q8[0]], values1[q8[1]], values1[q8[2]], values1[q8[3]]};
|
||||
qf2 = {values2[q8[4]], values2[q8[5]], values2[q8[6]], values2[q8[7]]};
|
||||
acc1 += yl[2] * qf1;
|
||||
acc2 += yl[3] * qf2;
|
||||
|
||||
const uint8_t h = xb.scales_h[ib/2] >> 4*(ib%2);
|
||||
const int ls1 = ((xb.scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
|
||||
const int ls2 = ((xb.scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
|
||||
sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3]) + ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3]));
|
||||
//uint16_t extra = xb.extra >> 2*ib;
|
||||
//sumf[row] += (float)xb.d * (ls1 * (acc1[0] + acc1[1] + acc1[2] + acc1[3] + (extra & 1 ? sumy[0] : 0)) +
|
||||
// ls2 * (acc2[0] + acc2[1] + acc2[2] + acc2[3] + (extra & 2 ? sumy[1] : 0)));
|
||||
const uint8_t h = xb.scales_h[it/4] >> 2*(it%4);
|
||||
const int ls = (((xb.scales_l[it/2] >> 4*(it%2)) & 0xf) | ((h << 4) & 0x30)) - 32;
|
||||
sumf[row] += (float)xb.d * ls * (acc[0] + acc[1] + acc[2] + acc[3]);
|
||||
|
||||
}
|
||||
|
||||
yb += 2 * QK_K;
|
||||
}
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
if (tiisg < 2) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7692,22 +7678,26 @@ void dequantize_iq3_k(device const block_iq3_k * xb, short il, thread type4x4 &
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq4_k(device const block_iq4_k * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
const int l = il%2;
|
||||
// l = 0 or 1. l = 0 processes the first 16 quants in a block of 32, l = 1 the second 16
|
||||
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
|
||||
const int ls = ((xb->scales_l[ib32] >> 4*l) & 0xf) | (((xb->scales_h[il/4] >> 2*(il%4)) & 3) << 4);
|
||||
const float d = (float)xb->d * (ls - 32);
|
||||
int i16 = 4*(il%4);
|
||||
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 16*(il/8) + i16;
|
||||
const int shift = 4*((il/4)%2);
|
||||
uint16_t extra = xb->extra >> i16;
|
||||
const float d = xb->d;
|
||||
float4 dl;
|
||||
dl[0] = d * (((xb->scales_l[2*(il%4)+0] & 0xf) | ((xb->scales_h[il%4] << 4) & 0x30)) - 32);
|
||||
dl[1] = d * (((xb->scales_l[2*(il%4)+0] >> 4) | ((xb->scales_h[il%4] << 2) & 0x30)) - 32);
|
||||
dl[2] = d * (((xb->scales_l[2*(il%4)+1] & 0xf) | ((xb->scales_h[il%4] >> 0) & 0x30)) - 32);
|
||||
dl[3] = d * (((xb->scales_l[2*(il%4)+1] >> 4) | ((xb->scales_h[il%4] >> 2) & 0x30)) - 32);
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
constant float * values = kvalues_iq4k_f + 16*((xb->extra >> il) & 1);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
aux32 = (q4[i] >> 4*l) & 0x0f0f0f0f;
|
||||
reg[i][0] = d * values[q8[0]];
|
||||
reg[i][1] = d * values[q8[1]];
|
||||
reg[i][2] = d * values[q8[2]];
|
||||
reg[i][3] = d * values[q8[3]];
|
||||
constant float * values = kvalues_iq4k_f + ((extra & 1) << 4);
|
||||
aux32 = (q4[i] >> shift) & 0x0f0f0f0f;
|
||||
reg[i][0] = dl[i] * values[q8[0]];
|
||||
reg[i][1] = dl[i] * values[q8[1]];
|
||||
reg[i][2] = dl[i] * values[q8[2]];
|
||||
reg[i][3] = dl[i] * values[q8[3]];
|
||||
extra >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user