MoE improvements on Metal

This version beats mainline, there are things I don't understand:
* Mianline has effectively gone to GEMV for MUL_MAT_ID. We can do the
  same, but we are 30% slower. Why?
* Using actual GEMM, we beat mainline with ubtach size of 128. But then
  performance degrades. Why?
This commit is contained in:
Iwan Kawrakow
2025-04-02 15:26:19 +02:00
parent 21a5b8bd28
commit 2a5552830b
2 changed files with 4756 additions and 2147 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1652,13 +1652,18 @@ void kernel_mul_mv_q8_0_f32_impl(
yl[i] = yb[i];
}
device const block_q8_0 * xr = x + ib;
for (int row = 0; row < nr; row++) {
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
//device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
device const int8_t * qs = xr->qs + NB_Q8_0*il;
float sumq = 0.f;
for (int iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
}
sumf[row] += sumq*x[ib+row*nb].d;
//sumf[row] += sumq*x[ib+row*nb].d;
sumf[row] += sumq*xr->d;
xr += nb;
}
yb += NB_Q8_0 * nw;
@@ -5746,7 +5751,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
uint tiisg,
uint sgitg) {
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
//threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
@@ -5766,8 +5771,8 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const int ib = it/2;
const int il = it%2;
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
//shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
//threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[2]={0.f}, all_sum;
@@ -5793,15 +5798,19 @@ void kernel_mul_mv_iq4_xs_f32_impl(
aux32[0] = q4[0] & 0x0f0f0f0f;
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
//qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
//qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
qf1 = {kvalues_iq4nl_f[q8[0]], kvalues_iq4nl_f[q8[1]], kvalues_iq4nl_f[q8[2]], kvalues_iq4nl_f[q8[3]]};
qf2 = {kvalues_iq4nl_f[q8[4]], kvalues_iq4nl_f[q8[5]], kvalues_iq4nl_f[q8[6]], kvalues_iq4nl_f[q8[7]]};
acc1 += yl[0] * qf1;
acc2 += yl[1] * qf2;
aux32[0] = q4[1] & 0x0f0f0f0f;
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
//qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
//qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
qf1 = {kvalues_iq4nl_f[q8[0]], kvalues_iq4nl_f[q8[1]], kvalues_iq4nl_f[q8[2]], kvalues_iq4nl_f[q8[3]]};
qf2 = {kvalues_iq4nl_f[q8[4]], kvalues_iq4nl_f[q8[5]], kvalues_iq4nl_f[q8[6]], kvalues_iq4nl_f[q8[7]]};
acc1 += yl[2] * qf1;
acc2 += yl[3] * qf2;
@@ -7218,13 +7227,43 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
for (int i = 0; i < 16; i++) {
reg[i/4][i%4] = (qs[i + 16*il] * d);
if constexpr (is_same_v<type4x4, half4x4>) {
const half d = xb->d;
for (int i = 0; i < 16; i++) {
reg[i/4][i%4] = (half)qs[i + 16*il] * d;
}
} else {
const float d = xb->d;
for (int i = 0; i < 16; i++) {
reg[i/4][i%4] = qs[i + 16*il] * d;
}
}
}
//template <typename type4x4>
//void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
// device const int8_t * qs = ((device const int8_t *)xb->qs);
// const half d = xb->d;
//
// for (int i = 0; i < 16; i++) {
// reg[i/4][i%4] = (qs[i + 16*il] * d);
// }
//}
//template <typename type4x4>
//void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
// device const int8_t * qs = ((device const int8_t *)xb->qs);
// const float d = xb->d;
//
// float4x4 reg_f;
//
// for (int i = 0; i < 16; i++) {
// reg_f[i/4][i%4] = d * qs[i + 16*il];
// }
//
// reg = (type4x4)reg_f;
//}
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & reg) {
const float d = xb->d;
@@ -8246,39 +8285,6 @@ kernel void kernel_mul_mm_id(
uint ntg = ntg3.x * ntg3.y * ntg3.z;
uint n = nei0*nei1;
//uint npt = (n + ntg - 1) / ntg;
//uint first = tiitg * npt;
//uint last = first + npt <= n ? first + npt : n;
//uint nhave = 0;
//for (uint i = first; i < last; ++i) {
// uint ii0 = i % nei0;
// uint ii1 = i / nei0;
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
// if (id == i02) ++nhave;
//}
//threadgroup uint * nums = (threadgroup uint *)shared_memory;
//nums[tiitg] = nhave;
//threadgroup_barrier(mem_flags::mem_threadgroup);
//uint nprev = 0;
//for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
//int64_t _ne1 = nprev;
//for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
//for (uint i = first; i < last; ++i) {
// uint ii0 = i % nei0;
// uint ii1 = i / nei0;
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
// if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
//
// The following is slightly faster than the commented out version above
//
uint nhave = 0;
for (uint i = tiitg; i < n; i += ntg) {
uint ii0 = i % nei0;
@@ -8290,10 +8296,31 @@ kernel void kernel_mul_mm_id(
nums[tiitg] = nhave;
threadgroup_barrier(mem_flags::mem_threadgroup);
uint nprev = 0;
for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
int64_t _ne1 = nprev;
for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
uint stride = 1;
while (stride <= ntg/2) {
uint index = (tiitg+1)*stride*2 - 1; // index - stride = 2*tiitg*stride + stride - 1;
if (index < ntg) nums[index] += nums[index-stride];
stride <<= 1;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
stride = ntg/2;
while (stride > 0) {
uint index = (tiitg+1)*stride*2 - 1;
if (index+stride < ntg) nums[index+stride] += nums[index];
stride >>= 1;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
int64_t _ne1 = nums[ntg-1];
if (!_ne1) return;
uint nprev = tiitg > 0 ? nums[tiitg-1] : 0;
//uint ncum = 0;
//for (uint i = 0; i < tiitg; ++i) ncum += nums[i];
//uint nprev = ncum;
//for (uint i = tiitg; i < ntg; ++i) ncum += nums[i];
//if (!ncum) return;
//int64_t _ne1 = ncum;
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
for (uint i = tiitg; i < n; i += ntg) {
@@ -8304,26 +8331,6 @@ kernel void kernel_mul_mm_id(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// This is the original version that is ridiculously slow.
//// row indices
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
//// TODO: parallelize this loop
//int64_t _ne1 = 0;
//for (ushort ii1 = 0; ii1 < nei1; ii1++) {
// for (ushort ii0 = 0; ii0 < nei0; ii0++) {
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
// if (id == i02) {
// //if (tiitg == 0) {
// rowids[_ne1] = ushort2(ii0, ii1);
// //}
// _ne1++;
// }
// }
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
kernel_mul_mm_id_impl<Dequantizer>(
src0,
src1,