mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
New iq2_kt: Metal - very slow.
It seems Apple Silicon cannot quickly add 4 8-bit ints. Or I don't know how to do it - but I didn't find anything in the Metal Shading Language Specification. So, performance is quite a bit worse than the original trellis.
This commit is contained in:
@@ -6613,6 +6613,18 @@ struct Trellis3 {
|
||||
for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3];
|
||||
return result;
|
||||
}
|
||||
template <typename T4>
|
||||
static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) {
|
||||
thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3};
|
||||
uint32_t aux32[2];
|
||||
thread const int8_t * a8 = (thread const int8_t *)aux32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
aux32[0] = aux[i] & kmask;
|
||||
aux32[1] = (ka3*aux[i] + kb3) & kmask;
|
||||
v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3];
|
||||
v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Trellis {
|
||||
@@ -6710,7 +6722,7 @@ void kernel_mul_mv_iq2_kt_f32_impl(
|
||||
float drow[N_DST];
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
device const float * dptr = (device const float *)(cx + row*row_size);
|
||||
drow[row] = dptr[0] * 31.75f * 1.05f;
|
||||
drow[row] = dptr[0] * 1.05f;
|
||||
}
|
||||
|
||||
device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float));
|
||||
@@ -6725,10 +6737,10 @@ void kernel_mul_mv_iq2_kt_f32_impl(
|
||||
|
||||
const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf];
|
||||
|
||||
Trellis::gen8(q2[2*it+0]+4096, v1, v2);
|
||||
Trellis3::gen8(q2[2*it+0]+4096, v1, v2);
|
||||
auto sum = v1*y4[0] + v2*y4[1];
|
||||
|
||||
Trellis::gen8(q2[2*it+1]+4096, v1, v2);
|
||||
Trellis3::gen8(q2[2*it+1]+4096, v1, v2);
|
||||
sum += v1*y4[2] + v2*y4[3];
|
||||
|
||||
sum *= ls;
|
||||
@@ -8561,19 +8573,18 @@ template <typename type4x4>
|
||||
void dequantize_iq2_kt(device const block_iq2_kt * x, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256
|
||||
int ib32 = il/2;
|
||||
half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 31.75h * 1.05h;
|
||||
half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 1.05h;
|
||||
device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2);
|
||||
|
||||
half4 v1, v2;
|
||||
char4 v1, v2;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
Trellis::gen8(q2[i]+4096, v1, v2);
|
||||
v1 *= scale; v2 *= scale;
|
||||
Trellis3::gen8(q2[i]+4096, v1, v2);
|
||||
if constexpr (is_same_v<type4x4, half4x4>) {
|
||||
reg[2*i+0] = v1;
|
||||
reg[2*i+1] = v2;
|
||||
reg[2*i+0] = {scale*(half)v1[0], scale*(half)v1[1], scale*(half)v1[2], scale*(half)v1[3]};
|
||||
reg[2*i+1] = {scale*(half)v2[0], scale*(half)v2[1], scale*(half)v2[2], scale*(half)v2[3]};
|
||||
} else {
|
||||
reg[2*i+0] = {(float)v1[0], (float)v1[1], (float)v1[2], (float)v1[3]};
|
||||
reg[2*i+1] = {(float)v2[0], (float)v2[1], (float)v2[2], (float)v2[3]};
|
||||
reg[2*i+0] = {scale*(float)v1[0], scale*(float)v1[1], scale*(float)v1[2], scale*(float)v1[3]};
|
||||
reg[2*i+1] = {scale*(float)v2[0], scale*(float)v2[1], scale*(float)v2[2], scale*(float)v2[3]};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user