Metal: much faster MoE prompt processing (#307)

* MoE improvements on Metal

This version beats mainline, there are things I don't understand:
* Mianline has effectively gone to GEMV for MUL_MAT_ID. We can do the
  same, but we are 30% slower. Why?
* Using actual GEMM, we beat mainline with ubtach size of 128. But then
  performance degrades. Why?

* Some cleanup

* Much better

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-04-03 07:15:49 +02:00
committed by GitHub
parent 6d405d1fd1
commit 07dbc1aa06
2 changed files with 2234 additions and 2158 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1652,13 +1652,18 @@ void kernel_mul_mv_q8_0_f32_impl(
yl[i] = yb[i];
}
device const block_q8_0 * xr = x + ib;
for (int row = 0; row < nr; row++) {
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
//device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
device const int8_t * qs = xr->qs + NB_Q8_0*il;
float sumq = 0.f;
for (int iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
}
sumf[row] += sumq*x[ib+row*nb].d;
//sumf[row] += sumq*x[ib+row*nb].d;
sumf[row] += sumq*xr->d;
xr += nb;
}
yb += NB_Q8_0 * nw;
@@ -7218,10 +7223,16 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
for (int i = 0; i < 16; i++) {
reg[i/4][i%4] = (qs[i + 16*il] * d);
if constexpr (is_same_v<type4x4, half4x4>) {
const half d = xb->d;
for (int i = 0; i < 16; i++) {
reg[i/4][i%4] = (half)qs[i + 16*il] * d;
}
} else {
const float d = xb->d;
for (int i = 0; i < 16; i++) {
reg[i/4][i%4] = qs[i + 16*il] * d;
}
}
}
@@ -8246,39 +8257,6 @@ kernel void kernel_mul_mm_id(
uint ntg = ntg3.x * ntg3.y * ntg3.z;
uint n = nei0*nei1;
//uint npt = (n + ntg - 1) / ntg;
//uint first = tiitg * npt;
//uint last = first + npt <= n ? first + npt : n;
//uint nhave = 0;
//for (uint i = first; i < last; ++i) {
// uint ii0 = i % nei0;
// uint ii1 = i / nei0;
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
// if (id == i02) ++nhave;
//}
//threadgroup uint * nums = (threadgroup uint *)shared_memory;
//nums[tiitg] = nhave;
//threadgroup_barrier(mem_flags::mem_threadgroup);
//uint nprev = 0;
//for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
//int64_t _ne1 = nprev;
//for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
//for (uint i = first; i < last; ++i) {
// uint ii0 = i % nei0;
// uint ii1 = i / nei0;
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
// if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
//
// The following is slightly faster than the commented out version above
//
uint nhave = 0;
for (uint i = tiitg; i < n; i += ntg) {
uint ii0 = i % nei0;
@@ -8290,10 +8268,24 @@ kernel void kernel_mul_mm_id(
nums[tiitg] = nhave;
threadgroup_barrier(mem_flags::mem_threadgroup);
uint nprev = 0;
for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
int64_t _ne1 = nprev;
for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
uint stride = 1;
while (stride <= ntg/2) {
uint index = (tiitg+1)*stride*2 - 1; // index - stride = 2*tiitg*stride + stride - 1;
if (index < ntg) nums[index] += nums[index-stride];
stride <<= 1;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
stride = ntg/2;
while (stride > 0) {
uint index = (tiitg+1)*stride*2 - 1;
if (index+stride < ntg) nums[index+stride] += nums[index];
stride >>= 1;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
uint _ne1 = nums[ntg-1];
if (!_ne1) return;
uint nprev = tiitg > 0 ? nums[tiitg-1] : 0;
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
for (uint i = tiitg; i < n; i += ntg) {
@@ -8304,47 +8296,37 @@ kernel void kernel_mul_mm_id(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// This is the original version that is ridiculously slow.
//// row indices
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
uint nstep = (_ne1 + BLOCK_SIZE_N - 1)/BLOCK_SIZE_N;
//// TODO: parallelize this loop
//int64_t _ne1 = 0;
//for (ushort ii1 = 0; ii1 < nei1; ii1++) {
// for (ushort ii0 = 0; ii0 < nei0; ii0++) {
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
// if (id == i02) {
// //if (tiitg == 0) {
// rowids[_ne1] = ushort2(ii0, ii1);
// //}
// _ne1++;
// }
// }
//}
for (uint istep = 0; istep < nstep; ++istep) {
//threadgroup_barrier(mem_flags::mem_threadgroup);
uint first = BLOCK_SIZE_N*istep;
uint last = first + BLOCK_SIZE_N < _ne1 ? first + BLOCK_SIZE_N : _ne1;
int64_t this_ne1 = last - first;
threadgroup ushort2 * this_rowids = rowids + istep*BLOCK_SIZE_N;
kernel_mul_mm_id_impl<Dequantizer>(
src0,
src1,
rowids,
dst,
ne00,
ne02,
nb01,
nb02,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
_ne1,
ne0*ne1,
shared_memory,
tgpig,
tiitg,
sgitg);
kernel_mul_mm_id_impl<Dequantizer>(
src0,
src1,
this_rowids,
dst,
ne00,
ne02,
nb01,
nb02,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
this_ne1,
ne0*ne1,
shared_memory,
tgpig,
tiitg,
sgitg);
}
}
#define QK_NL 16