mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-05 03:20:00 +00:00
Metal: much faster MoE prompt processing (#307)
* MoE improvements on Metal This version beats mainline, there are things I don't understand: * Mianline has effectively gone to GEMV for MUL_MAT_ID. We can do the same, but we are 30% slower. Why? * Using actual GEMM, we beat mainline with ubtach size of 128. But then performance degrades. Why? * Some cleanup * Much better --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1652,13 +1652,18 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
yl[i] = yb[i];
|
||||
}
|
||||
|
||||
device const block_q8_0 * xr = x + ib;
|
||||
|
||||
for (int row = 0; row < nr; row++) {
|
||||
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
|
||||
//device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
|
||||
device const int8_t * qs = xr->qs + NB_Q8_0*il;
|
||||
float sumq = 0.f;
|
||||
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
||||
sumq += qs[iq] * yl[iq];
|
||||
}
|
||||
sumf[row] += sumq*x[ib+row*nb].d;
|
||||
//sumf[row] += sumq*x[ib+row*nb].d;
|
||||
sumf[row] += sumq*xr->d;
|
||||
xr += nb;
|
||||
}
|
||||
|
||||
yb += NB_Q8_0 * nw;
|
||||
@@ -7218,10 +7223,16 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg
|
||||
template <typename type4x4>
|
||||
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
const half d = xb->d;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
||||
if constexpr (is_same_v<type4x4, half4x4>) {
|
||||
const half d = xb->d;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg[i/4][i%4] = (half)qs[i + 16*il] * d;
|
||||
}
|
||||
} else {
|
||||
const float d = xb->d;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg[i/4][i%4] = qs[i + 16*il] * d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8246,39 +8257,6 @@ kernel void kernel_mul_mm_id(
|
||||
uint ntg = ntg3.x * ntg3.y * ntg3.z;
|
||||
uint n = nei0*nei1;
|
||||
|
||||
//uint npt = (n + ntg - 1) / ntg;
|
||||
//uint first = tiitg * npt;
|
||||
//uint last = first + npt <= n ? first + npt : n;
|
||||
|
||||
//uint nhave = 0;
|
||||
//for (uint i = first; i < last; ++i) {
|
||||
// uint ii0 = i % nei0;
|
||||
// uint ii1 = i / nei0;
|
||||
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
// if (id == i02) ++nhave;
|
||||
//}
|
||||
//threadgroup uint * nums = (threadgroup uint *)shared_memory;
|
||||
//nums[tiitg] = nhave;
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
//uint nprev = 0;
|
||||
//for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
|
||||
//int64_t _ne1 = nprev;
|
||||
//for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
|
||||
|
||||
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
//for (uint i = first; i < last; ++i) {
|
||||
// uint ii0 = i % nei0;
|
||||
// uint ii1 = i / nei0;
|
||||
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
// if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
|
||||
//}
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
//
|
||||
// The following is slightly faster than the commented out version above
|
||||
//
|
||||
uint nhave = 0;
|
||||
for (uint i = tiitg; i < n; i += ntg) {
|
||||
uint ii0 = i % nei0;
|
||||
@@ -8290,10 +8268,24 @@ kernel void kernel_mul_mm_id(
|
||||
nums[tiitg] = nhave;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
uint nprev = 0;
|
||||
for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
|
||||
int64_t _ne1 = nprev;
|
||||
for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
|
||||
uint stride = 1;
|
||||
while (stride <= ntg/2) {
|
||||
uint index = (tiitg+1)*stride*2 - 1; // index - stride = 2*tiitg*stride + stride - 1;
|
||||
if (index < ntg) nums[index] += nums[index-stride];
|
||||
stride <<= 1;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
stride = ntg/2;
|
||||
while (stride > 0) {
|
||||
uint index = (tiitg+1)*stride*2 - 1;
|
||||
if (index+stride < ntg) nums[index+stride] += nums[index];
|
||||
stride >>= 1;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
uint _ne1 = nums[ntg-1];
|
||||
if (!_ne1) return;
|
||||
|
||||
uint nprev = tiitg > 0 ? nums[tiitg-1] : 0;
|
||||
|
||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
for (uint i = tiitg; i < n; i += ntg) {
|
||||
@@ -8304,47 +8296,37 @@ kernel void kernel_mul_mm_id(
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// This is the original version that is ridiculously slow.
|
||||
//// row indices
|
||||
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
uint nstep = (_ne1 + BLOCK_SIZE_N - 1)/BLOCK_SIZE_N;
|
||||
|
||||
//// TODO: parallelize this loop
|
||||
//int64_t _ne1 = 0;
|
||||
//for (ushort ii1 = 0; ii1 < nei1; ii1++) {
|
||||
// for (ushort ii0 = 0; ii0 < nei0; ii0++) {
|
||||
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
// if (id == i02) {
|
||||
// //if (tiitg == 0) {
|
||||
// rowids[_ne1] = ushort2(ii0, ii1);
|
||||
// //}
|
||||
// _ne1++;
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
for (uint istep = 0; istep < nstep; ++istep) {
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
uint first = BLOCK_SIZE_N*istep;
|
||||
uint last = first + BLOCK_SIZE_N < _ne1 ? first + BLOCK_SIZE_N : _ne1;
|
||||
int64_t this_ne1 = last - first;
|
||||
threadgroup ushort2 * this_rowids = rowids + istep*BLOCK_SIZE_N;
|
||||
|
||||
kernel_mul_mm_id_impl<Dequantizer>(
|
||||
src0,
|
||||
src1,
|
||||
rowids,
|
||||
dst,
|
||||
ne00,
|
||||
ne02,
|
||||
nb01,
|
||||
nb02,
|
||||
ne11,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
_ne1,
|
||||
ne0*ne1,
|
||||
shared_memory,
|
||||
tgpig,
|
||||
tiitg,
|
||||
sgitg);
|
||||
kernel_mul_mm_id_impl<Dequantizer>(
|
||||
src0,
|
||||
src1,
|
||||
this_rowids,
|
||||
dst,
|
||||
ne00,
|
||||
ne02,
|
||||
nb01,
|
||||
nb02,
|
||||
ne11,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
this_ne1,
|
||||
ne0*ne1,
|
||||
shared_memory,
|
||||
tgpig,
|
||||
tiitg,
|
||||
sgitg);
|
||||
}
|
||||
}
|
||||
|
||||
#define QK_NL 16
|
||||
|
||||
Reference in New Issue
Block a user