mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 16:44:21 +00:00
Zen4 Flash Attention (#32)
* Zen4 flash attention: moving useful parts from the kq_fused_softmax branch * Add flash attention with soft-cap and fix D = 256 case * Flash attention refinements * Update FlashAttn comment --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -16149,6 +16149,38 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
scale /= softcap;
|
||||
}
|
||||
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16 &&
|
||||
mask && mask->type == GGML_TYPE_F16) {
|
||||
int64_t work_per_slice = D*nek1*neq1;
|
||||
int ntg = 1;
|
||||
if (nth%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
|
||||
else if (nth%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
|
||||
else if (nth%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
|
||||
if ((neq2*neq3)%(nth/ntg) == 0) {
|
||||
//if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d\n", __func__, (int)D, (int)neq2, (int)neq1, (int)nek1);
|
||||
int counter = 0;
|
||||
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
|
||||
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
||||
if (counter++ % (nth/ntg) == ith/ntg) {
|
||||
int iq1 = (ith%ntg)*neq1/ntg;
|
||||
if (!iqk_flash_attn_noalibi(D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
|
||||
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
|
||||
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
|
||||
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
|
||||
(const void *)((const char *)mask->data + iq1*mask->nb[1]),
|
||||
scale, softcap,
|
||||
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
IQK_Flash_Attn_NotAvailable:;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
|
||||
|
||||
@@ -5915,6 +5915,575 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
|
||||
#endif // __aarch64__
|
||||
|
||||
namespace {
|
||||
|
||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||
// copy-pasted from Justine Tunney's contribution to llama.cpp
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline float32x4_t v_expf(float32x4_t x) {
|
||||
const float32x4_t r = vdupq_n_f32(0x1.8p23f);
|
||||
const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
|
||||
const float32x4_t n = vsubq_f32(z, r);
|
||||
const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
|
||||
vdupq_n_f32(0x1.7f7d1cp-20f));
|
||||
const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
|
||||
const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
|
||||
const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
|
||||
const float32x4_t u = vmulq_f32(b, b);
|
||||
const float32x4_t j = vfmaq_f32(
|
||||
vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
|
||||
vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
|
||||
vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
|
||||
if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
|
||||
return vfmaq_f32(k, j, k);
|
||||
const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
|
||||
const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
|
||||
const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
|
||||
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
|
||||
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
|
||||
}
|
||||
inline float32x4_t v_tanh(float32x4_t x) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
|
||||
const float32x4_t exp_two_x = v_expf(two_x);
|
||||
const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
|
||||
const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
||||
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
|
||||
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
|
||||
// copy-pasted from Justine Tunney's contribution to llama.cpp
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline __m512 v_expf(__m512 x) {
|
||||
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
||||
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
||||
const __m512 n = _mm512_sub_ps(z, r);
|
||||
const __m512 b =
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
||||
const __mmask16 d =
|
||||
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
|
||||
const __m512 u = _mm512_mul_ps(b, b);
|
||||
const __m512 j = _mm512_fmadd_ps(
|
||||
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
||||
_mm512_set1_ps(0x1.573e2ep-5f)),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
||||
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
|
||||
const __m512 res = _mm512_scalef_ps(j, n);
|
||||
if (_mm512_kortestz(d, d))
|
||||
return res;
|
||||
const __m512 zero = _mm512_setzero_ps();
|
||||
const __m512 alt = _mm512_mask_blend_ps(
|
||||
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
|
||||
return _mm512_mask_blend_ps(d, res, alt);
|
||||
}
|
||||
inline __m512 v_tanh(__m512 x) {
|
||||
const __m512 one = _mm512_set1_ps(1.0f);
|
||||
const __m512 exp_two_x = v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
|
||||
const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ);
|
||||
const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
|
||||
return _mm512_mask_blend_ps(mask, res, one);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__AVX2__) && defined(__FMA__)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline __m256 v_expf(__m256 x) {
|
||||
const __m256 r = _mm256_set1_ps(0x1.8p23f);
|
||||
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
|
||||
const __m256 n = _mm256_sub_ps(z, r);
|
||||
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
|
||||
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
|
||||
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
|
||||
const __m256 k = _mm256_castsi256_ps(
|
||||
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
|
||||
const __m256i c = _mm256_castps_si256(
|
||||
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
||||
_mm256_set1_ps(126), _CMP_GT_OQ));
|
||||
const __m256 u = _mm256_mul_ps(b, b);
|
||||
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
|
||||
_mm256_set1_ps(0x1.573e2ep-5f)), u,
|
||||
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
|
||||
_mm256_set1_ps(0x1.fffdb6p-2f))),
|
||||
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
|
||||
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
|
||||
return _mm256_fmadd_ps(j, k, k);
|
||||
const __m256i g = _mm256_and_si256(
|
||||
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
|
||||
_mm256_set1_epi32(0x82000000u));
|
||||
const __m256 s1 =
|
||||
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
|
||||
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
|
||||
const __m256i d = _mm256_castps_si256(
|
||||
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
||||
_mm256_set1_ps(192), _CMP_GT_OQ));
|
||||
return _mm256_or_ps(
|
||||
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
|
||||
_mm256_andnot_ps(
|
||||
_mm256_castsi256_ps(d),
|
||||
_mm256_or_ps(
|
||||
_mm256_and_ps(_mm256_castsi256_ps(c),
|
||||
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
|
||||
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
|
||||
}
|
||||
inline __m256 v_tanh(__m256 x) {
|
||||
const __m256 one = _mm256_set1_ps(1.0f);
|
||||
const __m256 exp_two_x = v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
|
||||
const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
|
||||
const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
|
||||
return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res));
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
|
||||
namespace {
|
||||
|
||||
// Some of the methods in FlashAttn have two identical implementations that only differ by
|
||||
// one version using a loop over the template parameter q_step, while the other using a loop
|
||||
// over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot,
|
||||
// but performance drops signficantly if I remove the version with fixed q_step iterations.
|
||||
// We only instantiate FlashAttn with q_step = 1 and q_step = 4 or 8 (depending on head size D),
|
||||
// so when we have to process Nq rows, we process q_step*(Nq/q_step) using fixed q_step loops,
|
||||
// and use the variable nq version (with lower performance) only for the remaining i1...q_step-1
|
||||
// rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to
|
||||
// process template parameter of such functions, but this would result in the compiler generating
|
||||
// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
|
||||
template <int D, int q_step, int k_step>
|
||||
struct FlashAttn {
|
||||
static_assert(D%16 == 0 && D <= 256);
|
||||
static_assert(k_step%16 == 0);
|
||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||
|
||||
constexpr static bool is_small_head = D <= 128;
|
||||
|
||||
constexpr static int vk_size = is_small_head ? D/8 : D/16;
|
||||
static_assert(2*q_step <= vk_size);
|
||||
|
||||
FlashAttn(float scale, float softcap) : vscale(_mm512_set1_ps(scale)), softcap(softcap), h_inf(GGML_FP32_TO_FP16(-INFINITY)) {}
|
||||
|
||||
inline void init_qstep() {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
S[j] = 0; M[j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, __m512 * qv) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
|
||||
cache[k_step*m1 + l1 + 0] = cache[k_step*m1 + l1 + 1] = -INFINITY;
|
||||
if (mp[l1+0] == h_inf && mp[l1+1] == h_inf) {
|
||||
return;
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
for (int i = 0; i < D/16; ++i) qv[i] = _mm512_loadu_ps(qr + 16*i);
|
||||
if (mp[l1+0] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[0], qv[0]);
|
||||
for (int i = 1; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
if (mp[l1+1] != h_inf) {
|
||||
auto vsum = _mm512_mul_ps(vk[D/16], qv[0]);
|
||||
for (int i = 1; i < D/16; ++i) vsum = _mm512_fmadd_ps(vk[i+D/16], qv[i], vsum);
|
||||
cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
|
||||
if (mp[l1] == h_inf) {
|
||||
cache[k_step*m1 + l1] = -INFINITY;
|
||||
return;
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
auto vsum = _mm512_mul_ps(vk[0], _mm512_loadu_ps(qr));
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
vsum = _mm512_fmadd_ps(vk[i], _mm512_loadu_ps(qr + 16*i), vsum);
|
||||
}
|
||||
cache[k_step*m1 + l1] = _mm512_reduce_add_ps(vsum);
|
||||
}
|
||||
|
||||
inline void update_M_S(int j) {
|
||||
if (softcap <= 0.0f) {
|
||||
for (int l = 0; l < k_step/16; ++l) vk[l] = _mm512_mul_ps(vscale, _mm512_loadu_ps(cache + k_step*j + 16*l));
|
||||
} else {
|
||||
auto v_softcap = _mm512_set1_ps(softcap);
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
auto val = _mm512_loadu_ps(cache + k_step*j + 16*l);
|
||||
vk[l] = _mm512_mul_ps(v_softcap, v_tanh(_mm512_mul_ps(vscale, val)));
|
||||
}
|
||||
}
|
||||
|
||||
float smax = reduce_T<_mm512_reduce_max_ps, _mm512_max_ps>(vk);
|
||||
need_scaling[j] = 0;
|
||||
if (smax > M[j]) {
|
||||
if (M[j] > -INFINITY) {
|
||||
float m = expf(M[j] - smax);
|
||||
vms[j] = _mm512_set1_ps(m);
|
||||
need_scaling[j] = 1;
|
||||
S[j] *= m;
|
||||
} else {
|
||||
need_scaling[j] = 2;
|
||||
S[j] = 0;
|
||||
}
|
||||
M[j] = smax;
|
||||
}
|
||||
auto vm = _mm512_set1_ps(M[j]);
|
||||
for (int l = 0; l < k_step/16; ++l) {
|
||||
vk[l] = v_expf(_mm512_sub_ps(vk[l], vm));
|
||||
_mm512_storeu_ps(cache + k_step*j + 16*l, vk[l]);
|
||||
}
|
||||
S[j] += reduce_T<_mm512_reduce_add_ps, _mm512_add_ps>(vk);
|
||||
}
|
||||
|
||||
inline void normalize_and_store(int j, const float * R, float * qkv) const {
|
||||
GGML_ASSERT(S[j] > 0);
|
||||
auto norm = _mm512_set1_ps(1/S[j]);
|
||||
for (int i = 0; i < D/16; ++i) {
|
||||
auto r = _mm512_loadu_ps(R + 16*i);
|
||||
_mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, r));
|
||||
}
|
||||
}
|
||||
|
||||
inline void normalize_and_store(int nq1, int stride_qkv, float * qkv) const {
|
||||
auto R = qkv_cache;
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
normalize_and_store(j, R, qkv);
|
||||
qkv += stride_qkv;
|
||||
R += D;
|
||||
}
|
||||
}
|
||||
|
||||
inline void normalize_and_store(int stride_qkv, float * qkv) const {
|
||||
auto R = qkv_cache;
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
normalize_and_store(j, R, qkv);
|
||||
qkv += stride_qkv;
|
||||
R += D;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq(int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
__m512 qv[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
|
||||
auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
|
||||
for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_l(int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto kr = (const ggml_half *)(k + l1*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
||||
inline void mult_mask_kq(int nq, int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
__m512 qv[D/16];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
auto kr1 = (const ggml_half *)(k + (l1 + 0)*stride_k);
|
||||
auto kr2 = (const ggml_half *)(k + (l1 + 1)*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i+ 0] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr1 + i));
|
||||
for (int i = 0; i < D/16; ++i) vk[i+D/16] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr2 + i));
|
||||
for (int m1 = 0; m1 < nq; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
||||
inline void mult_mask_kq_l(int nq, int stride_k, int stride_q, int stride_m,
|
||||
const char * k, const float * q, const char * mask) {
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto kr = (const ggml_half *)(k + l1*stride_k);
|
||||
for (int i = 0; i < D/16; ++i) vk[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i));
|
||||
for (int m1 = 0; m1 < nq; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void multiply_mask_kq(int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(stride_k, stride_q, stride_m, k, q, mask);
|
||||
}
|
||||
else {
|
||||
mult_mask_kq_l(stride_k, stride_q, stride_m, k, q, mask);
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
update_M_S(j);
|
||||
}
|
||||
}
|
||||
|
||||
inline void multiply_mask_kq(int nq, int stride_k, int stride_q, int stride_m, const char * k, const float * q, const char * mask) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(nq, stride_k, stride_q, stride_m, k, q, mask);
|
||||
}
|
||||
else {
|
||||
mult_mask_kq_l(nq, stride_k, stride_q, stride_m, k, q, mask);
|
||||
}
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
update_M_S(j);
|
||||
}
|
||||
}
|
||||
|
||||
// This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2
|
||||
// Hence, for now, we will not handle head sizes of 80 and 112
|
||||
inline void accumulate_qkv(int stride_v, const char * v) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int Nq = q_step, class = std::enable_if<Nq >= 2>>
|
||||
inline void accumulate_qkv(int nq1, int stride_v, const char * v) {
|
||||
for (int i = 0; i < D/16; i += 2) {
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
if (need_scaling[j] == 2) {
|
||||
vk[2*j+0] = vk[2*j+1] = _mm512_setzero_ps();
|
||||
} else {
|
||||
auto R = qkv_cache + D*j;
|
||||
vk[2*j+0] = _mm512_loadu_ps(R + 16*i);
|
||||
vk[2*j+1] = _mm512_loadu_ps(R + 16*i + 16);
|
||||
if (need_scaling[j] == 1) {
|
||||
vk[2*j+0] = _mm512_mul_ps(vk[2*j+0], vms[j]);
|
||||
vk[2*j+1] = _mm512_mul_ps(vk[2*j+1], vms[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
auto vr = (const ggml_half *)(v + l1*stride_v);
|
||||
auto v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+0));
|
||||
auto v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)vr+i+1));
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto vs = _mm512_set1_ps(cache[k_step*j + l1]);
|
||||
vk[2*j+0] = _mm512_fmadd_ps(v1, vs, vk[2*j+0]);
|
||||
vk[2*j+1] = _mm512_fmadd_ps(v2, vs, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
auto R = qkv_cache + D*j;
|
||||
_mm512_storeu_ps(R + 16*i, vk[2*j+0]);
|
||||
_mm512_storeu_ps(R + 16*i + 16, vk[2*j+1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute(int nq1, int nk1, int stride_k, int stride_q, int stride_m, int stride_v, int stride_qkv,
|
||||
const char * k, const float * q, const char * mask, const char * v, float * qkv) {
|
||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||
init_qstep();
|
||||
auto kr = k;
|
||||
auto vr = v;
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
multiply_mask_kq(stride_k, stride_q, stride_m, kr, q, mr);
|
||||
accumulate_qkv(stride_v, vr);
|
||||
kr += k_step*stride_k;
|
||||
vr += k_step*stride_v;
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
normalize_and_store(stride_qkv, qkv);
|
||||
|
||||
q += q_step*stride_q;
|
||||
mask += q_step*stride_m;
|
||||
qkv += q_step*stride_qkv;
|
||||
}
|
||||
int n_left = nq1 - q_step*(nq1/q_step);
|
||||
if (n_left > 0) {
|
||||
init_qstep();
|
||||
auto kr = k;
|
||||
auto vr = v;
|
||||
auto mr = mask;
|
||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||
multiply_mask_kq(n_left, stride_k, stride_q, stride_m, kr, q, mr);
|
||||
accumulate_qkv(n_left, stride_v, vr);
|
||||
kr += k_step*stride_k;
|
||||
vr += k_step*stride_v;
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
normalize_and_store(n_left, stride_qkv, qkv);
|
||||
}
|
||||
}
|
||||
|
||||
float cache[q_step*k_step];
|
||||
float qkv_cache[D*q_step];
|
||||
float S[q_step], M[q_step];
|
||||
int need_scaling[q_step];
|
||||
__m512 vms[q_step];
|
||||
__m512 vk[vk_size];
|
||||
const __m512 vscale;
|
||||
const float softcap;
|
||||
const ggml_half h_inf;
|
||||
|
||||
typedef __m512 (*combine_t)(__m512, __m512);
|
||||
typedef float (*reduce_t)(__m512);
|
||||
template <reduce_t Op, combine_t Op_combine>
|
||||
static inline float reduce_T(const __m512 * vals) {
|
||||
float result;
|
||||
if constexpr (k_step/16 == 1) {
|
||||
result = Op(vals[0]);
|
||||
}
|
||||
else if constexpr (k_step/16 == 2) {
|
||||
result = Op(Op_combine(vals[0], vals[1]));
|
||||
}
|
||||
else {
|
||||
auto vmax = Op_combine(vals[0], vals[1]);
|
||||
for (int l = 2; l < k_step/16; ++l) vmax = Op_combine(vmax, vals[l]);
|
||||
result = Op(vmax);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <int D, int q_step, int k_step>
|
||||
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv) {
|
||||
if (nq1 >= q_step) {
|
||||
FlashAttn<D, q_step, k_step> fa(scale, softcap);
|
||||
fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
|
||||
(const char *)k, q, (const char *)mask, (const char *)v, qkv);
|
||||
} else {
|
||||
FlashAttn<D, 1, k_step> fa(scale, softcap);
|
||||
fa.compute(nq1, nk1, stride_k, stride_q, stride_m, stride_v, stride_qkv,
|
||||
(const char *)k, q, (const char *)mask, (const char *)v, qkv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool iqk_flash_attn_noalibi(int D, // head size
|
||||
int nq1, // number of columns in q
|
||||
int nk1, // number of rows in k
|
||||
int stride_q, // distance between q columns in bytes
|
||||
int stride_k, // distance between k rows in bytes
|
||||
int stride_v, // distance between v rows in bytes
|
||||
int stride_m, // distance between mask rows (in bytes
|
||||
int stride_qkv, // distance between rows in mask (in bytes)
|
||||
const float * q, // q matrix.
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
float scale, // scale applied before softmax
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv) { // v*softmax(scale*(k*q))
|
||||
|
||||
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
|
||||
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
|
||||
switch (D) {
|
||||
case 64:
|
||||
iqk_flash_helper_T< 64, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
// Disable until we fix accumulate_qkv for odd D/16
|
||||
//case 80:
|
||||
// iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 96:
|
||||
iqk_flash_helper_T< 96, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
// Disable until we fix accumulate_qkv for odd D/16
|
||||
//case 112:
|
||||
// iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 128:
|
||||
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 256:
|
||||
iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
// TODO
|
||||
bool iqk_flash_attn_noalibi([[maybe_unused]] int D, // head size
|
||||
[[maybe_unused]] int nq, // number of columns in q
|
||||
[[maybe_unused]] int nk, // number of rows in k
|
||||
[[maybe_unused]] int stride_q, // distance between q columns in bytes
|
||||
[[maybe_unused]] int stride_k, // distance between k rows in bytes
|
||||
[[maybe_unused]] int stride_v, // distance between v rows in bytes
|
||||
[[maybe_unused]] int stride_m, // distance between mask rows (in bytes
|
||||
[[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
|
||||
[[maybe_unused]] const float * q, // q matrix.
|
||||
[[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
[[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
[[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
[[maybe_unused]] float scale, // scale applied before softmax
|
||||
[[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
[[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#else // IQK_IMPLEMENT
|
||||
|
||||
bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
|
||||
@@ -5926,4 +6495,22 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const
|
||||
return false;
|
||||
}
|
||||
|
||||
bool iqk_flash_attn_noalibi([[maybe_unused]] int D, // head size
|
||||
[[maybe_unused]] int nq, // number of columns in q
|
||||
[[maybe_unused]] int nk, // number of rows in k
|
||||
[[maybe_unused]] int stride_q, // distance between q columns in bytes
|
||||
[[maybe_unused]] int stride_k, // distance between k rows in bytes
|
||||
[[maybe_unused]] int stride_v, // distance between v rows in bytes
|
||||
[[maybe_unused]] int stride_m, // distance between mask rows (in bytes
|
||||
[[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes)
|
||||
[[maybe_unused]] const float * q, // q matrix.
|
||||
[[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
[[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
[[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
[[maybe_unused]] float scale, // scale applied before softmax
|
||||
[[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
[[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q))
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -21,6 +21,21 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
||||
int typeB, const void * B, long strideB,
|
||||
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
|
||||
|
||||
bool iqk_flash_attn_noalibi(int D, // head size
|
||||
int nq, // number of columns in q
|
||||
int nk, // number of rows in k
|
||||
int stride_q, // distance between q columns in bytes
|
||||
int stride_k, // distance between k rows in bytes
|
||||
int stride_v, // distance between v rows in bytes
|
||||
int stride_m, // distance between mask rows (in bytes
|
||||
int stride_qkv, // distance between rows in mask (in bytes)
|
||||
const float * q, // q matrix.
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
float scale, // scale applied before softmax
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv); // v*softmax(scale*(k*q))
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user