Fix standard attention on the CPU

This commit is contained in:
Iwan Kawrakow
2025-05-15 08:40:47 +03:00
parent 14ed9fb44d
commit ab6077718f

View File

@@ -461,27 +461,15 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
int ny = mm.funcs.size();
while (ny > 0 && !mm.funcs[ny-1]) --ny;
if (ny >= r2) {
int nx64 = Nx/64;
int nchunk64 = nx64*ne02;
for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
int i02 = ichunk/nx64;
int ix = 64*(ichunk - i02*nx64);
nchunk = nx32*ne02;
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
int i02 = ichunk/nx32;
int ix = 32*(ichunk - i02*nx32);
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
}
int ix0 = 64*nx64;
if (ix0 < Nx) {
nx32 -= 2*nx64;
nchunk = nx32*ne02;
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
int i02 = ichunk/nx32;
int ix = ix0 + 32*(ichunk - i02*nx32);
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
}
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
}
return true;
}
return true;
}
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
int i02 = ichunk/nx32;
@@ -494,7 +482,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
}
return true;
}
//if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
int gcd = simple_gcd(ne02, nth);
int counter = 0;
for (int64_t i12 = 0; i12 < ne02; i12++) {
@@ -510,7 +497,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
}
if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) {
//printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx);
MulMat mm;
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
return false;