mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
MoE improvements on Metal
This version beats mainline, there are things I don't understand: * Mianline has effectively gone to GEMV for MUL_MAT_ID. We can do the same, but we are 30% slower. Why? * Using actual GEMM, we beat mainline with ubtach size of 128. But then performance degrades. Why?
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1652,13 +1652,18 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
yl[i] = yb[i];
|
||||
}
|
||||
|
||||
device const block_q8_0 * xr = x + ib;
|
||||
|
||||
for (int row = 0; row < nr; row++) {
|
||||
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
|
||||
//device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
|
||||
device const int8_t * qs = xr->qs + NB_Q8_0*il;
|
||||
float sumq = 0.f;
|
||||
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
||||
sumq += qs[iq] * yl[iq];
|
||||
}
|
||||
sumf[row] += sumq*x[ib+row*nb].d;
|
||||
//sumf[row] += sumq*x[ib+row*nb].d;
|
||||
sumf[row] += sumq*xr->d;
|
||||
xr += nb;
|
||||
}
|
||||
|
||||
yb += NB_Q8_0 * nw;
|
||||
@@ -5746,7 +5751,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
|
||||
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
||||
//threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
||||
const int nb = ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
@@ -5766,8 +5771,8 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
const int ib = it/2;
|
||||
const int il = it%2;
|
||||
|
||||
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
//shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float sumf[2]={0.f}, all_sum;
|
||||
@@ -5793,15 +5798,19 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
|
||||
aux32[0] = q4[0] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
//qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
//qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
qf1 = {kvalues_iq4nl_f[q8[0]], kvalues_iq4nl_f[q8[1]], kvalues_iq4nl_f[q8[2]], kvalues_iq4nl_f[q8[3]]};
|
||||
qf2 = {kvalues_iq4nl_f[q8[4]], kvalues_iq4nl_f[q8[5]], kvalues_iq4nl_f[q8[6]], kvalues_iq4nl_f[q8[7]]};
|
||||
acc1 += yl[0] * qf1;
|
||||
acc2 += yl[1] * qf2;
|
||||
|
||||
aux32[0] = q4[1] & 0x0f0f0f0f;
|
||||
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
||||
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
//qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
||||
//qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
||||
qf1 = {kvalues_iq4nl_f[q8[0]], kvalues_iq4nl_f[q8[1]], kvalues_iq4nl_f[q8[2]], kvalues_iq4nl_f[q8[3]]};
|
||||
qf2 = {kvalues_iq4nl_f[q8[4]], kvalues_iq4nl_f[q8[5]], kvalues_iq4nl_f[q8[6]], kvalues_iq4nl_f[q8[7]]};
|
||||
acc1 += yl[2] * qf1;
|
||||
acc2 += yl[3] * qf2;
|
||||
|
||||
@@ -7218,13 +7227,43 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg
|
||||
template <typename type4x4>
|
||||
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
const half d = xb->d;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
||||
if constexpr (is_same_v<type4x4, half4x4>) {
|
||||
const half d = xb->d;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg[i/4][i%4] = (half)qs[i + 16*il] * d;
|
||||
}
|
||||
} else {
|
||||
const float d = xb->d;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg[i/4][i%4] = qs[i + 16*il] * d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//template <typename type4x4>
|
||||
//void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||
// device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
// const half d = xb->d;
|
||||
//
|
||||
// for (int i = 0; i < 16; i++) {
|
||||
// reg[i/4][i%4] = (qs[i + 16*il] * d);
|
||||
// }
|
||||
//}
|
||||
|
||||
//template <typename type4x4>
|
||||
//void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||
// device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
// const float d = xb->d;
|
||||
//
|
||||
// float4x4 reg_f;
|
||||
//
|
||||
// for (int i = 0; i < 16; i++) {
|
||||
// reg_f[i/4][i%4] = d * qs[i + 16*il];
|
||||
// }
|
||||
//
|
||||
// reg = (type4x4)reg_f;
|
||||
//}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & reg) {
|
||||
const float d = xb->d;
|
||||
@@ -8246,39 +8285,6 @@ kernel void kernel_mul_mm_id(
|
||||
uint ntg = ntg3.x * ntg3.y * ntg3.z;
|
||||
uint n = nei0*nei1;
|
||||
|
||||
//uint npt = (n + ntg - 1) / ntg;
|
||||
//uint first = tiitg * npt;
|
||||
//uint last = first + npt <= n ? first + npt : n;
|
||||
|
||||
//uint nhave = 0;
|
||||
//for (uint i = first; i < last; ++i) {
|
||||
// uint ii0 = i % nei0;
|
||||
// uint ii1 = i / nei0;
|
||||
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
// if (id == i02) ++nhave;
|
||||
//}
|
||||
//threadgroup uint * nums = (threadgroup uint *)shared_memory;
|
||||
//nums[tiitg] = nhave;
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
//uint nprev = 0;
|
||||
//for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
|
||||
//int64_t _ne1 = nprev;
|
||||
//for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
|
||||
|
||||
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
//for (uint i = first; i < last; ++i) {
|
||||
// uint ii0 = i % nei0;
|
||||
// uint ii1 = i / nei0;
|
||||
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
// if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
|
||||
//}
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
//
|
||||
// The following is slightly faster than the commented out version above
|
||||
//
|
||||
uint nhave = 0;
|
||||
for (uint i = tiitg; i < n; i += ntg) {
|
||||
uint ii0 = i % nei0;
|
||||
@@ -8290,10 +8296,31 @@ kernel void kernel_mul_mm_id(
|
||||
nums[tiitg] = nhave;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
uint nprev = 0;
|
||||
for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
|
||||
int64_t _ne1 = nprev;
|
||||
for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
|
||||
uint stride = 1;
|
||||
while (stride <= ntg/2) {
|
||||
uint index = (tiitg+1)*stride*2 - 1; // index - stride = 2*tiitg*stride + stride - 1;
|
||||
if (index < ntg) nums[index] += nums[index-stride];
|
||||
stride <<= 1;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
stride = ntg/2;
|
||||
while (stride > 0) {
|
||||
uint index = (tiitg+1)*stride*2 - 1;
|
||||
if (index+stride < ntg) nums[index+stride] += nums[index];
|
||||
stride >>= 1;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
int64_t _ne1 = nums[ntg-1];
|
||||
if (!_ne1) return;
|
||||
|
||||
uint nprev = tiitg > 0 ? nums[tiitg-1] : 0;
|
||||
|
||||
//uint ncum = 0;
|
||||
//for (uint i = 0; i < tiitg; ++i) ncum += nums[i];
|
||||
//uint nprev = ncum;
|
||||
//for (uint i = tiitg; i < ntg; ++i) ncum += nums[i];
|
||||
//if (!ncum) return;
|
||||
//int64_t _ne1 = ncum;
|
||||
|
||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
for (uint i = tiitg; i < n; i += ntg) {
|
||||
@@ -8304,26 +8331,6 @@ kernel void kernel_mul_mm_id(
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// This is the original version that is ridiculously slow.
|
||||
//// row indices
|
||||
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
|
||||
//// TODO: parallelize this loop
|
||||
//int64_t _ne1 = 0;
|
||||
//for (ushort ii1 = 0; ii1 < nei1; ii1++) {
|
||||
// for (ushort ii0 = 0; ii0 < nei0; ii0++) {
|
||||
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
// if (id == i02) {
|
||||
// //if (tiitg == 0) {
|
||||
// rowids[_ne1] = ushort2(ii0, ii1);
|
||||
// //}
|
||||
// _ne1++;
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
kernel_mul_mm_id_impl<Dequantizer>(
|
||||
src0,
|
||||
src1,
|
||||
|
||||
Reference in New Issue
Block a user