diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7faba6bb..891b3a8f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22250,6 +22250,14 @@ static void ggml_compute_forward_ssm_conv_f32( // for use with the destination state offset between sequences GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float)); + if (n_kv == 1 && nc == 4) { + if (iqk_ssm_conv4(nr, nc, n_t, src0->nb[1], src1->nb[0], src1->nb[1], src2->nb[1], + (const float *)src1->data, (const float *)src0->data, (const float *)src2->data, + (float *)dst->data, ith, nth)) { + return; + } + } + // rows per thread const int dr = (nr + nth - 1)/nth; diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 4581a1dd..1cc128b5 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -550,3 +550,77 @@ float iqk_exp_with_thresh(int n, float * logits, float max, float min) { //t.join(); //return result[0] + result[1]; } + +bool iqk_ssm_conv4(int nr, int nc, int nt, + uint64_t nb01, uint64_t nb10, uint64_t nb11, uint64_t nb21, + const float * x0_in, const float * s0_in, const float * c_in, + float * dst, int ith, int nth) { +#ifdef __AVX2__ + if (nt <= 32 || nc != 4 || nr%16 != 0) { + return false; + } + int nr16 = nr/16; + int dr16 = (nr16 + nth - 1)/nth; + int ir0 = ith*dr16; + int ir1 = std::min(nr16, ir0 + dr16); + __m256 vs[8], vc[8]; + float aux[64]; + for (int ir = ir0; ir < ir1; ++ir) { + auto x = dst + 16*ir; + auto s = dst + 16*ir*nb21/sizeof(float) + nr*nt; + auto s0 = s0_in + 16*ir*nb01/sizeof(float); // {d_conv - 1, d_inner, n_kv} + auto x0 = x0_in + 16*ir*nb10/sizeof(float); + auto c = c_in + 16*ir*nb21/sizeof(float); + for (int ic = 0; ic < 3; ++ic) { + for (int j = 0; j < 8; ++j) { + aux[j + 8*ic + 8] = s0[(j+0)*nb01/sizeof(float) + ic]; + aux[j + 8*ic + 40] = s0[(j+8)*nb01/sizeof(float) + ic]; + } + } + // Not necessary, but doing it to shut up compiler warnings + for (int j = 0; j < 8; ++j) { + aux[j] = aux[j+32] = 0.0f; + } + for (int k = 0; k < 8; ++k) vs[k] = _mm256_loadu_ps(aux + 8*k); + for (int ic = 0; ic < 4; ++ic) { + for (int j = 0; j < 8; ++j) { + aux[j + 8*ic ] = c[(j+0)*nb21/sizeof(float) + ic]; + aux[j + 8*ic + 32] = c[(j+8)*nb21/sizeof(float) + ic]; + } + } + for (int k = 0; k < 8; ++k) vc[k] = _mm256_loadu_ps(aux + 8*k); + int idx = 0; + for (int it = 0; it < nt; ++it) { + vs[idx+0] = _mm256_loadu_ps(x0+0); + vs[idx+4] = _mm256_loadu_ps(x0+8); + idx = (idx + 1) & 3; + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + for (int k = 0; k < 4; ++k) { + int ii = (idx + k) & 3; + sum1 = _mm256_fmadd_ps(vs[ii+0], vc[k+0], sum1); + sum2 = _mm256_fmadd_ps(vs[ii+4], vc[k+4], sum2); + } + _mm256_storeu_ps(x+0, sum1); + _mm256_storeu_ps(x+8, sum2); + x0 += nb11/sizeof(float); + x += nr; + } + for (int k = 0; k < 4; ++k) { + int ii = (idx + k) & 3; + _mm256_storeu_ps(aux + 8*k + 0, vs[ii+0]); + _mm256_storeu_ps(aux + 8*k + 32, vs[ii+4]); + } + for (int j = 0; j < 8; ++j) { + for (int ic = 0; ic < 4; ++ic) { + s[(j+0)*nb21/sizeof(float) + ic] = aux[j + 8*ic + 0]; + s[(j+8)*nb21/sizeof(float) + ic] = aux[j + 8*ic + 32]; + } + } + } + return true; +#else + return false; +#endif + } + diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index 267c3e85..6f4317f9 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -32,6 +32,11 @@ void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth); float iqk_exp_with_thresh(int n, float * logits, float max, float min); +bool iqk_ssm_conv4(int nr, int nc, int nt, + uint64_t nb01, uint64_t nb10, uint64_t nb11, uint64_t nb21, + const float * x0, const float * s0, const float * c, + float * dst, int ith, int nth); + #ifdef __cplusplus } #endif