mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-13 23:40:09 +00:00
iq2_ks: faster Metal
LLaMA-3.1-8B: PP-512 = 475.22 ± 0.37 t/s TG-128 = 45.32 ± 0.03 t/s
This commit is contained in:
@@ -2289,10 +2289,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 ||
|
||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
|
||||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| src0t == GGML_TYPE_IQ2_KS ||
|
||||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
|
||||
src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_KS) {
|
||||
const int mem_size = 64*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
|
||||
@@ -3685,6 +3685,7 @@ constexpr constant static float kvalues_iq6k_f[128] = {
|
||||
};
|
||||
|
||||
constexpr constant static float kvalues_iq2k_f[8] = { -31.f, -13.f, 1.f, 17.f, -26.f, -8.f, 6.f, 22.f };
|
||||
constexpr constant static half kvalues_iq2k_h[8] = { -31.h, -13.h, 1.h, 17.h, -26.h, -8.h, 6.h, 22.h };
|
||||
|
||||
constexpr constant static float kvalues_iq3k_f[16] = { -63.f, -40.f, -23.f, -10.f, 1.f, 13.f, 28.f, 47.f, -59.f, -36.f, -19.f, -6.f, 5.f, 17.f, 32.f, 51.f };
|
||||
constexpr constant static half kvalues_iq3k_h[16] = { -63.h, -40.h, -23.h, -10.h, 1.h, 13.h, 28.h, 47.h, -59.h, -36.h, -19.h, -6.h, 5.h, 17.h, 32.h, 51.h };
|
||||
@@ -6304,6 +6305,24 @@ void kernel_mul_mv_iq2_ks_f32_impl(
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
||||
|
||||
threadgroup float * all_values = (threadgroup float *)shared_values + 32*sgitg;
|
||||
{
|
||||
//int row = tiisg%N_DST;
|
||||
//device const half * dptr = (device const half *)(cx + row*row_size);
|
||||
//const float d = *dptr;
|
||||
//all_values[8*row + tiisg/N_DST] = d*iq2nl_values[tiisg/N_DST];
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
int row = tiisg/8;
|
||||
int pos = tiisg%8;
|
||||
device const half * dptr = (device const half *)(cx + row*row_size);
|
||||
const float d = *dptr;
|
||||
all_values[8*row + pos] = d*kvalues_iq2k_f[pos];
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
cx += sizeof(half);
|
||||
|
||||
uint32_t q32[2];
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
||||
@@ -6317,15 +6336,14 @@ void kernel_mul_mv_iq2_ks_f32_impl(
|
||||
yl[i+24] = y4[i+96];
|
||||
}
|
||||
|
||||
device const block_iq2_ks * x = (device const block_iq2_ks *)cx + ib;
|
||||
device const uint16_t * q16 = (device const uint16_t *)x->qs + 16*iq + 4*ir;
|
||||
device const uint16_t * sc = (device const uint16_t *)x->scales;
|
||||
device const uint16_t * ex = (device const uint16_t *)&x->extra;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
device const half * dptr = (device const half *)(cx + row*row_size);
|
||||
const float d = *dptr;
|
||||
|
||||
device const block_iq2_ks * x = (device const block_iq2_ks *)(dptr + 1);
|
||||
device const block_iq2_ks & xb = x[ib];
|
||||
device const uint16_t * q16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir;
|
||||
device const uint16_t * sc = (device const uint16_t *)xb.scales;
|
||||
threadgroup const float * row_values = all_values + 8*row;
|
||||
|
||||
uint32_t sc32 = (sc[iq] | (sc[iq] << 12)) & 0x0f0f0f0f;
|
||||
thread const int8_t * s8 = (thread const int8_t *)&sc32;
|
||||
@@ -6333,18 +6351,22 @@ void kernel_mul_mv_iq2_ks_f32_impl(
|
||||
q32[0] = q16[0] | (q16[1] << 16);
|
||||
q32[1] = q16[2] | (q16[3] << 16);
|
||||
|
||||
uint8_t extra = xb.extra << 4*(1-iq);
|
||||
uint8_t extra = ex[0] << 4*(1-iq);
|
||||
|
||||
float4 acc = {0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
constant float * values = kvalues_iq2k_f + ((extra >> (2 + l)) & 4);
|
||||
threadgroup const float * values = row_values + ((extra >> (2 + l)) & 4);
|
||||
aux32[0] = (q32[0] >> 2*l) & 0x03030303;
|
||||
aux32[1] = (q32[1] >> 2*l) & 0x03030303;
|
||||
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
|
||||
}
|
||||
extra = xb.extra >> (8 + 4*iq);
|
||||
sumf[row] += d * (acc[0] * (s8[0] - (extra & 1 ? 0 : 16)) + acc[1] * (s8[2] - (extra & 2 ? 0 : 16))
|
||||
+ acc[2] * (s8[1] - (extra & 4 ? 0 : 16)) + acc[3] * (s8[3] - (extra & 8 ? 0 : 16)));
|
||||
extra = ex[0] >> (8 + 4*iq);
|
||||
sumf[row] += acc[0] * (s8[0] - (extra & 1 ? 0 : 16)) + acc[1] * (s8[2] - (extra & 2 ? 0 : 16))
|
||||
+ acc[2] * (s8[1] - (extra & 4 ? 0 : 16)) + acc[3] * (s8[3] - (extra & 8 ? 0 : 16));
|
||||
|
||||
q16 += row_size/2;
|
||||
sc += row_size/2;
|
||||
ex += row_size/2;
|
||||
|
||||
}
|
||||
|
||||
@@ -6381,11 +6403,12 @@ kernel void kernel_mul_mv_iq2_ks_f32(
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq3_k_f32_impl(
|
||||
@@ -7703,8 +7726,8 @@ void dequantize_iq2_ks(device const block_iq2_ks * xb, short il, thread type4x4
|
||||
const short ib32 = il/2;
|
||||
half d = (((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf) - ((xb->extra >> (8 + ib32)) & 1 ? 0 : 16));
|
||||
|
||||
constant int8_t * int_values = iq2nl_values + 4*((xb->extra >> ib32) & 1);
|
||||
half4 values = { d * int_values[0], d * int_values[1], d * int_values[2], d * int_values[3] };
|
||||
constant half4 * half_values = (constant half4 *)kvalues_iq2k_h;
|
||||
half4 values = half_values[(xb->extra >> ib32) & 1] * d;
|
||||
|
||||
const int shift = 2*((il%8)/2);
|
||||
thread uint16_t aux16[2];
|
||||
|
||||
Reference in New Issue
Block a user