mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-25 08:59:30 +00:00
WIP: plugging into ggml_compute_forward_flash_attn_ext_f16
Nothing I have tried is faster than current main branch. I guess, doing vector dot products just cannot compete with a tiled matrix multiplication.
This commit is contained in:
@@ -3252,6 +3252,53 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0,
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
|
||||
}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
IQK_NOINLINE void mul_mat_Qx_Qy_1xN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBase::k_step;
|
||||
int nb4 = n/4;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Acc acc[Qx::nrc];
|
||||
if (nb <= Qx::nrc) {
|
||||
QFBase::Data yv[Qx::nrc];
|
||||
for (int i = 0; i < nb; ++i) yv[i] = y.load1(0, i);
|
||||
//for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
// auto sum = QFBase::acc_first(yv[0], x.load1(ix, 0));
|
||||
// for (int i = 1; i < nb; ++i) {
|
||||
// sum = QFBase::acc(sum, yv[i], x.load1(ix, i));
|
||||
// }
|
||||
// info.store(ix0+ix, 0, QFBase::hsum(sum));
|
||||
//}
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc_first(yv[0], x.load1(ix, 0));
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[ix] = QFBase::acc(acc[ix], yv[i], x.load1(ix, i));
|
||||
}
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(acc[ix]));
|
||||
return;
|
||||
}
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, 0);
|
||||
acc[ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
yv = y.load1(0, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, i);
|
||||
acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
|
||||
yv = y.load_tail(0, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load_tail(ix, i);
|
||||
acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(acc[ix]));
|
||||
}
|
||||
|
||||
// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done
|
||||
// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in
|
||||
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
|
||||
@@ -3280,6 +3327,33 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FloatX, typename FloatY>
|
||||
void mul_mat_fX_fY_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
#ifdef __AVX512F__
|
||||
constexpr int k_nx = 8;
|
||||
#else
|
||||
constexpr int k_nx = 4;
|
||||
#endif
|
||||
const char * cx = (const char *)vx;
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_Qx_Qy_1xN<QFT<FloatY, 1>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
int nx = nrc_x - last_x;
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_1xN<QFT<FloatY, 1>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_Qx_Qy_1xN<QFT<FloatY, 1>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_Qx_Qy_1xN<QFT<FloatY, 1>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
|
||||
#ifdef __AVX512F__
|
||||
case 4: mul_mat_Qx_Qy_1xN<QFT<FloatY, 1>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
|
||||
case 5: mul_mat_Qx_Qy_1xN<QFT<FloatY, 1>, QFT<FloatX, 5>>(n, cx, bx, last_x, info); break;
|
||||
case 6: mul_mat_Qx_Qy_1xN<QFT<FloatY, 1>, QFT<FloatX, 6>>(n, cx, bx, last_x, info); break;
|
||||
case 7: mul_mat_Qx_Qy_1xN<QFT<FloatY, 1>, QFT<FloatX, 7>>(n, cx, bx, last_x, info); break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Tiled Q8_0 x Q8_0 implementation. Not used as the templated legacy quant implementation
|
||||
// above is faster. Left behind so we remember we tried.
|
||||
@@ -6029,6 +6103,42 @@ inline __m256 v_expf(__m256 x) {
|
||||
}
|
||||
#endif
|
||||
|
||||
inline float prepare_softmax(int nc, const float * sp, float scale, float slope, const char * mp, bool use_fp16, __m512 * values) {
|
||||
__m512 vscale = _mm512_set1_ps(scale);
|
||||
__m512 vmax = _mm512_set1_ps(-INFINITY);
|
||||
if (mp) {
|
||||
__m512 vslope = _mm512_set1_ps(slope);
|
||||
if (use_fp16) {
|
||||
const ggml_fp16_t * mp_f16 = (const ggml_fp16_t *)mp;
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
const __m512 m = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + i));
|
||||
const __m512 x = _mm512_loadu_ps(sp + 16*i);
|
||||
const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x));
|
||||
vmax = _mm512_max_ps(vmax , y);
|
||||
values[i] = y;
|
||||
}
|
||||
} else {
|
||||
const float * mp_f32 = (const float *)mp;
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
const __m512 m = _mm512_loadu_ps(mp_f32 + 16*i);
|
||||
const __m512 x = _mm512_loadu_ps(sp + 16*i);
|
||||
const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x));
|
||||
vmax = _mm512_max_ps(vmax, y);
|
||||
values[i] = y;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
const __m512 x = _mm512_loadu_ps(sp + 16*i);
|
||||
const __m512 y = _mm512_mul_ps(vscale, x);
|
||||
vmax = _mm512_max_ps(vmax, y);
|
||||
values[i] = y;
|
||||
}
|
||||
}
|
||||
return _mm512_reduce_max_ps(vmax);
|
||||
}
|
||||
|
||||
|
||||
inline float prepare_softmax(int nc, float * sp, float scale, float slope, const char * mp, bool use_fp16) {
|
||||
__m512 vscale = _mm512_set1_ps(scale);
|
||||
__m512 vmax = _mm512_set1_ps(-INFINITY);
|
||||
@@ -6121,6 +6231,23 @@ void softmax_extended(int nc, float * sp, float * dp, float scale, float slope,
|
||||
do_scale(nc, sp, dp, 1.f/sum);
|
||||
}
|
||||
|
||||
float softmax_extended_dont_scale(int nc, float * sp, float scale, float slope, const char * mp, bool mask_is_fp16) {
|
||||
if (nc/16 <= 16 && 16*(nc/16) == nc) {
|
||||
__m512 v[16];
|
||||
auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16, v);
|
||||
auto vmax = _mm512_set1_ps(-max);
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
v[i] = v_expf(_mm512_add_ps(v[i], vmax));
|
||||
vsum = _mm512_add_ps(vsum, v[i]);
|
||||
_mm512_storeu_ps(sp + 16*i, v[i]);
|
||||
}
|
||||
return _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16);
|
||||
return do_soft_max(nc, sp, max);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00,
|
||||
@@ -6259,10 +6386,82 @@ void iqk_flash_helper_2(int nq, // number of elements in q
|
||||
GGML_ASSERT(nq % 4 == 0);
|
||||
//GGML_ASSERT(nq / 16 <= 16);
|
||||
|
||||
if (nq == 128) {
|
||||
const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr;
|
||||
__m512 vq[8];
|
||||
__m512 acc[8] = {};
|
||||
for (int i = 0; i < 8; ++i) vq[i] = _mm512_loadu_ps(q + 16*i);
|
||||
float M = -INFINITY;
|
||||
float S = 0;
|
||||
int ik = 0;
|
||||
if (nk >= 8) {
|
||||
float s8[8];
|
||||
for (int ik8 = 0; ik8 < nk/8; ++ik8) {
|
||||
float smax = -INFINITY;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j));
|
||||
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
float s = scale * _mm512_reduce_add_ps(sum);
|
||||
if (mp) s += slope*GGML_FP16_TO_FP32(mp[8*ik8+j]);
|
||||
s8[j] = s;
|
||||
smax = std::max(smax, s);
|
||||
}
|
||||
if (smax > M) {
|
||||
float ms = expf(M - smax);
|
||||
auto scale = _mm512_set1_ps(ms);
|
||||
for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(acc[i], scale);
|
||||
S *= ms;
|
||||
M = smax;
|
||||
}
|
||||
_mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M))));
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+j));
|
||||
auto vs = _mm512_set1_ps(s8[j]);
|
||||
S += s8[j];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
ik = 8*(nk/8);
|
||||
}
|
||||
for (; ik < nk; ++ik) {
|
||||
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f;
|
||||
if (mv == -INFINITY) continue;
|
||||
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
|
||||
const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*ik);
|
||||
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
float s = scale * _mm512_reduce_add_ps(sum) + mv;
|
||||
if (s > M) {
|
||||
float ms = expf(M - s);
|
||||
auto vms = _mm512_set1_ps(ms);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)));
|
||||
}
|
||||
S = ms*S + 1;
|
||||
M = s;
|
||||
} else {
|
||||
float vs = expf(s - M);
|
||||
auto vvs = _mm512_set1_ps(vs);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
}
|
||||
S += vs;
|
||||
}
|
||||
}
|
||||
auto norm = _mm512_set1_ps(1/S);
|
||||
for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i]));
|
||||
return;
|
||||
}
|
||||
|
||||
DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0};
|
||||
|
||||
mul_mat_fX_fY_T<1, ggml_half, float>(nq, k, stride_k, info, nk);
|
||||
softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true);
|
||||
mul_mat_fX_fY_1<ggml_half, float>(nq, k, stride_k, info, nk);
|
||||
//mul_mat_fX_fY_T<1, ggml_half, float>(nq, k, stride_k, info, nk);
|
||||
//softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true);
|
||||
auto sum = softmax_extended_dont_scale(nk, qk, scale, slope, (const char *)mask, true);
|
||||
|
||||
GGML_ASSERT(nq%16 == 0);
|
||||
if (nq/16 <= 16) {
|
||||
@@ -6280,14 +6479,16 @@ void iqk_flash_helper_2(int nq, // number of elements in q
|
||||
v_qkv[j] = _mm512_fmadd_ps(v_qk, v_v, v_qkv[j]);
|
||||
}
|
||||
}
|
||||
auto vnorm = _mm512_set1_ps(1/sum);
|
||||
for (int j = 0; j < nq/16; ++j) {
|
||||
_mm512_storeu_ps(qkv + 16*j, v_qkv[j]);
|
||||
_mm512_storeu_ps(qkv + 16*j, _mm512_mul_ps(vnorm, v_qkv[j]));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
float norm = 1/sum;
|
||||
for (int ic = 0; ic < nk; ++ic) {
|
||||
auto v_qk = _mm512_set1_ps(qk[ic]);
|
||||
auto v_qk = _mm512_set1_ps(qk[ic]*norm);
|
||||
const ggml_half * vr = (const ggml_half *)((const char *)v + ic*stride_v);
|
||||
for (int j = 0; j < nq/16; ++j) {
|
||||
auto v_v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr + j));
|
||||
|
||||
Reference in New Issue
Block a user