diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index adff7669..755d34ce 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2075,6 +2075,19 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y return max; } +static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { + GGML_ASSERT(n%16 == 0); + __m512 vmax = _mm512_set1_ps(-INFINITY); + __m512 vinf = _mm512_set1_ps(-INFINITY); + const __mmask16 * mm16 = (const __mmask16 *)x; + for (int j = 0; j < n/16; ++j) { + __m512 v = _mm512_mask_blend_ps(mm16[j], _mm512_loadu_ps(y + 16*j), vinf); + _mm512_storeu_ps(y + 16*j, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} + static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -2646,6 +2659,13 @@ inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) { return _mm512_mul_ps(th, s_after); } +inline static __m512 ggml_v_softcap_mask(__m512 x, __m512 s_before, __m512 s_after, __m512 src, __mmask16 mask) { + const __m512 one = _mm512_set1_ps(1.0f); + const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, s_before)); + const __m512 th = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + return _mm512_mask_mul_ps(src, mask, th, s_after); +} + inline static __m512 ggml_v_gelu(__m512 x, __m512 c1, __m512 c2) { const __m512 one = _mm512_set1_ps(1.0f); __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); @@ -2883,6 +2903,20 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl } } +static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { + const __mmask16 * m16 = (const __mmask16 *)mask; + __m512 vinf = _mm512_set1_ps(-INFINITY); + __m512 vmax = vinf; + __m512 vs_before = _mm512_set1_ps(2.f*s_before); + __m512 vs_after = _mm512_set1_ps(s_after); + for (int i = 0; i < n/16; ++i) { + __m512 v = ggml_v_softcap_mask(_mm512_loadu_ps(x + 16*i), vs_before, vs_after, vinf, m16[i]); + _mm512_storeu_ps(y + 16*i, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} + static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -6045,10 +6079,10 @@ static struct ggml_tensor * ggml_softcap_max_impl( GGML_ASSERT(ggml_is_padded_1d(a)); if (mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32 || mask->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(mask->ne[0] == a->ne[0]); + //GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); } @@ -13799,6 +13833,7 @@ static void ggml_compute_forward_softcap_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32); for (int i1 = ir0; i1 < ir1; i1++) { // ALiBi @@ -13809,27 +13844,36 @@ static void ggml_compute_forward_softcap_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - - ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + const int mask_row = i1%ne01; float max = -INFINITY; - if (mp_f32) { - if (use_f16) { - max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); - //for (int i = 0; i < nc; ++i) { - // wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); - //} - } else { - max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); - //for (int i = 0; i < nc; ++i) { - // wp[i] += slope*mp_f32[i]; - //} + if (use_i32) { + int n32 = (ne00 + 31)/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32; + max = ggml_vec_cpy_softcap_mask_f32(nc, sp, wp, mp_u32, values[2], values[0]*values[3]); + } else { + + ggml_vec_cpy_softcap_f32(nc, sp, wp, values[2], values[0]*values[3]); + + if (src1) { + if (use_f16) { + ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + mask_row*ne00; + max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); + } else if (use_i32) { + int n32 = (ne00 + 31)/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32; + max = ggml_vec_add_f32_infmask(nc, mp_u32, wp); + } else { + float * mp_f32 = (float *)((char *) src1->data) + mask_row*ne00; + max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); + //for (int i = 0; i < nc; ++i) { + // wp[i] += slope*mp_f32[i]; + //} + } + } + else { + ggml_vec_max_f32(nc, &max, wp); } - } - else { - ggml_vec_max_f32(nc, &max, wp); } #ifndef NDEBUG diff --git a/src/llama.cpp b/src/llama.cpp index 76aa3fb8..51b5bbaf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8687,10 +8687,10 @@ struct llm_build_context { } struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { - auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16; - lctx.inp_KQ_mask = causal - ? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_I32; + auto nx = causal ? n_kv : n_tokens; + if (type == GGML_TYPE_I32) nx = (nx + 31)/32; + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); return flash_attn && type == GGML_TYPE_F32 ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; @@ -8705,11 +8705,10 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) { GGML_ASSERT(hparams.n_swa > 0); - - auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16; - lctx.inp_KQ_mask_swa = causal - ? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) - : ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_I32; + auto nx = causal ? n_kv : n_tokens; + if (type == GGML_TYPE_I32) nx = (nx + 31)/32; + lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); ggml_set_input(lctx.inp_KQ_mask_swa); @@ -14288,40 +14287,45 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); } - auto float_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type; - GGML_ASSERT(float_type == GGML_TYPE_F16 || float_type == GGML_TYPE_F32); + auto mask_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type; + GGML_ASSERT(mask_type == GGML_TYPE_I32 || mask_type == GGML_TYPE_F32); - if (float_type == GGML_TYPE_F16) { + if (mask_type == GGML_TYPE_I32) { // in order this to be true, we are not using alibi GGML_ASSERT(!hparams.use_alibi); - auto h_zero = ggml_fp32_to_fp16(0.0f); - auto h_inf = ggml_fp32_to_fp16(-INFINITY); - ggml_fp16_t * h_data = lctx.inp_KQ_mask ? (ggml_fp16_t *)lctx.inp_KQ_mask->data : nullptr; - ggml_fp16_t * h_data_swa = lctx.inp_KQ_mask_swa ? (ggml_fp16_t *)lctx.inp_KQ_mask_swa->data : nullptr; + uint32_t * h_data = lctx.inp_KQ_mask ? (uint32_t *)lctx.inp_KQ_mask->data : nullptr; + uint32_t * h_data_swa = lctx.inp_KQ_mask_swa ? (uint32_t *)lctx.inp_KQ_mask_swa->data : nullptr; for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; const llama_seq_id seq_id = batch.seq_id[j][0]; + uint32_t u = 0, u_swa = 0; + uint32_t m = 1; + for (int i = 0; i < n_kv; ++i) { - auto f = lctx.kv_self.cells[i].pos <= pos && lctx.kv_self.cells[i].has_seq_id(seq_id) ? h_zero : h_inf; - if (h_data) h_data[j*n_kv + i] = f; - if (h_data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) f = h_inf; - h_data_swa[j*n_kv + i] = f; + if (lctx.kv_self.cells[i].pos > pos || !lctx.kv_self.cells[i].has_seq_id(seq_id)) { + u |= m; u_swa |= m; + } + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) u_swa |= m; + m <<= 1; + if (!m) { + if (h_data) *h_data++ = ~u; + if (h_data_swa) *h_data_swa++ = ~u_swa; + u = u_swa = 0; m = 1; } } + if (m > 1) { + if (h_data) *h_data++ = ~u; + if (h_data_swa) *h_data_swa++ = ~u_swa; + } + } - if (h_data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) h_data[i*n_kv + j] = h_inf; - } - } - - if (h_data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) h_data_swa[i*n_kv + j] = h_inf; - } + auto n_pad = GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); + if (n_pad > n_tokens) { + auto n_kv_32 = (n_kv + 31)/32; + if (h_data) std::memset(h_data, 0, (n_pad - n_tokens)*n_kv_32*sizeof(uint32_t)); + if (h_data_swa) std::memset(h_data_swa, 0, (n_pad - n_tokens)*n_kv_32*sizeof(uint32_t)); } }