mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 00:24:11 +00:00
Flass attention refinements
This commit is contained in:
@@ -6069,8 +6069,10 @@ struct FlashAttn {
|
||||
static_assert(k_step%16 == 0);
|
||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||
|
||||
constexpr static int vk_size = D <= 256 ? D/8 : D/16;
|
||||
static_assert(q_step <= vk_size);
|
||||
constexpr static bool is_small_head = D <= 128;
|
||||
|
||||
constexpr static int vk_size = is_small_head ? D/8 : D/16;
|
||||
static_assert(2*q_step <= vk_size);
|
||||
|
||||
FlashAttn(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
|
||||
|
||||
@@ -6080,6 +6082,7 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, __m512 * qv) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
@@ -6102,38 +6105,31 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
inline void update_M_S(int j, [[maybe_unused]] const char * mask) {
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
|
||||
if (mp[l1] == h_inf) {
|
||||
cache[k_step*m1 + l1] = -INFINITY;
|
||||
return;
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
auto vsum = _mm512_mul_ps(vk[0], _mm512_loadu_ps(qr));
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
|
||||
}
|
||||
cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
|
||||
inline void update_M_S(int j) {
|
||||
if (softcap <= 0.0f) {
|
||||
if constexpr (D <= 256) {
|
||||
for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
||||
} else {
|
||||
auto vinf = _mm512_set1_ps(-INFINITY);
|
||||
const ggml_half * mp = (const ggml_half *)mask;
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
|
||||
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mp), _mm256_setzero_si256());
|
||||
vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, val);
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
||||
} else {
|
||||
auto v_softcap = _mm512_set1_ps(softcap);
|
||||
if constexpr (D <= 256) {
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
|
||||
//vk[l] = _mm512_mul_ps(vscale, v_tanh(_mm512_mul_ps(v_softcap, val)));
|
||||
vk[l] = _mm512_mul_ps(v_softcap, v_tanh(_mm512_mul_ps(vscale, val)));
|
||||
}
|
||||
} else {
|
||||
auto vinf = _mm512_set1_ps(-INFINITY);
|
||||
const ggml_half * mp = (const ggml_half *)mask;
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
auto m16 = _mm256_cmpeq_epi16_mask(_mm256_loadu_si256((const __m256i *)mp+l), _mm256_setzero_si256());
|
||||
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
|
||||
//val = v_tanh(_mm512_mul_ps(v_softcap, val));
|
||||
//vk[l] = _mm512_mask_mul_ps(vinf, m16, vscale, val);
|
||||
val = v_tanh(_mm512_mul_ps(vscale, val));
|
||||
vk[l] = _mm512_mask_mul_ps(vinf, m16, v_softcap, val);
|
||||
}
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
|
||||
vk[l] = _mm512_mul_ps(v_softcap, v_tanh(_mm512_mul_ps(vscale, val)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6158,6 +6154,7 @@ struct FlashAttn {
|
||||
}
|
||||
S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk);
|
||||
}
|
||||
|
||||
inline void normalize_and_store(int j, const float * R, float * qkv) const {
|
||||
GGML_ASSERT(S[j] > 0);
|
||||
auto norm = _mm512_set1_ps(1/S[j]);
|
||||
@@ -6167,8 +6164,6 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
inline void accumulate_qkv(int nq1, int stride_v, const char * v);
|
||||
|
||||
inline void normalize_and_store(int nq1, int stride_qkv, float * qkv) const {
|
||||
auto R = qkv_cache;
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
@@ -6178,7 +6173,16 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
template <int Size = D, class = std::enable_if<Size <= 256>>
|
||||
inline void normalize_and_store(int stride_qkv, float * qkv) const {
|
||||
auto R = qkv_cache;
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
normalize_and_store(j, R, qkv);
|
||||
qkv += stride_qkv;
|
||||
R += D;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq(int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
__m512 qv[D/16];
|
||||
@@ -6193,25 +6197,19 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
|
||||
template <int Size = D, class = std::enable_if<Size >= 257>>
|
||||
inline void mult_mask_kq(int stride_k, int stride_q, const char * k, const float * q) {
|
||||
DataInfo info{cache, (const char *)q, k_step*sizeof(float), stride_q*sizeof(float), 0, 0, nullptr, 0};
|
||||
for (int i = 0; i < q_step/4; ++i) {
|
||||
mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
|
||||
info.cur_y += 4;
|
||||
}
|
||||
int n_left = q_step - 4*(q_step/4);
|
||||
if (n_left > 0) {
|
||||
switch (n_left) {
|
||||
case 1: mul_mat_fX_fY_T<1, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
case 2: mul_mat_fX_fY_T<2, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
case 3: mul_mat_fX_fY_T<3, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
default: break;
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_l(int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto kr = (const ggml_half *)(k + l1*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int Size = D, class = std::enable_if<Size <= 256>>
|
||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq(int nq, int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
__m512 qv[D/16];
|
||||
@@ -6225,58 +6223,109 @@ struct FlashAttn {
|
||||
}
|
||||
}
|
||||
}
|
||||
template <int Size = D, class = std::enable_if<Size >= 257>>
|
||||
inline void mult_mask_kq(int nq, int stride_k, int stride_q, const char * k, const float * q) {
|
||||
DataInfo info{cache, (const char *)q, k_step*sizeof(float), stride_q*sizeof(float), 0, 0, nullptr, 0};
|
||||
for (int i = 0; i < nq/4; ++i) {
|
||||
mul_mat_fX_fY_T<4, ggml_half, float>(D, (const void *)k, stride_k, info, k_step);
|
||||
info.cur_y += 4;
|
||||
}
|
||||
int n_left = nq - 4*(nq/4);
|
||||
if (n_left > 0) {
|
||||
switch (n_left) {
|
||||
case 1: mul_mat_fX_fY_T<1, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
case 2: mul_mat_fX_fY_T<2, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
case 3: mul_mat_fX_fY_T<3, ggml_half, float>(D, (const void *)k, stride_k, info, k_step); break;
|
||||
default: break;
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_l(int nq, int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto kr = (const ggml_half *)(k + l1*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
for (int m1 = 0; m1 < nq; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void multiply_mask_kq(int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) {
|
||||
if constexpr (D <= 256) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(stride_k, stride_q, stride_m, k, q, mask);
|
||||
}
|
||||
else {
|
||||
mult_mask_kq(stride_k, stride_q, k, q);
|
||||
mult_mask_kq_l(stride_k, stride_q, stride_m, k, q, mask);
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
update_M_S(j, mask);
|
||||
mask += stride_m;
|
||||
update_M_S(j);
|
||||
}
|
||||
}
|
||||
|
||||
inline void multiply_mask_kq(int nq, int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) {
|
||||
if constexpr (D <= 256) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(nq, stride_k, stride_q, stride_m, k, q, mask);
|
||||
}
|
||||
else {
|
||||
mult_mask_kq(nq, stride_k, stride_q, k, q);
|
||||
mult_mask_kq_l(nq, stride_k, stride_q, stride_m, k, q, mask);
|
||||
}
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
update_M_S(j, mask);
|
||||
mask += stride_m;
|
||||
update_M_S(j);
|
||||
}
|
||||
}
|
||||
|
||||
inline void accumulate_qkv(int stride_v, const char * v);
|
||||
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
|
||||
// Hence, for now, we will not handle head sizes of 80 and 112
|
||||
inline void accumulate_qkv(int stride_v, const char * v) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void normalize_and_store(int stride_qkv, float * qkv) const {
|
||||
auto R = qkv_cache;
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
normalize_and_store(j, R, qkv);
|
||||
qkv += stride_qkv;
|
||||
R += D;
|
||||
template <int Nq = q_step, class = std::enable_if<Nq >= 2>>
|
||||
inline void accumulate_qkv(int nq1, int stride_v, const char * v) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6339,138 +6388,14 @@ struct FlashAttn {
|
||||
result = Op(Op_combine(vals[0], vals[1]));
|
||||
}
|
||||
else {
|
||||
auto vmax = vals[0];
|
||||
for (int l = 1; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]);
|
||||
auto vmax = Op_combine(vals[0], vals[1]);
|
||||
for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]);
|
||||
result = Op(vmax);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
void FlashAttn<D, q_step, k_step>::accumulate_qkv(int stride_v, const char * v) {
|
||||
if constexpr (2*q_step <= vk_size) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[j] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[j] = _mm512_loadu_ps(R + 16*i);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[j] = _mm512_mul_ps(vk[j], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[j] = _mm512_fmadd_ps(v, vs, vk[j]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
void FlashAttn<D, q_step, k_step>::accumulate_qkv(int nq1, int stride_v, const char * v) {
|
||||
if (2*nq1 <= vk_size) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[j] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[j] = _mm512_loadu_ps(R + 16*i);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[j] = _mm512_mul_ps(vk[j], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[j] = _mm512_fmadd_ps(v, vs, vk[j]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
@@ -6514,12 +6439,14 @@ bool iqk_flash_attn_noalibi(int D, // head size
|
||||
switch (D) {
|
||||
case 64:
|
||||
iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 80:
|
||||
iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
// Disable until we fix accumulate_qkv for odd D/16
|
||||
//case 80:
|
||||
// iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 96:
|
||||
iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 112:
|
||||
iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
// Disable until we fix accumulate_qkv for odd D/16
|
||||
//case 112:
|
||||
// iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 128:
|
||||
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 256:
|
||||
|
||||
Reference in New Issue
Block a user