mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
WIP
This commit is contained in:
@@ -6756,6 +6756,33 @@ bool iqk_soft_max_noalibi(int nc, int ir0, int ir1, int ne00, int ne01,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int nrc_y, typename FloatX, typename FloatY>
|
||||||
|
void mul_mat_fX_fY_fa(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||||
|
assert(n%QFBase::k_step == 0);
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
constexpr int k_nx = 7;
|
||||||
|
#else
|
||||||
|
constexpr int k_nx = 3;
|
||||||
|
#endif
|
||||||
|
const char * cx = (const char *)vx;
|
||||||
|
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||||
|
mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
|
||||||
|
}
|
||||||
|
int last_x = k_nx*(nrc_x/k_nx);
|
||||||
|
if (last_x == nrc_x) return;
|
||||||
|
int nx = nrc_x - last_x;
|
||||||
|
switch (nx) {
|
||||||
|
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
|
||||||
|
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
|
||||||
|
case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
|
||||||
|
case 5: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 5>>(n, cx, bx, last_x, info); break;
|
||||||
|
case 6: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 6>>(n, cx, bx, last_x, info); break;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void iqk_flash_helper_3(int ne00,
|
void iqk_flash_helper_3(int ne00,
|
||||||
int nq1, // number of elements in q
|
int nq1, // number of elements in q
|
||||||
int nk1, // number of rows in k
|
int nk1, // number of rows in k
|
||||||
@@ -6792,28 +6819,46 @@ void iqk_flash_helper_3(int ne00,
|
|||||||
float qkv_cache[128*q_step];
|
float qkv_cache[128*q_step];
|
||||||
int need_scaling[q_step];
|
int need_scaling[q_step];
|
||||||
auto vscale = _mm512_set1_ps(scale);
|
auto vscale = _mm512_set1_ps(scale);
|
||||||
|
auto vinf = _mm512_set1_ps(-INFINITY);
|
||||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||||
for (int j = 0; j < q_step; ++j) {
|
for (int j = 0; j < q_step; ++j) {
|
||||||
S[j] = 0; M[j] = -INFINITY;
|
S[j] = 0; M[j] = -INFINITY;
|
||||||
}
|
}
|
||||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||||
//for (int l1 = 0; l1 < k_step; ++l1) {
|
// This is slower
|
||||||
// auto kr = (const ggml_half *)((const char *)k + (k_step*k1 + l1)*stride_k);
|
//DataInfo info{cache, (const char *)(q + q_step*i1*stride_q), k_step*sizeof(float), stride_q*sizeof(float), 0, 0, nullptr, 0};
|
||||||
// for (int i = 0; i < 8; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
//mul_mat_fX_fY_T<q_step/2, ggml_half, float>(ne00, (const void *)((const char *)k + k_step*k1*stride_k), stride_k, info, k_step);
|
||||||
// for (int m1 = 0; m1 < q_step; ++m1) {
|
//info.cur_y += q_step/2;
|
||||||
// // q index is q_step*i1 + m1
|
//mul_mat_fX_fY_T<q_step/2, ggml_half, float>(ne00, (const void *)((const char *)k + k_step*k1*stride_k), stride_k, info, k_step);
|
||||||
// // k index is k_step*k1 + l1
|
//for (int j = 0; j < q_step; ++j) {
|
||||||
// const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*(q_step*i1 + m1)) + k_step*k1;
|
// const ggml_half * mp = (const ggml_half *)((const char *)mask + stride_m*q_step*i1) + k_step*k1;
|
||||||
// if (mp[l1] == h_inf) {
|
// for (int l = 0; l < k_step/16; ++l) {
|
||||||
// cache[k_step*m1 + l1] = -INFINITY;
|
// auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
|
||||||
// continue;
|
// auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mp), _mm256_setzero_si256());
|
||||||
// }
|
// vals[l] = _mm512_mask_mul_ps(vinf, m16, vscale, val);
|
||||||
// auto qr = q + (q_step*i1 + m1)*stride_q;
|
|
||||||
// auto vsum = _mm512_mul_ps(vk[0], _mm512_loadu_ps(qr));
|
|
||||||
// for (int i = 1; i < 8; ++i) vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
|
|
||||||
// cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
|
|
||||||
// }
|
// }
|
||||||
|
// auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
|
||||||
|
// need_scaling[j] = 0;
|
||||||
|
// if (smax > M[j]) {
|
||||||
|
// if (M[j] > -INFINITY) {
|
||||||
|
// float m = expf(M[j] - smax);
|
||||||
|
// vms[j] = _mm512_set1_ps(m);
|
||||||
|
// need_scaling[j] = 1;
|
||||||
|
// S[j] *= m;
|
||||||
|
// } else {
|
||||||
|
// need_scaling[j] = 2;
|
||||||
|
// S[j] = 0;
|
||||||
|
// }
|
||||||
|
// M[j] = smax;
|
||||||
|
// }
|
||||||
|
// auto vm = _mm512_set1_ps(M[j]);
|
||||||
|
// for (int l = 0; l < k_step/16; ++l) {
|
||||||
|
// vals[l] = v_expf(_mm512_sub_ps(vals[l], vm));
|
||||||
|
// _mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]);
|
||||||
|
// }
|
||||||
|
// S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1]));
|
||||||
//}
|
//}
|
||||||
|
|
||||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||||
auto kr1 = (const ggml_half *)((const char *)k + (k_step*k1 + l1 + 0)*stride_k);
|
auto kr1 = (const ggml_half *)((const char *)k + (k_step*k1 + l1 + 0)*stride_k);
|
||||||
auto kr2 = (const ggml_half *)((const char *)k + (k_step*k1 + l1 + 1)*stride_k);
|
auto kr2 = (const ggml_half *)((const char *)k + (k_step*k1 + l1 + 1)*stride_k);
|
||||||
@@ -6843,16 +6888,8 @@ void iqk_flash_helper_3(int ne00,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int j = 0; j < q_step; ++j) {
|
for (int j = 0; j < q_step; ++j) {
|
||||||
//auto R = qkv_cache + 128*j;
|
|
||||||
//auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
|
||||||
for (int l = 0; l < k_step/16; ++l) vals[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
for (int l = 0; l < k_step/16; ++l) vals[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
||||||
auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
|
auto smax = _mm512_reduce_max_ps(_mm512_max_ps(vals[0], vals[1]));
|
||||||
//auto smax = _mm512_reduce_max_ps(_mm512_max_ps(_mm512_max_ps(vals[0], vals[1]), _mm512_max_ps(vals[2], vals[3])));
|
|
||||||
//auto val1 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
|
|
||||||
//auto val2 = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16));
|
|
||||||
//auto smax = _mm512_reduce_max_ps(_mm512_max_ps(val1, val2));
|
|
||||||
////auto val = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j));
|
|
||||||
////auto smax = _mm512_reduce_max_ps(val);
|
|
||||||
need_scaling[j] = 0;
|
need_scaling[j] = 0;
|
||||||
if (smax > M[j]) {
|
if (smax > M[j]) {
|
||||||
if (M[j] > -INFINITY) {
|
if (M[j] > -INFINITY) {
|
||||||
@@ -6872,29 +6909,19 @@ void iqk_flash_helper_3(int ne00,
|
|||||||
_mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]);
|
_mm512_storeu_ps(cache + k_step*j + 16*l, vals[l]);
|
||||||
}
|
}
|
||||||
S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1]));
|
S[j] += _mm512_reduce_add_ps(_mm512_add_ps(vals[0], vals[1]));
|
||||||
//S[j] += _mm512_reduce_add_ps(_mm512_add_ps(_mm512_add_ps(vals[0], vals[1]), _mm512_add_ps(vals[2], vals[3])));
|
|
||||||
//val1 = v_expf(_mm512_sub_ps(val1, _mm512_set1_ps(M[j])));
|
|
||||||
//val2 = v_expf(_mm512_sub_ps(val2, _mm512_set1_ps(M[j])));
|
|
||||||
//S[j] += _mm512_reduce_add_ps(_mm512_add_ps(val1, val2));
|
|
||||||
//_mm512_storeu_ps(cache + k_step*j, val1);
|
|
||||||
//_mm512_storeu_ps(cache + k_step*j + 16, val2);
|
|
||||||
////val = v_expf(_mm512_sub_ps(val, _mm512_set1_ps(M[j])));
|
|
||||||
////S[j] += _mm512_reduce_add_ps(val);
|
|
||||||
////_mm512_storeu_ps(cache + k_step*j, val);
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < 8; i += 2) {
|
for (int i = 0; i < 8; i += 2) {
|
||||||
for (int j = 0; j < q_step; ++j) {
|
for (int j = 0; j < q_step; ++j) {
|
||||||
if (need_scaling[j] == 2) {
|
if (need_scaling[j] == 2) {
|
||||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||||
} else {
|
} else {
|
||||||
//auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
auto R = qkv_cache + 128*j;
|
||||||
auto R = qkv_cache + 128*j;
|
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
if (need_scaling[j] == 1) {
|
||||||
if (need_scaling[j] == 1) {
|
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||||
@@ -6909,7 +6936,6 @@ void iqk_flash_helper_3(int ne00,
|
|||||||
}
|
}
|
||||||
for (int j = 0; j < q_step; ++j) {
|
for (int j = 0; j < q_step; ++j) {
|
||||||
auto R = qkv_cache + 128*j;
|
auto R = qkv_cache + 128*j;
|
||||||
//auto R = qkv + (q_step*i1 + j)*stride_qkv;
|
|
||||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||||
}
|
}
|
||||||
@@ -6929,7 +6955,7 @@ void iqk_flash_helper_3(int ne00,
|
|||||||
return;
|
return;
|
||||||
if (nq1%16 != 0 || nk1%16 != 0) printf("Oops(%s): nq1 = %d, nk1 = %d\n", __func__, nq1, nk1);
|
if (nq1%16 != 0 || nk1%16 != 0) printf("Oops(%s): nq1 = %d, nk1 = %d\n", __func__, nq1, nk1);
|
||||||
//GGML_ASSERT(nq1%16 == 0 && nk1%16 == 0);
|
//GGML_ASSERT(nq1%16 == 0 && nk1%16 == 0);
|
||||||
auto vinf = _mm512_set1_ps(-INFINITY);
|
//auto vinf = _mm512_set1_ps(-INFINITY);
|
||||||
for (int i1 = 0; i1 < nq1/16; ++i1) {
|
for (int i1 = 0; i1 < nq1/16; ++i1) {
|
||||||
//int iq1 = 16*i1;
|
//int iq1 = 16*i1;
|
||||||
for (int j1 = 0; j1 < 16; ++j1) {
|
for (int j1 = 0; j1 < 16; ++j1) {
|
||||||
|
|||||||
Reference in New Issue
Block a user