mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
iq2_kt: Metal GEMV
Performance is actually quite decent: 52 t/s on my M2-Max for LlaMA-3.1-8B
This commit is contained in:
@@ -6596,7 +6596,55 @@ void kernel_mul_mv_iq2_k_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
struct Trellis {
|
||||
constexpr constant static uint32_t kmask1 = 0x8fff8fff;
|
||||
constexpr constant static uint32_t kmask2 = 0x3b603b60;
|
||||
constexpr constant static uint32_t ka = 89226354;
|
||||
constexpr constant static uint32_t kb = 64248484;
|
||||
constexpr constant static uint32_t ka1 = ka*ka;
|
||||
constexpr constant static uint32_t kb1 = kb*ka+kb;
|
||||
constexpr constant static uint32_t ka2 = ka1*ka;
|
||||
constexpr constant static uint32_t kb2 = kb1*ka+kb;
|
||||
constexpr constant static uint32_t ka3 = ka2*ka;
|
||||
constexpr constant static uint32_t kb3 = kb2*ka+kb;
|
||||
constexpr constant static uint32_t ka4 = ka3*ka;
|
||||
constexpr constant static uint32_t kb4 = kb3*ka+kb;
|
||||
constexpr constant static uint32_t ka5 = ka4*ka;
|
||||
constexpr constant static uint32_t kb5 = kb4*ka+kb;
|
||||
constexpr constant static uint32_t ka6 = ka5*ka;
|
||||
constexpr constant static uint32_t kb6 = kb5*ka+kb;
|
||||
constexpr constant static uint32_t ka7 = ka6*ka;
|
||||
constexpr constant static uint32_t kb7 = kb6*ka+kb;
|
||||
|
||||
static inline half4 gen4(uint32_t val) {
|
||||
thread uint32_t aux[4] = {((ka *val + kb ) & kmask1) ^ kmask2,
|
||||
((ka1*val + kb1) & kmask1) ^ kmask2,
|
||||
((ka2*val + kb2) & kmask1) ^ kmask2,
|
||||
((ka3*val + kb3) & kmask1) ^ kmask2};
|
||||
const thread half * h = (const thread half *)aux;
|
||||
return { h[0]+h[1], h[2]+h[3], h[4]+h[5], h[6]+h[7] };
|
||||
}
|
||||
template <typename T4>
|
||||
static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) {
|
||||
thread uint32_t aux[8] = {((ka *val + kb ) & kmask1) ^ kmask2,
|
||||
((ka1*val + kb1) & kmask1) ^ kmask2,
|
||||
((ka2*val + kb2) & kmask1) ^ kmask2,
|
||||
((ka3*val + kb3) & kmask1) ^ kmask2,
|
||||
((ka4*val + kb4) & kmask1) ^ kmask2,
|
||||
((ka5*val + kb5) & kmask1) ^ kmask2,
|
||||
((ka6*val + kb6) & kmask1) ^ kmask2,
|
||||
((ka7*val + kb7) & kmask1) ^ kmask2};
|
||||
const thread half * h = (const thread half *)aux;
|
||||
if constexpr (is_same_v<T4, half4>) {
|
||||
v1 = { h[0]+h[1], h[2]+h[3], h[4]+h[5], h[6]+h[7] };
|
||||
v2 = { h[8]+h[9], h[10]+h[11], h[12]+h[13], h[14]+h[15] };
|
||||
} else {
|
||||
v1 = { (float)(h[0]+h[1]), (float)(h[ 2]+h[ 3]), (float)(h[ 4]+h[ 5]), (float)(h[ 6]+h[ 7]) };
|
||||
v2 = { (float)(h[8]+h[9]), (float)(h[10]+h[11]), (float)(h[12]+h[13]), (float)(h[14]+h[15]) };
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void kernel_mul_mv_iq2_kt_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@@ -6621,7 +6669,7 @@ void kernel_mul_mv_iq2_kt_f32_impl(
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
const uint row_size = 2 + nb*sizeof(block_iq2_ks);
|
||||
const uint row_size = sizeof(float) + nb*sizeof(block_iq2_kt);
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
@@ -6631,91 +6679,55 @@ void kernel_mul_mv_iq2_kt_f32_impl(
|
||||
device const char * cx = (device const char *) src0 + (first_row + offset0)*row_size;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f};
|
||||
float4 sumf={0.f};
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
const int iq = it/4; // 0 or 1
|
||||
const int ir = it%4; // 0...3
|
||||
const int ix = tiisg/16; // 0...1
|
||||
const int it = tiisg%16; // 0...15
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
||||
device const float4 * y4 = (device const float4 *)y + ix * (QK_K/4) + 4 * it;
|
||||
|
||||
threadgroup float * all_values = (threadgroup float *)shared_values + 32*sgitg;
|
||||
{
|
||||
//int row = tiisg%N_DST;
|
||||
//device const half * dptr = (device const half *)(cx + row*row_size);
|
||||
//const float d = *dptr;
|
||||
//all_values[8*row + tiisg/N_DST] = d*iq2nl_values[tiisg/N_DST];
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
int row = tiisg/8;
|
||||
int pos = tiisg%8;
|
||||
device const half * dptr = (device const half *)(cx + row*row_size);
|
||||
const float d = *dptr;
|
||||
all_values[8*row + pos] = d*kvalues_iq2k_f[pos];
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float4 v1, v2;
|
||||
|
||||
float drow[N_DST];
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
device const float * dptr = (device const float *)(cx + row*row_size);
|
||||
drow[row] = dptr[0] * 31.75f * 1.05f;
|
||||
}
|
||||
|
||||
cx += sizeof(half);
|
||||
device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float));
|
||||
|
||||
uint32_t q32[2];
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
||||
for (int ib = ix; ib < nb; ib += 2) {
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0];
|
||||
yl[i+ 8] = y4[i+32];
|
||||
yl[i+16] = y4[i+64];
|
||||
yl[i+24] = y4[i+96];
|
||||
}
|
||||
|
||||
device const block_iq2_ks * x = (device const block_iq2_ks *)cx + ib;
|
||||
device const uint16_t * q16 = (device const uint16_t *)x->qs + 16*iq + 4*ir;
|
||||
device const uint16_t * sc = (device const uint16_t *)x->scales;
|
||||
device const uint16_t * ex = (device const uint16_t *)&x->extra;
|
||||
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
threadgroup const float * row_values = all_values + 8*row;
|
||||
device const uint16_t * q2 = (device const uint16_t *)(sc + 4);
|
||||
|
||||
uint32_t sc32 = (sc[iq] | (sc[iq] << 12)) & 0x0f0f0f0f;
|
||||
thread const int8_t * s8 = (thread const int8_t *)&sc32;
|
||||
const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf];
|
||||
|
||||
q32[0] = q16[0] | (q16[1] << 16);
|
||||
q32[1] = q16[2] | (q16[3] << 16);
|
||||
Trellis::gen8(q2[2*it+0]+4096, v1, v2);
|
||||
auto sum = v1*y4[0] + v2*y4[1];
|
||||
|
||||
uint8_t extra = ex[0] << 4*(1-iq);
|
||||
Trellis::gen8(q2[2*it+1]+4096, v1, v2);
|
||||
sum += v1*y4[2] + v2*y4[3];
|
||||
|
||||
float4 acc = {0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
threadgroup const float * values = row_values + ((extra >> (2 + l)) & 4);
|
||||
aux32[0] = (q32[0] >> 2*l) & 0x03030303;
|
||||
aux32[1] = (q32[1] >> 2*l) & 0x03030303;
|
||||
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
|
||||
}
|
||||
extra = ex[0] >> (8 + 4*iq);
|
||||
sumf[row] += acc[0] * (s8[0] - (extra & 1 ? 0 : 16)) + acc[1] * (s8[2] - (extra & 2 ? 0 : 16))
|
||||
+ acc[2] * (s8[1] - (extra & 4 ? 0 : 16)) + acc[3] * (s8[3] - (extra & 8 ? 0 : 16));
|
||||
sum *= ls;
|
||||
|
||||
q16 += row_size/2;
|
||||
sc += row_size/2;
|
||||
ex += row_size/2;
|
||||
sumf[row] += sum[0] + sum[1] + sum[2] + sum[3];
|
||||
|
||||
sc += row_size;
|
||||
|
||||
}
|
||||
|
||||
y4 += 4 * QK_K;
|
||||
y4 += QK_K/2;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row += 2) {
|
||||
float2 tmp = {sumf[row], sumf[row+1]};
|
||||
tmp = simd_sum(tmp);
|
||||
if (tiisg < 2) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = tmp[tiisg];
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
if (tiisg < 4) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq2_kt_f32")]]
|
||||
@@ -8264,49 +8276,6 @@ void dequantize_iq4_kss(device const block_iq4_kss * xb, short il, thread type4x
|
||||
}
|
||||
}
|
||||
|
||||
struct Trellis {
|
||||
constexpr constant static uint32_t kmask1 = 0x8fff8fff;
|
||||
constexpr constant static uint32_t kmask2 = 0x3b603b60;
|
||||
constexpr constant static uint32_t ka = 89226354;
|
||||
constexpr constant static uint32_t kb = 64248484;
|
||||
constexpr constant static uint32_t ka1 = ka*ka;
|
||||
constexpr constant static uint32_t kb1 = kb*ka+kb;
|
||||
constexpr constant static uint32_t ka2 = ka1*ka;
|
||||
constexpr constant static uint32_t kb2 = kb1*ka+kb;
|
||||
constexpr constant static uint32_t ka3 = ka2*ka;
|
||||
constexpr constant static uint32_t kb3 = kb2*ka+kb;
|
||||
constexpr constant static uint32_t ka4 = ka3*ka;
|
||||
constexpr constant static uint32_t kb4 = kb3*ka+kb;
|
||||
constexpr constant static uint32_t ka5 = ka4*ka;
|
||||
constexpr constant static uint32_t kb5 = kb4*ka+kb;
|
||||
constexpr constant static uint32_t ka6 = ka5*ka;
|
||||
constexpr constant static uint32_t kb6 = kb5*ka+kb;
|
||||
constexpr constant static uint32_t ka7 = ka6*ka;
|
||||
constexpr constant static uint32_t kb7 = kb6*ka+kb;
|
||||
|
||||
static inline half4 gen4(uint32_t val) {
|
||||
thread uint32_t aux[4] = {((ka *val + kb ) & kmask1) ^ kmask2,
|
||||
((ka1*val + kb1) & kmask1) ^ kmask2,
|
||||
((ka2*val + kb2) & kmask1) ^ kmask2,
|
||||
((ka3*val + kb3) & kmask1) ^ kmask2};
|
||||
const thread half * h = (const thread half *)aux;
|
||||
return { h[0]+h[1], h[2]+h[3], h[4]+h[5], h[6]+h[7] };
|
||||
}
|
||||
static inline void gen8(uint32_t val, thread half4& v1, thread half4& v2) {
|
||||
thread uint32_t aux[8] = {((ka *val + kb ) & kmask1) ^ kmask2,
|
||||
((ka1*val + kb1) & kmask1) ^ kmask2,
|
||||
((ka2*val + kb2) & kmask1) ^ kmask2,
|
||||
((ka3*val + kb3) & kmask1) ^ kmask2,
|
||||
((ka4*val + kb4) & kmask1) ^ kmask2,
|
||||
((ka5*val + kb5) & kmask1) ^ kmask2,
|
||||
((ka6*val + kb6) & kmask1) ^ kmask2,
|
||||
((ka7*val + kb7) & kmask1) ^ kmask2};
|
||||
const thread half * h = (const thread half *)aux;
|
||||
v1 = { h[0]+h[1], h[2]+h[3], h[4]+h[5], h[6]+h[7] };
|
||||
v2 = { h[8]+h[9], h[10]+h[11], h[12]+h[13], h[14]+h[15] };
|
||||
}
|
||||
};
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_kt(device const block_iq2_kt * x, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256
|
||||
|
||||
Reference in New Issue
Block a user