mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
iq3_k: slightly faster Metal dot product
This commit is contained in:
@@ -2290,11 +2290,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
|
||||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0 ||
|
||||
src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
|
||||
src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : 16*sizeof(float);
|
||||
else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
@@ -2697,11 +2697,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 ||
|
||||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
|
||||
src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
|
||||
src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : 16*sizeof(float);
|
||||
else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
|
||||
@@ -6449,8 +6449,14 @@ void kernel_mul_mv_iq3_k_f32_impl(
|
||||
device const block_iq3_k * x = (device const block_iq3_k *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
threadgroup float * all_values = (threadgroup float *)shared_values + 16*sgitg;
|
||||
{
|
||||
if (tiisg < 16) all_values[tiisg] = kvalues_iq3k_f[tiisg];
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
float sumf[N_DST]={0.f};
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
@@ -6493,7 +6499,8 @@ void kernel_mul_mv_iq3_k_f32_impl(
|
||||
|
||||
float4 acc = {0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
constant float * values = kvalues_iq3k_f + (extra & 8);
|
||||
threadgroup const float * values = all_values + (extra & 8);
|
||||
//constant float * values = kvalues_iq3k_f + (extra & 8);
|
||||
aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404);
|
||||
aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404);
|
||||
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
|
||||
@@ -6510,10 +6517,11 @@ void kernel_mul_mv_iq3_k_f32_impl(
|
||||
y4 += 4 * QK_K;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
||||
for (int row = 0; row < N_DST; row += 2) {
|
||||
float2 tmp{sumf[row], sumf[row+1]};
|
||||
tmp = simd_sum(tmp);
|
||||
if (tiisg < 2) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = tmp[tiisg];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6539,11 +6547,12 @@ kernel void kernel_mul_mv_iq3_k_f32(
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq3_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq3_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq4_k_f32_impl(
|
||||
|
||||
Reference in New Issue
Block a user