mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
New iq4_kt trellis: not working Metal implementation
This commit is contained in:
@@ -6596,6 +6596,25 @@ void kernel_mul_mv_iq2_k_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
struct Trellis3 {
|
||||
constexpr constant static uint32_t kmask = 0x3f3f3f3f;
|
||||
constexpr constant static uint32_t ka = 89226354;
|
||||
constexpr constant static uint32_t kb = 64248484;
|
||||
constexpr constant static uint32_t ka1 = ka*ka;
|
||||
constexpr constant static uint32_t kb1 = kb*ka+kb;
|
||||
constexpr constant static uint32_t ka2 = ka1*ka;
|
||||
constexpr constant static uint32_t kb2 = kb1*ka+kb;
|
||||
constexpr constant static uint32_t ka3 = ka2*ka;
|
||||
constexpr constant static uint32_t kb3 = kb2*ka+kb;
|
||||
static inline char4 gen4(uint32_t val) {
|
||||
thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask};
|
||||
thread const int8_t * a8 = (thread const int8_t *)aux;
|
||||
char4 result;
|
||||
for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3];
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct Trellis {
|
||||
constexpr constant static uint32_t kmask1 = 0x8fff8fff;
|
||||
constexpr constant static uint32_t kmask2 = 0x3b603b60;
|
||||
@@ -8586,20 +8605,20 @@ void dequantize_iq4_kt(device const block_iq4_kt * x, short il, float d, thread
|
||||
device const uint32_t * shb = x->qs;
|
||||
device const uint8_t * ql = (device const uint8_t *)(shb + 8);
|
||||
device const uint8_t * qh = ql + 64;
|
||||
float scale = d * (((shb[ib32] & 0xff) >> 1) - 64);
|
||||
const int ls = (shb[ib32] & 0xff) >> 1;
|
||||
const float scale = d * (ls - 64);
|
||||
const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15);
|
||||
|
||||
const int jj = ib32*8 + 4*(il%2);
|
||||
ql += jj;
|
||||
qh += jj%32;
|
||||
ql += 8*ib32;
|
||||
qh += 8*(ib32%4);
|
||||
|
||||
uint32_t sh = (shb[ib32] >> (8 + 12*(il%2))) << 12;
|
||||
const int shift = 8 - 4*(jj/32);
|
||||
const int shift = 8 - 4*(ib32/4);
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset;
|
||||
auto v = (float4)Trellis::gen4(idx);
|
||||
reg[i] = v * scale;
|
||||
auto c4 = Trellis3::gen4(idx);
|
||||
reg[i] = {scale*c4[0], scale*c4[1], scale*c4[2], scale*c4[3]};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8931,18 +8950,17 @@ struct DequantizerKT4 {
|
||||
using type4x4 = T4x4;
|
||||
DequantizerKT4(device const char * cx, short il = 0) : il(il) {
|
||||
device const float * dptr = (device const float *)cx;
|
||||
d[0] = dptr[0] * 31.75f * 1.01f;
|
||||
d[1] = dptr[1];
|
||||
d = dptr[0] * 1.01f;
|
||||
x = (device const Block *)(dptr + 2);
|
||||
}
|
||||
inline void convert(thread T4x4& t) const {
|
||||
float4x4 tmp;
|
||||
dequantize_iq4_kt(x, il, d[0], tmp);
|
||||
dequantize_iq4_kt(x, il, d, tmp);
|
||||
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j];
|
||||
}
|
||||
inline void convert(int64_t ind, thread T4x4& t) {
|
||||
float4x4 tmp;
|
||||
dequantize_iq4_kt(x + ind/nl, ind%nl, d[0], tmp);
|
||||
dequantize_iq4_kt(x + ind/nl, ind%nl, d, tmp);
|
||||
for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j];
|
||||
}
|
||||
inline void next() {
|
||||
@@ -8951,7 +8969,7 @@ struct DequantizerKT4 {
|
||||
}
|
||||
device const Block * x;
|
||||
short il;
|
||||
float d[2];
|
||||
float d;
|
||||
};
|
||||
|
||||
template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(half d, device const Block *, short, thread T4x4&), bool may_not_be_aligned = false>
|
||||
|
||||
Reference in New Issue
Block a user