mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
iq1_bn(Metal): 686 -> 702 t/s for PP-512
This commit is contained in:
@@ -7476,22 +7476,23 @@ template <typename type4x4>
|
|||||||
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
|
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
|
||||||
// il is in 0...3
|
// il is in 0...3
|
||||||
|
|
||||||
constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
|
constexpr uint16_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||||
|
constexpr half k_values[3] = {-1.h, 0.h, 1.h};
|
||||||
|
|
||||||
int i = 0;
|
|
||||||
for (int k = 0; k < 3; ++k) {
|
for (int k = 0; k < 3; ++k) {
|
||||||
uint8_t q = xb->ql[3*il + k];
|
uint16_t q = xb->ql[3*il + k];
|
||||||
for (int j = 0; j < 5; ++j) {
|
int i = 5*k + 4;
|
||||||
uint8_t v = k_mult[j]*q;
|
for (int j = 4; j >= 0; --j) {
|
||||||
int8_t vs = 3*v >> 8;
|
uint16_t v = q & 0xff;
|
||||||
//int8_t vs = (v + (v >> 1)) >> 7;
|
v += v << 1;
|
||||||
reg[i/4][i%4] = vs - 1;
|
reg[i/4][i%4] = k_values[v >> 8];
|
||||||
++i;
|
q += q << 1;
|
||||||
|
--i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
uint8_t v = k_mult[il]*xb->extra;
|
uint16_t v = (k_mult[il]*xb->extra) & 0xff;
|
||||||
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
|
v += v << 1;
|
||||||
reg[3][3] = vs - 1;
|
reg[3][3] = k_values[v >> 8];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
|
|||||||
Reference in New Issue
Block a user