mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
bitnet(scale in a separate tensor): Metal
iq2_bn TG-128 drops to 84 t/s, while I see in the logs that we had 97 t/s. If true, that's a pretty massive performance penalty for TG. Let me guess: ggml_mul is not exactly the most performant operation on Metal.
This commit is contained in:
@@ -5033,7 +5033,6 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
|||||||
|
|
||||||
float yl[8];
|
float yl[8];
|
||||||
float sumf[N_DST]={0.f}, all_sum;
|
float sumf[N_DST]={0.f}, all_sum;
|
||||||
float d1bn[N_DST];
|
|
||||||
|
|
||||||
const int nb32 = nb * (QK_IQ1BN / 32);
|
const int nb32 = nb * (QK_IQ1BN / 32);
|
||||||
|
|
||||||
@@ -5042,10 +5041,6 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
|||||||
|
|
||||||
device const float * y4 = y + 32 * ix + 8 * ir;
|
device const float * y4 = y + 32 * ix + 8 * ir;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
|
||||||
d1bn[row] = iq1bn_fp8_to_float(x[nb*row].extra & 0xff);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t aux32[2];
|
uint32_t aux32[2];
|
||||||
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
|
||||||
|
|
||||||
@@ -5060,13 +5055,13 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
|||||||
const int ib = ib32 % (QK_IQ1BN / 32);
|
const int ib = ib32 % (QK_IQ1BN / 32);
|
||||||
|
|
||||||
device const block_iq1_bn * xr = x + ibl;
|
device const block_iq1_bn * xr = x + ibl;
|
||||||
device const uint16_t * extra = (device const uint16_t *)&xr->extra;
|
device const uint8_t * extra = (device const uint8_t *)&xr->extra;
|
||||||
device const uint8_t * ql = xr->ql + 4 * ib + ir;
|
device const uint8_t * ql = xr->ql + 4 * ib + ir;
|
||||||
device const uint8_t * qh = xr->qh + 2 * ib + ir/2;
|
device const uint8_t * qh = xr->qh + 2 * ib + ir/2;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
|
||||||
uint8_t signs = extra[0] >> (8 + 4*ib + ir);
|
uint8_t signs = extra[0] >> (4*ib + ir);
|
||||||
float acc = 0.f;
|
float acc = 0.f;
|
||||||
|
|
||||||
uint32_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)];
|
uint32_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)];
|
||||||
@@ -5077,7 +5072,7 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
|||||||
|
|
||||||
sumf[row] += (signs & 1 ? sumy-acc : acc-sumy);
|
sumf[row] += (signs & 1 ? sumy-acc : acc-sumy);
|
||||||
|
|
||||||
extra += nb*sizeof(block_iq1_bn)/2;
|
extra += nb*sizeof(block_iq1_bn);
|
||||||
ql += nb*sizeof(block_iq1_bn);
|
ql += nb*sizeof(block_iq1_bn);
|
||||||
qh += nb*sizeof(block_iq1_bn);
|
qh += nb*sizeof(block_iq1_bn);
|
||||||
}
|
}
|
||||||
@@ -5089,7 +5084,7 @@ void kernel_mul_mv_iq1_bn_f32_impl(
|
|||||||
half2 r = {(half)sumf[row], (half)sumf[row+1]};
|
half2 r = {(half)sumf[row], (half)sumf[row+1]};
|
||||||
r = simd_sum(r);
|
r = simd_sum(r);
|
||||||
if (tiisg < 2) {
|
if (tiisg < 2) {
|
||||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg] * d1bn[row + tiisg];
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -5130,7 +5125,6 @@ void kernel_mul_mv_iq2_bn_f32_impl(
|
|||||||
|
|
||||||
float yl[16];
|
float yl[16];
|
||||||
float sumf[N_DST]={0.f};
|
float sumf[N_DST]={0.f};
|
||||||
float d1bn[N_DST];
|
|
||||||
|
|
||||||
const int nb32 = nb * (QK_IQ1BN / 32);
|
const int nb32 = nb * (QK_IQ1BN / 32);
|
||||||
|
|
||||||
@@ -5139,10 +5133,6 @@ void kernel_mul_mv_iq2_bn_f32_impl(
|
|||||||
|
|
||||||
device const float * y4 = y + 64 * ix + 4 * ir;
|
device const float * y4 = y + 64 * ix + 4 * ir;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
|
||||||
d1bn[row] = x[nb*row].d;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int ib = ix; ib < nb; ib += 8) {
|
for (int ib = ix; ib < nb; ib += 8) {
|
||||||
|
|
||||||
float sumy = 0.f;
|
float sumy = 0.f;
|
||||||
@@ -5177,7 +5167,7 @@ void kernel_mul_mv_iq2_bn_f32_impl(
|
|||||||
half2 r = {(half)sumf[row], (half)sumf[row+1]};
|
half2 r = {(half)sumf[row], (half)sumf[row+1]};
|
||||||
r = simd_sum(r);
|
r = simd_sum(r);
|
||||||
if (tiisg < 2) {
|
if (tiisg < 2) {
|
||||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg] * d1bn[row + tiisg];
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -5943,10 +5933,9 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
|
|||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
|
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
|
||||||
// il is in 0...3
|
// il is in 0...3
|
||||||
const float d = iq1bn_fp8_to_float(xb->extra & 0xff);
|
uint8_t gs = xb->extra >> 2*il;
|
||||||
uint8_t gs = xb->extra >> (8 + 2*il);
|
const float d1 = gs & 1 ? -1 : 1;
|
||||||
const float d1 = gs & 1 ? -d : d;
|
const float d2 = gs & 2 ? -1 : 1;
|
||||||
const float d2 = gs & 2 ? -d : d;
|
|
||||||
|
|
||||||
uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
|
uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
|
||||||
uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
|
uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
|
||||||
@@ -5969,13 +5958,11 @@ void dequantize_iq2_bn(device const block_iq2_bn * xb, short il, thread type4x4
|
|||||||
// il is in 0...3
|
// il is in 0...3
|
||||||
constexpr float k_scale[4] = {1.f, 0.25f, 0.0625f, 0.015625f};
|
constexpr float k_scale[4] = {1.f, 0.25f, 0.0625f, 0.015625f};
|
||||||
constexpr uint8_t k_mask[4] = {0x03, 0x0c, 0x30, 0xc0};
|
constexpr uint8_t k_mask[4] = {0x03, 0x0c, 0x30, 0xc0};
|
||||||
float d = xb->d;
|
const float d = k_scale[il];
|
||||||
const float m = -d;
|
|
||||||
d *= k_scale[il];
|
|
||||||
uint8_t mask = k_mask[il];
|
uint8_t mask = k_mask[il];
|
||||||
|
|
||||||
for (int j = 0; j < 16; ++j) {
|
for (int j = 0; j < 16; ++j) {
|
||||||
reg[j/4][j%4] = d * (xb->qs[j] & mask) + m;
|
reg[j/4][j%4] = d * (xb->qs[j] & mask) - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user