mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-14 07:48:16 +00:00
Faster convolution on AVX2 (#1400)
* Faster ssm_conv on AVX2 * Move the optimized ssm_conv to iqk * Minor
This commit is contained in:
@@ -22250,6 +22250,14 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||
// for use with the destination state offset between sequences
|
||||
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
|
||||
|
||||
if (n_kv == 1 && nc == 4) {
|
||||
if (iqk_ssm_conv4(nr, nc, n_t, src0->nb[1], src1->nb[0], src1->nb[1], src2->nb[1],
|
||||
(const float *)src1->data, (const float *)src0->data, (const float *)src2->data,
|
||||
(float *)dst->data, ith, nth)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
|
||||
@@ -550,3 +550,77 @@ float iqk_exp_with_thresh(int n, float * logits, float max, float min) {
|
||||
//t.join();
|
||||
//return result[0] + result[1];
|
||||
}
|
||||
|
||||
bool iqk_ssm_conv4(int nr, int nc, int nt,
|
||||
uint64_t nb01, uint64_t nb10, uint64_t nb11, uint64_t nb21,
|
||||
const float * x0_in, const float * s0_in, const float * c_in,
|
||||
float * dst, int ith, int nth) {
|
||||
#ifdef __AVX2__
|
||||
if (nt <= 32 || nc != 4 || nr%16 != 0) {
|
||||
return false;
|
||||
}
|
||||
int nr16 = nr/16;
|
||||
int dr16 = (nr16 + nth - 1)/nth;
|
||||
int ir0 = ith*dr16;
|
||||
int ir1 = std::min(nr16, ir0 + dr16);
|
||||
__m256 vs[8], vc[8];
|
||||
float aux[64];
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
auto x = dst + 16*ir;
|
||||
auto s = dst + 16*ir*nb21/sizeof(float) + nr*nt;
|
||||
auto s0 = s0_in + 16*ir*nb01/sizeof(float); // {d_conv - 1, d_inner, n_kv}
|
||||
auto x0 = x0_in + 16*ir*nb10/sizeof(float);
|
||||
auto c = c_in + 16*ir*nb21/sizeof(float);
|
||||
for (int ic = 0; ic < 3; ++ic) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
aux[j + 8*ic + 8] = s0[(j+0)*nb01/sizeof(float) + ic];
|
||||
aux[j + 8*ic + 40] = s0[(j+8)*nb01/sizeof(float) + ic];
|
||||
}
|
||||
}
|
||||
// Not necessary, but doing it to shut up compiler warnings
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
aux[j] = aux[j+32] = 0.0f;
|
||||
}
|
||||
for (int k = 0; k < 8; ++k) vs[k] = _mm256_loadu_ps(aux + 8*k);
|
||||
for (int ic = 0; ic < 4; ++ic) {
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
aux[j + 8*ic ] = c[(j+0)*nb21/sizeof(float) + ic];
|
||||
aux[j + 8*ic + 32] = c[(j+8)*nb21/sizeof(float) + ic];
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < 8; ++k) vc[k] = _mm256_loadu_ps(aux + 8*k);
|
||||
int idx = 0;
|
||||
for (int it = 0; it < nt; ++it) {
|
||||
vs[idx+0] = _mm256_loadu_ps(x0+0);
|
||||
vs[idx+4] = _mm256_loadu_ps(x0+8);
|
||||
idx = (idx + 1) & 3;
|
||||
__m256 sum1 = _mm256_setzero_ps();
|
||||
__m256 sum2 = _mm256_setzero_ps();
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
int ii = (idx + k) & 3;
|
||||
sum1 = _mm256_fmadd_ps(vs[ii+0], vc[k+0], sum1);
|
||||
sum2 = _mm256_fmadd_ps(vs[ii+4], vc[k+4], sum2);
|
||||
}
|
||||
_mm256_storeu_ps(x+0, sum1);
|
||||
_mm256_storeu_ps(x+8, sum2);
|
||||
x0 += nb11/sizeof(float);
|
||||
x += nr;
|
||||
}
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
int ii = (idx + k) & 3;
|
||||
_mm256_storeu_ps(aux + 8*k + 0, vs[ii+0]);
|
||||
_mm256_storeu_ps(aux + 8*k + 32, vs[ii+4]);
|
||||
}
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int ic = 0; ic < 4; ++ic) {
|
||||
s[(j+0)*nb21/sizeof(float) + ic] = aux[j + 8*ic + 0];
|
||||
s[(j+8)*nb21/sizeof(float) + ic] = aux[j + 8*ic + 32];
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,11 @@ void iqk_hadamard(struct ggml_tensor * dst, int ith, int nth);
|
||||
|
||||
float iqk_exp_with_thresh(int n, float * logits, float max, float min);
|
||||
|
||||
bool iqk_ssm_conv4(int nr, int nc, int nt,
|
||||
uint64_t nb01, uint64_t nb10, uint64_t nb11, uint64_t nb21,
|
||||
const float * x0, const float * s0, const float * c,
|
||||
float * dst, int ith, int nth);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user