Faster convolution on AVX2 (#1400)

* Faster ssm_conv on AVX2

* Move the optimized ssm_conv to iqk

* Minor
This commit is contained in:
Kawrakow
2026-03-11 19:28:38 +01:00
committed by GitHub
parent 1f4dcab5c6
commit afa6439ac3
3 changed files with 87 additions and 0 deletions

View File

@@ -22250,6 +22250,14 @@ static void ggml_compute_forward_ssm_conv_f32(
// for use with the destination state offset between sequences
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
if (n_kv == 1 && nc == 4) {
if (iqk_ssm_conv4(nr, nc, n_t, src0->nb[1], src1->nb[0], src1->nb[1], src2->nb[1],
(const float *)src1->data, (const float *)src0->data, (const float *)src2->data,
(float *)dst->data, ith, nth)) {
return;
}
}
// rows per thread
const int dr = (nr + nth - 1)/nth;

View File

@@ -550,3 +550,77 @@ float iqk_exp_with_thresh(int n, float * logits, float max, float min) {
//t.join();
//return result[0] + result[1];
}
bool iqk_ssm_conv4(int nr, int nc, int nt,
uint64_t nb01, uint64_t nb10, uint64_t nb11, uint64_t nb21,
const float * x0_in, const float * s0_in, const float * c_in,
float * dst, int ith, int nth) {
#ifdef __AVX2__
if (nt <= 32 || nc != 4 || nr%16 != 0) {
return false;
}
int nr16 = nr/16;
int dr16 = (nr16 + nth - 1)/nth;
int ir0 = ith*dr16;
int ir1 = std::min(nr16, ir0 + dr16);
__m256 vs[8], vc[8];
float aux[64];
for (int ir = ir0; ir < ir1; ++ir) {
auto x = dst + 16*ir;
auto s = dst + 16*ir*nb21/sizeof(float) + nr*nt;
auto s0 = s0_in + 16*ir*nb01/sizeof(float); // {d_conv - 1, d_inner, n_kv}
auto x0 = x0_in + 16*ir*nb10/sizeof(float);
auto c = c_in + 16*ir*nb21/sizeof(float);
for (int ic = 0; ic < 3; ++ic) {
for (int j = 0; j < 8; ++j) {
aux[j + 8*ic + 8] = s0[(j+0)*nb01/sizeof(float) + ic];
aux[j + 8*ic + 40] = s0[(j+8)*nb01/sizeof(float) + ic];
}
}
// Not necessary, but doing it to shut up compiler warnings
for (int j = 0; j < 8; ++j) {
aux[j] = aux[j+32] = 0.0f;
}
for (int k = 0; k < 8; ++k) vs[k] = _mm256_loadu_ps(aux + 8*k);
for (int ic = 0; ic < 4; ++ic) {
for (int j = 0; j < 8; ++j) {
aux[j + 8*ic ] = c[(j+0)*nb21/sizeof(float) + ic];
aux[j + 8*ic + 32] = c[(j+8)*nb21/sizeof(float) + ic];
}
}
for (int k = 0; k < 8; ++k) vc[k] = _mm256_loadu_ps(aux + 8*k);
int idx = 0;
for (int it = 0; it < nt; ++it) {
vs[idx+0] = _mm256_loadu_ps(x0+0);
vs[idx+4] = _mm256_loadu_ps(x0+8);
idx = (idx + 1) & 3;
__m256 sum1 = _mm256_setzero_ps();
__m256 sum2 = _mm256_setzero_ps();
for (int k = 0; k < 4; ++k) {
int ii = (idx + k) & 3;
sum1 = _mm256_fmadd_ps(vs[ii+0], vc[k+0], sum1);
sum2 = _mm256_fmadd_ps(vs[ii+4], vc[k+4], sum2);
}
_mm256_storeu_ps(x+0, sum1);
_mm256_storeu_ps(x+8, sum2);
x0 += nb11/sizeof(float);
x += nr;
}
for (int k = 0; k < 4; ++k) {
int ii = (idx + k) & 3;
_mm256_storeu_ps(aux + 8*k + 0, vs[ii+0]);
_mm256_storeu_ps(aux + 8*k + 32, vs[ii+4]);
}
for (int j = 0; j < 8; ++j) {
for (int ic = 0; ic < 4; ++ic) {
s[(j+0)*nb21/sizeof(float) + ic] = aux[j + 8*ic + 0];
s[(j+8)*nb21/sizeof(float) + ic] = aux[j + 8*ic + 32];
}
}
}
return true;
#else
return false;
#endif
}

View File

@@ -32,6 +32,11 @@ void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth);
float iqk_exp_with_thresh(int n, float * logits, float max, float min);
bool iqk_ssm_conv4(int nr, int nc, int nt,
uint64_t nb01, uint64_t nb10, uint64_t nb11, uint64_t nb21,
const float * x0, const float * s0, const float * c,
float * dst, int ith, int nth);
#ifdef __cplusplus
}
#endif