mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
WIP
This commit is contained in:
@@ -6103,7 +6103,7 @@ inline __m256 v_expf(__m256 x) {
|
||||
}
|
||||
#endif
|
||||
|
||||
inline float prepare_softmax(int nc, const float * sp, float scale, float slope, const char * mp, bool use_fp16, __m512 * values) {
|
||||
inline std::pair<float, int> prepare_softmax(int nc, const float * sp, float scale, float slope, const char * mp, bool use_fp16, __m512 * values) {
|
||||
__m512 vscale = _mm512_set1_ps(scale);
|
||||
__m512 vmax = _mm512_set1_ps(-INFINITY);
|
||||
if (mp) {
|
||||
@@ -6111,6 +6111,9 @@ inline float prepare_softmax(int nc, const float * sp, float scale, float slope,
|
||||
if (use_fp16) {
|
||||
const ggml_fp16_t * mp_f16 = (const ggml_fp16_t *)mp;
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
if (GGML_FP16_TO_FP32(mp_f16[16*i]) == -INFINITY) {
|
||||
nc = 16*i; break;
|
||||
}
|
||||
const __m512 m = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + i));
|
||||
const __m512 x = _mm512_loadu_ps(sp + 16*i);
|
||||
const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x));
|
||||
@@ -6120,6 +6123,9 @@ inline float prepare_softmax(int nc, const float * sp, float scale, float slope,
|
||||
} else {
|
||||
const float * mp_f32 = (const float *)mp;
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
if (mp_f32[16*i] == -INFINITY) {
|
||||
nc = 16*i; break;
|
||||
}
|
||||
const __m512 m = _mm512_loadu_ps(mp_f32 + 16*i);
|
||||
const __m512 x = _mm512_loadu_ps(sp + 16*i);
|
||||
const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x));
|
||||
@@ -6135,11 +6141,11 @@ inline float prepare_softmax(int nc, const float * sp, float scale, float slope,
|
||||
values[i] = y;
|
||||
}
|
||||
}
|
||||
return _mm512_reduce_max_ps(vmax);
|
||||
return std::make_pair(_mm512_reduce_max_ps(vmax), nc);
|
||||
}
|
||||
|
||||
|
||||
inline float prepare_softmax(int nc, float * sp, float scale, float slope, const char * mp, bool use_fp16) {
|
||||
inline std::pair<float, int> prepare_softmax(int nc, float * sp, float scale, float slope, const char * mp, bool use_fp16) {
|
||||
__m512 vscale = _mm512_set1_ps(scale);
|
||||
__m512 vmax = _mm512_set1_ps(-INFINITY);
|
||||
float scalar_max = -INFINITY;
|
||||
@@ -6147,27 +6153,26 @@ inline float prepare_softmax(int nc, float * sp, float scale, float slope, const
|
||||
__m512 vslope = _mm512_set1_ps(slope);
|
||||
if (use_fp16) {
|
||||
const ggml_fp16_t * mp_f16 = (const ggml_fp16_t *)mp;
|
||||
__m512 vmax1 = _mm512_set1_ps(-INFINITY);
|
||||
for (int i = 0; i < nc/32; ++i) {
|
||||
const __m512 m1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+0));
|
||||
const __m512 m2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+1));
|
||||
const __m512 x1 = _mm512_loadu_ps(sp + 32*i + 0);
|
||||
const __m512 x2 = _mm512_loadu_ps(sp + 32*i + 16);
|
||||
const __m512 y1 = _mm512_fmadd_ps(vslope, m1, _mm512_mul_ps(vscale, x1));
|
||||
const __m512 y2 = _mm512_fmadd_ps(vslope, m2, _mm512_mul_ps(vscale, x2));
|
||||
vmax = _mm512_max_ps(vmax , y1);
|
||||
vmax1 = _mm512_max_ps(vmax1, y2);
|
||||
_mm512_storeu_ps(sp + 32*i + 0, y1);
|
||||
_mm512_storeu_ps(sp + 32*i + 16, y2);
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
if (GGML_FP16_TO_FP32(mp_f16[16*i]) == -INFINITY) {
|
||||
nc = 16*i; break;
|
||||
}
|
||||
const __m512 m = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16+i));
|
||||
const __m512 x = _mm512_loadu_ps(sp + 16*i);
|
||||
const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x));
|
||||
vmax = _mm512_max_ps(vmax , y);
|
||||
_mm512_storeu_ps(sp + 16*i, y);
|
||||
}
|
||||
vmax = _mm512_max_ps(vmax, vmax1);
|
||||
for (int i = 32*(nc/32); i < nc; ++i) {
|
||||
for (int i = 16*(nc/16); i < nc; ++i) {
|
||||
sp[i] = scale*sp[i] + slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
||||
scalar_max = std::max(scalar_max, sp[i]);
|
||||
}
|
||||
} else {
|
||||
const float * mp_f32 = (const float *)mp;
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
if (mp_f32[16*i] == -INFINITY) {
|
||||
nc = 16*i; break;
|
||||
}
|
||||
const __m512 m = _mm512_loadu_ps(mp_f32 + 16*i);
|
||||
const __m512 x = _mm512_loadu_ps(sp + 16*i);
|
||||
const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x));
|
||||
@@ -6192,7 +6197,7 @@ inline float prepare_softmax(int nc, float * sp, float scale, float slope, const
|
||||
}
|
||||
}
|
||||
float vector_max = _mm512_reduce_max_ps(vmax);
|
||||
return std::max(scalar_max, vector_max);
|
||||
return std::make_pair(std::max(scalar_max, vector_max), nc);
|
||||
}
|
||||
|
||||
inline float do_soft_max(int nc, float * sp, float max) {
|
||||
@@ -6226,26 +6231,26 @@ inline void do_scale(int nc, const float * sp, float * dp, float scale) {
|
||||
}
|
||||
|
||||
void softmax_extended(int nc, float * sp, float * dp, float scale, float slope, const char * mp, bool mask_is_fp16) {
|
||||
auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16);
|
||||
auto sum = do_soft_max(nc, sp, max);
|
||||
do_scale(nc, sp, dp, 1.f/sum);
|
||||
auto [max, ncc] = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16);
|
||||
auto sum = do_soft_max(ncc, sp, max);
|
||||
do_scale(ncc, sp, dp, 1.f/sum);
|
||||
}
|
||||
|
||||
float softmax_extended_dont_scale(int nc, float * sp, float scale, float slope, const char * mp, bool mask_is_fp16) {
|
||||
std::pair<float, int> softmax_extended_dont_scale(int nc, float * sp, float scale, float slope, const char * mp, bool mask_is_fp16) {
|
||||
if (nc/16 <= 16 && 16*(nc/16) == nc) {
|
||||
__m512 v[16];
|
||||
auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16, v);
|
||||
auto [max, ncc] = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16, v);
|
||||
auto vmax = _mm512_set1_ps(-max);
|
||||
auto vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
for (int i = 0; i < ncc/16; ++i) {
|
||||
v[i] = v_expf(_mm512_add_ps(v[i], vmax));
|
||||
vsum = _mm512_add_ps(vsum, v[i]);
|
||||
_mm512_storeu_ps(sp + 16*i, v[i]);
|
||||
}
|
||||
return _mm512_reduce_add_ps(vsum);
|
||||
return std::make_pair(_mm512_reduce_add_ps(vsum), ncc);
|
||||
}
|
||||
auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16);
|
||||
return do_soft_max(nc, sp, max);
|
||||
auto [max, ncc] = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16);
|
||||
return std::make_pair(do_soft_max(ncc, sp, max), ncc);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -6372,15 +6377,16 @@ void iqk_flash_helper(int nq, // number of elements in q
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <int nq>
|
||||
inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m512 * acc, const char * v, int stride_v) {
|
||||
if (smax > M) {
|
||||
if (M > -INFINITY) {
|
||||
float ms = expf(M - smax);
|
||||
auto vms = _mm512_set1_ps(ms);
|
||||
for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]);
|
||||
for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]);
|
||||
S *= ms;
|
||||
} else {
|
||||
for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps();
|
||||
for (int i = 0; i < nq/16; ++i) acc[i] = _mm512_setzero_ps();
|
||||
S = 0;
|
||||
}
|
||||
M = smax;
|
||||
@@ -6393,11 +6399,61 @@ inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m5
|
||||
S += saux[j];
|
||||
auto vs = _mm512_set1_ps(saux[j]);
|
||||
auto vr = (const ggml_half *)(v + stride_v*j);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (int i = 0; i < nq/16; ++i) {
|
||||
acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nq>
|
||||
void flash_attn_T(int nk, // number of rows in k
|
||||
int stride_k, // distance between rows in k (in bytes)
|
||||
int stride_v,
|
||||
const float * q, // q vector
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nk elements
|
||||
float scale,
|
||||
float slope,
|
||||
float * qkv) {
|
||||
constexpr int kNchunk = 16;
|
||||
__m512 vq[nq/16];
|
||||
__m512 acc[nq/16];
|
||||
float saux[kNchunk];
|
||||
for (int i = 0; i < nq/16; ++i) vq[i] = _mm512_loadu_ps(q + 16*i);
|
||||
const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr;
|
||||
float M = -INFINITY;
|
||||
float S = 0;
|
||||
float smax = -INFINITY;
|
||||
int last_ik = 0;
|
||||
int ik = 0;
|
||||
for (; ik < nk; ++ik) {
|
||||
if (ik - last_ik == kNchunk) {
|
||||
accumulate<nq>(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
|
||||
last_ik = ik;
|
||||
smax = -INFINITY;
|
||||
}
|
||||
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f;
|
||||
if (mv == -INFINITY) break;
|
||||
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
|
||||
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
for (int i = 1; i < nq/16; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
float s = scale * _mm512_reduce_add_ps(sum) + mv;
|
||||
saux[ik - last_ik] = s;
|
||||
smax = std::max(smax, s);
|
||||
}
|
||||
if (ik > last_ik) {
|
||||
accumulate<nq>(ik - last_ik, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
|
||||
}
|
||||
if (S > 0) {
|
||||
auto norm = _mm512_set1_ps(1/S);
|
||||
for (int i = 0; i < nq/16; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i]));
|
||||
} else {
|
||||
printf("Oops: S = %g. ik = %d, last_ik = %d, nk = %d, M = %g\n", S, ik, last_ik, nk, M);
|
||||
GGML_ASSERT(false);
|
||||
std::memset(qkv, 0, 128*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_flash_helper_2(int nq, // number of elements in q
|
||||
@@ -6413,167 +6469,25 @@ void iqk_flash_helper_2(int nq, // number of elements in q
|
||||
float * qk,
|
||||
float * qkv) {
|
||||
GGML_ASSERT(nq % 4 == 0);
|
||||
//GGML_ASSERT(nq / 16 <= 16);
|
||||
|
||||
if (nq == 128) {
|
||||
constexpr int kNchunk = 16;
|
||||
const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr;
|
||||
__m512 vq[8];
|
||||
__m512 acc[8]; // = {};
|
||||
for (int i = 0; i < 8; ++i) vq[i] = _mm512_loadu_ps(q + 16*i);
|
||||
float M = -INFINITY;
|
||||
float S = 0;
|
||||
float saux[kNchunk];
|
||||
float smax = -INFINITY;
|
||||
int last_ik = 0;
|
||||
int ik = 0;
|
||||
for (; ik < nk; ++ik) {
|
||||
if (ik - last_ik == kNchunk) {
|
||||
accumulate(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
|
||||
//if (smax > M) {
|
||||
// if (M > -INFINITY) {
|
||||
// float ms = expf(M - smax);
|
||||
// auto vms = _mm512_set1_ps(ms);
|
||||
// for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]);
|
||||
// S *= ms;
|
||||
// }
|
||||
// M = smax;
|
||||
//}
|
||||
//auto vs_all = v_expf(_mm256_sub_ps(_mm256_loadu_ps(saux), _mm256_set1_ps(M)));
|
||||
//_mm256_storeu_ps(saux, vs_all);
|
||||
//for (int j = 0; j < 8; ++j) {
|
||||
// S += saux[j];
|
||||
// auto vs = _mm512_set1_ps(saux[j]);
|
||||
// auto vr = (const ggml_half *)((const char *)v + stride_v*(last_ik + j));
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
// }
|
||||
//}
|
||||
last_ik = ik;
|
||||
smax = -INFINITY;
|
||||
switch (nq) {
|
||||
case 64: flash_attn_T< 64>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
|
||||
case 80: flash_attn_T< 80>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
|
||||
case 96: flash_attn_T< 96>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
|
||||
case 112: flash_attn_T<112>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
|
||||
case 128: flash_attn_T<128>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
|
||||
case 256: flash_attn_T<256>(nk, stride_k, stride_v, q, k, v, mask, scale, slope, qkv); return;
|
||||
default: break;
|
||||
//default: GGML_ABORT("unhandled head size -> fatal error");
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
const ggml_half * mp = (const ggml_half *)mask;
|
||||
for (int i = 0; i < nk; ++i) {
|
||||
if (GGML_FP16_TO_FP32(mp[i]) == -INFINITY) {
|
||||
nk = i; break;
|
||||
}
|
||||
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f;
|
||||
if (mv == -INFINITY) break; //continue;
|
||||
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
|
||||
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
float s = scale * _mm512_reduce_add_ps(sum) + mv;
|
||||
saux[ik - last_ik] = s;
|
||||
smax = std::max(smax, s);
|
||||
}
|
||||
if (ik > last_ik) {
|
||||
accumulate(ik - last_ik, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
|
||||
//if (smax > M) {
|
||||
// if (M > -INFINITY) {
|
||||
// float ms = expf(M - smax);
|
||||
// auto vms = _mm512_set1_ps(ms);
|
||||
// for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]);
|
||||
// S *= ms;
|
||||
// }
|
||||
// M = smax;
|
||||
//}
|
||||
//for (int j = last_ik; j < ik; ++j) {
|
||||
// float s = expf(saux[j - last_ik] - M);
|
||||
// S += s;
|
||||
// auto vs = _mm512_set1_ps(s);
|
||||
// auto vr = (const ggml_half *)((const char *)v + stride_v*j);
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
// }
|
||||
//}
|
||||
}
|
||||
if (S > 0) {
|
||||
auto norm = _mm512_set1_ps(1/S);
|
||||
for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i]));
|
||||
} else {
|
||||
printf("Oops: S = %g. ik = %d, last_ik = %d, nk = %d, M = %g\n", S, ik, last_ik, nk, M);
|
||||
GGML_ASSERT(false);
|
||||
std::memset(qkv, 0, 128*sizeof(float));
|
||||
}
|
||||
return;
|
||||
//int ik = 0;
|
||||
//if (false && nk >= 8) {
|
||||
// bool finished = false;
|
||||
// float s8[8];
|
||||
// for (int ik8 = 0; ik8 < nk/8; ++ik8) {
|
||||
// float smax = -INFINITY;
|
||||
// int j = 0;
|
||||
// for (; j < 8; ++j) {
|
||||
// const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[8*ik+j]) : 0.0f;
|
||||
// if (mv == -INFINITY) break;
|
||||
// const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j));
|
||||
// auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
// for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
// float s = scale * _mm512_reduce_add_ps(sum) + mv;
|
||||
// s8[j] = s;
|
||||
// smax = std::max(smax, s);
|
||||
// }
|
||||
// if (smax == -INFINITY) { finished = true; break; }
|
||||
// if (smax > M) {
|
||||
// if (M > -INFINITY) {
|
||||
// float ms = expf(M - smax);
|
||||
// auto scale = _mm512_set1_ps(ms);
|
||||
// for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(acc[i], scale);
|
||||
// S *= ms;
|
||||
// } else {
|
||||
// for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps();
|
||||
// S = 0;
|
||||
// }
|
||||
// M = smax;
|
||||
// }
|
||||
// _mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M))));
|
||||
// for (int l = 0; l < j; ++l) {
|
||||
// if (s8[l] <= 0.0f) continue;
|
||||
// const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+l));
|
||||
// auto vs = _mm512_set1_ps(s8[l]);
|
||||
// S += s8[l];
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// ik = finished ? nk : 8*(nk/8);
|
||||
//}
|
||||
//int last_ik = 0;
|
||||
//for (; ik < nk; ++ik) {
|
||||
// const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f;
|
||||
// if (mv == -INFINITY) break; //continue;
|
||||
// const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
|
||||
// const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*ik);
|
||||
// auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
|
||||
// for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
|
||||
// float s = scale * _mm512_reduce_add_ps(sum) + mv;
|
||||
// if (s > M) {
|
||||
// if (M > -INFINITY) {
|
||||
// float ms = expf(M - s);
|
||||
// auto vms = _mm512_set1_ps(ms);
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)));
|
||||
// }
|
||||
// S = ms*S + 1;
|
||||
// } else {
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// acc[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i));
|
||||
// }
|
||||
// S = 1;
|
||||
// }
|
||||
// M = s;
|
||||
// } else {
|
||||
// float vs = expf(s - M);
|
||||
// auto vvs = _mm512_set1_ps(vs);
|
||||
// for (int i = 0; i < 8; ++i) {
|
||||
// acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
|
||||
// }
|
||||
// S += vs;
|
||||
// }
|
||||
//}
|
||||
//if (S > 0) {
|
||||
// auto norm = _mm512_set1_ps(1/S);
|
||||
// for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i]));
|
||||
//} else {
|
||||
// std::memset(qkv, 0, 128*sizeof(float));
|
||||
//}
|
||||
//return;
|
||||
}
|
||||
|
||||
DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0};
|
||||
@@ -6581,7 +6495,8 @@ void iqk_flash_helper_2(int nq, // number of elements in q
|
||||
mul_mat_fX_fY_1<ggml_half, float>(nq, k, stride_k, info, nk);
|
||||
//mul_mat_fX_fY_T<1, ggml_half, float>(nq, k, stride_k, info, nk);
|
||||
//softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true);
|
||||
auto sum = softmax_extended_dont_scale(nk, qk, scale, slope, (const char *)mask, true);
|
||||
auto [sum, ncc] = softmax_extended_dont_scale(nk, qk, scale, slope, (const char *)mask, true);
|
||||
nk = ncc;
|
||||
|
||||
GGML_ASSERT(nq%16 == 0);
|
||||
if (nq/16 <= 16) {
|
||||
|
||||
Reference in New Issue
Block a user