iq2_ks: faster Metal

LLaMA-3.1-8B:
PP-512 = 475.22 ± 0.37 t/s
TG-128 =  45.32 ± 0.03 t/s
This commit is contained in:
Iwan Kawrakow
2024-10-13 12:23:14 +03:00
parent 5cafaf5481
commit f9f15c27b6
2 changed files with 44 additions and 16 deletions

View File

@@ -2289,10 +2289,15 @@ static enum ggml_status ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| src0t == GGML_TYPE_IQ2_KS ||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_KS) {
const int mem_size = 64*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];

View File

@@ -3685,6 +3685,7 @@ constexpr constant static float kvalues_iq6k_f[128] = {
};
constexpr constant static float kvalues_iq2k_f[8] = { -31.f, -13.f, 1.f, 17.f, -26.f, -8.f, 6.f, 22.f };
constexpr constant static half kvalues_iq2k_h[8] = { -31.h, -13.h, 1.h, 17.h, -26.h, -8.h, 6.h, 22.h };
constexpr constant static float kvalues_iq3k_f[16] = { -63.f, -40.f, -23.f, -10.f, 1.f, 13.f, 28.f, 47.f, -59.f, -36.f, -19.f, -6.f, 5.f, 17.f, 32.f, 51.f };
constexpr constant static half kvalues_iq3k_h[16] = { -63.h, -40.h, -23.h, -10.h, 1.h, 13.h, 28.h, 47.h, -59.h, -36.h, -19.h, -6.h, 5.h, 17.h, 32.h, 51.h };
@@ -6304,6 +6305,24 @@ void kernel_mul_mv_iq2_ks_f32_impl(
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
threadgroup float * all_values = (threadgroup float *)shared_values + 32*sgitg;
{
//int row = tiisg%N_DST;
//device const half * dptr = (device const half *)(cx + row*row_size);
//const float d = *dptr;
//all_values[8*row + tiisg/N_DST] = d*iq2nl_values[tiisg/N_DST];
//threadgroup_barrier(mem_flags::mem_threadgroup);
int row = tiisg/8;
int pos = tiisg%8;
device const half * dptr = (device const half *)(cx + row*row_size);
const float d = *dptr;
all_values[8*row + pos] = d*kvalues_iq2k_f[pos];
simdgroup_barrier(mem_flags::mem_none);
//threadgroup_barrier(mem_flags::mem_threadgroup);
}
cx += sizeof(half);
uint32_t q32[2];
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
@@ -6317,15 +6336,14 @@ void kernel_mul_mv_iq2_ks_f32_impl(
yl[i+24] = y4[i+96];
}
device const block_iq2_ks * x = (device const block_iq2_ks *)cx + ib;
device const uint16_t * q16 = (device const uint16_t *)x->qs + 16*iq + 4*ir;
device const uint16_t * sc = (device const uint16_t *)x->scales;
device const uint16_t * ex = (device const uint16_t *)&x->extra;
for (int row = 0; row < N_DST; row++) {
device const half * dptr = (device const half *)(cx + row*row_size);
const float d = *dptr;
device const block_iq2_ks * x = (device const block_iq2_ks *)(dptr + 1);
device const block_iq2_ks & xb = x[ib];
device const uint16_t * q16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir;
device const uint16_t * sc = (device const uint16_t *)xb.scales;
threadgroup const float * row_values = all_values + 8*row;
uint32_t sc32 = (sc[iq] | (sc[iq] << 12)) & 0x0f0f0f0f;
thread const int8_t * s8 = (thread const int8_t *)&sc32;
@@ -6333,18 +6351,22 @@ void kernel_mul_mv_iq2_ks_f32_impl(
q32[0] = q16[0] | (q16[1] << 16);
q32[1] = q16[2] | (q16[3] << 16);
uint8_t extra = xb.extra << 4*(1-iq);
uint8_t extra = ex[0] << 4*(1-iq);
float4 acc = {0.f};
for (int l = 0; l < 4; ++l) {
constant float * values = kvalues_iq2k_f + ((extra >> (2 + l)) & 4);
threadgroup const float * values = row_values + ((extra >> (2 + l)) & 4);
aux32[0] = (q32[0] >> 2*l) & 0x03030303;
aux32[1] = (q32[1] >> 2*l) & 0x03030303;
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
}
extra = xb.extra >> (8 + 4*iq);
sumf[row] += d * (acc[0] * (s8[0] - (extra & 1 ? 0 : 16)) + acc[1] * (s8[2] - (extra & 2 ? 0 : 16))
+ acc[2] * (s8[1] - (extra & 4 ? 0 : 16)) + acc[3] * (s8[3] - (extra & 8 ? 0 : 16)));
extra = ex[0] >> (8 + 4*iq);
sumf[row] += acc[0] * (s8[0] - (extra & 1 ? 0 : 16)) + acc[1] * (s8[2] - (extra & 2 ? 0 : 16))
+ acc[2] * (s8[1] - (extra & 4 ? 0 : 16)) + acc[3] * (s8[3] - (extra & 8 ? 0 : 16));
q16 += row_size/2;
sc += row_size/2;
ex += row_size/2;
}
@@ -6381,11 +6403,12 @@ kernel void kernel_mul_mv_iq2_ks_f32(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
kernel_mul_mv_iq2_ks_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq3_k_f32_impl(
@@ -7703,8 +7726,8 @@ void dequantize_iq2_ks(device const block_iq2_ks * xb, short il, thread type4x4
const short ib32 = il/2;
half d = (((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf) - ((xb->extra >> (8 + ib32)) & 1 ? 0 : 16));
constant int8_t * int_values = iq2nl_values + 4*((xb->extra >> ib32) & 1);
half4 values = { d * int_values[0], d * int_values[1], d * int_values[2], d * int_values[3] };
constant half4 * half_values = (constant half4 *)kvalues_iq2k_h;
half4 values = half_values[(xb->extra >> ib32) & 1] * d;
const int shift = 2*((il%8)/2);
thread uint16_t aux16[2];