mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
WIP: Fusing K*Q and softmax - not working yet
This commit is contained in:
117
ggml/src/ggml.c
117
ggml/src/ggml.c
@@ -12761,6 +12761,8 @@ static void ggml_compute_forward_mul_mat(
|
||||
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) {
|
||||
//if (ith == 0) printf("%s(%-10s): %d multiplies with Nx = %d, Ny = %d ne00 = %d\n", __func__, dst->name,
|
||||
// (int)(ne12*ne13), (int)ne01, (int)ne11, (int)ne00);
|
||||
int counter = 0;
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
@@ -12773,6 +12775,10 @@ static void ggml_compute_forward_mul_mat(
|
||||
}
|
||||
}
|
||||
}
|
||||
#if IK_PRINT_TIMING
|
||||
if (ith == 0) printf("%s(%s, 0): %g GFLOP, %g MiB\n", __func__, dst->name, 2e-9*ggml_nelements(dst)*ne00,
|
||||
4.*ggml_nelements(dst)/(1024*1024));
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
@@ -12783,6 +12789,10 @@ static void ggml_compute_forward_mul_mat(
|
||||
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
|
||||
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
|
||||
ith, nth)) goto IQK_MulMat_Not_Available1;
|
||||
#if IK_PRINT_TIMING
|
||||
if (ith == 0) printf("%s(%s, 1): %g GFLOP, %g MiB\n", __func__, dst->name, 2e-9*ggml_nelements(dst)*ne00,
|
||||
4.*ggml_nelements(dst)/(1024*1024));
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
IQK_MulMat_Not_Available1:;
|
||||
@@ -14302,6 +14312,93 @@ static void ggml_compute_forward_diag_mask_zero(
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_fused_mul_mat_softmax(const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * mul_mat,
|
||||
struct ggml_tensor * soft_max) {
|
||||
|
||||
if (!(mul_mat->src[0]->type == GGML_TYPE_F16 || mul_mat->src[0]->type == GGML_TYPE_F32) ||
|
||||
!(mul_mat->src[1]->type == GGML_TYPE_F16 || mul_mat->src[1]->type == GGML_TYPE_F32) ||
|
||||
!(soft_max->type == GGML_TYPE_F16 ||soft_max->type == GGML_TYPE_F32) ||
|
||||
!ggml_is_contiguous(soft_max) || !ggml_are_same_shape(mul_mat, soft_max)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
//if (ith == 0) printf("%s: %s = softmax(%s = %s x %s)\n", __func__, soft_max->name, mul_mat->name, mul_mat->src[0]->name, mul_mat->src[1]->name);
|
||||
|
||||
const struct ggml_tensor * dst = mul_mat;
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
const int64_t r3 = ne13 / ne03;
|
||||
|
||||
float op_params[2];
|
||||
memcpy(op_params, soft_max->op_params, sizeof(op_params));
|
||||
|
||||
const uint32_t n_head = ne02;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
const float m0 = powf(2.0f, -(op_params[0] ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(op_params[0] / 2.0f) / n_head_log2);
|
||||
|
||||
//if ((ne12*ne13)%nth == 0) {
|
||||
// int counter = 0;
|
||||
// for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
// for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
// if (counter++ % nth == ith) {
|
||||
// const uint32_t h = i12;
|
||||
// const float slope = (op_params[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
// if (!iqk_fused_mul_mat_softmax(ne01, ne11, ne00,
|
||||
// src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
|
||||
// src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
|
||||
// (float *)((char *)soft_max->data + i12*nb2 + i13*nb3), nb1/sizeof(float),
|
||||
// params->wdata, params->wsize,
|
||||
// soft_max->src[1]->data, op_params[0], slope, 0, 1)) return false;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//} else {
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
const uint32_t h = i12;
|
||||
const float slope = (op_params[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
if (!iqk_fused_mul_mat_softmax(ne01, ne11, ne00,
|
||||
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
|
||||
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
|
||||
(float *)((char *)soft_max->data + i12*nb2 + i13*nb3), nb1/sizeof(float),
|
||||
params->wdata, params->wsize,
|
||||
soft_max->src[1]->data, op_params[0], slope, ith, nth)) {
|
||||
if (ith == 0) printf("iqk_fused_mul_mat_softmax returned false!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
//}
|
||||
|
||||
//if (ith == 0) printf(" success!\n");
|
||||
return true;
|
||||
}
|
||||
|
||||
// ggml_compute_forward_soft_max
|
||||
|
||||
static void ggml_compute_forward_soft_max_f32(
|
||||
@@ -14340,6 +14437,10 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
const int nc = src0->ne[0];
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
if (ith == 0) printf("%s: %d x %d, %g MiB\n", __func__, nc, nr, 4.*nc*nr/(1024*1024));
|
||||
#endif
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
@@ -17357,11 +17458,12 @@ static void ggml_compute_forward_cross_entropy_loss_back(
|
||||
|
||||
/////////////////////////////////
|
||||
|
||||
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
|
||||
GGML_ASSERT(params);
|
||||
|
||||
bool result = false;
|
||||
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
|
||||
return;
|
||||
return result;
|
||||
}
|
||||
|
||||
#if IK_PRINT_TIMING
|
||||
@@ -17459,7 +17561,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_compute_forward_mul_mat(params, tensor);
|
||||
if (next && next->op == GGML_OP_SOFT_MAX) {
|
||||
result = ggml_fused_mul_mat_softmax(params, tensor, next);
|
||||
} else {
|
||||
ggml_compute_forward_mul_mat(params, tensor);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
@@ -17701,6 +17807,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
int64_t t2 = ggml_time_us();
|
||||
if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1));
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -19526,7 +19633,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
|
||||
if (ggml_is_noop(node)) continue;
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
bool skip_next = ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes - 1 ? cgraph->nodes[node_n + 1] : NULL);
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
state->shared->ec = GGML_STATUS_ABORTED;
|
||||
@@ -19537,6 +19644,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
if (state->shared->ec != GGML_STATUS_SUCCESS) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (skip_next) ++node_n;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <new> // for hardware_destructive_interference_size
|
||||
|
||||
#if defined IQK_IMPLEMENT
|
||||
|
||||
@@ -5915,6 +5916,305 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
|
||||
#endif // __aarch64__
|
||||
|
||||
namespace {
|
||||
|
||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||
// copy-pasted from Justine Tunney's contribution to llama.cpp
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline float32x4_t v_expf(float32x4_t x) {
|
||||
const float32x4_t r = vdupq_n_f32(0x1.8p23f);
|
||||
const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
|
||||
const float32x4_t n = vsubq_f32(z, r);
|
||||
const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
|
||||
vdupq_n_f32(0x1.7f7d1cp-20f));
|
||||
const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
|
||||
const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
|
||||
const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
|
||||
const float32x4_t u = vmulq_f32(b, b);
|
||||
const float32x4_t j = vfmaq_f32(
|
||||
vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
|
||||
vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
|
||||
vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
|
||||
if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
|
||||
return vfmaq_f32(k, j, k);
|
||||
const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
|
||||
const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
|
||||
const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
|
||||
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
|
||||
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
|
||||
// copy-pasted from Justine Tunney's contribution to llama.cpp
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline __m512 v_expf(__m512 x) {
|
||||
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
||||
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
||||
const __m512 n = _mm512_sub_ps(z, r);
|
||||
const __m512 b =
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
||||
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
||||
const __mmask16 d =
|
||||
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
|
||||
const __m512 u = _mm512_mul_ps(b, b);
|
||||
const __m512 j = _mm512_fmadd_ps(
|
||||
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
||||
_mm512_set1_ps(0x1.573e2ep-5f)),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
||||
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
||||
u,
|
||||
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
|
||||
const __m512 res = _mm512_scalef_ps(j, n);
|
||||
if (_mm512_kortestz(d, d))
|
||||
return res;
|
||||
const __m512 zero = _mm512_setzero_ps();
|
||||
const __m512 alt = _mm512_mask_blend_ps(
|
||||
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
|
||||
return _mm512_mask_blend_ps(d, res, alt);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__AVX2__) && defined(__FMA__)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline __m256 v_expf(__m256 x) {
|
||||
const __m256 r = _mm256_set1_ps(0x1.8p23f);
|
||||
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
|
||||
const __m256 n = _mm256_sub_ps(z, r);
|
||||
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
|
||||
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
|
||||
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
|
||||
const __m256 k = _mm256_castsi256_ps(
|
||||
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
|
||||
const __m256i c = _mm256_castps_si256(
|
||||
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
||||
_mm256_set1_ps(126), _CMP_GT_OQ));
|
||||
const __m256 u = _mm256_mul_ps(b, b);
|
||||
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
|
||||
_mm256_set1_ps(0x1.573e2ep-5f)), u,
|
||||
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
|
||||
_mm256_set1_ps(0x1.fffdb6p-2f))),
|
||||
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
|
||||
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
|
||||
return _mm256_fmadd_ps(j, k, k);
|
||||
const __m256i g = _mm256_and_si256(
|
||||
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
|
||||
_mm256_set1_epi32(0x82000000u));
|
||||
const __m256 s1 =
|
||||
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
|
||||
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
|
||||
const __m256i d = _mm256_castps_si256(
|
||||
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
|
||||
_mm256_set1_ps(192), _CMP_GT_OQ));
|
||||
return _mm256_or_ps(
|
||||
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
|
||||
_mm256_andnot_ps(
|
||||
_mm256_castsi256_ps(d),
|
||||
_mm256_or_ps(
|
||||
_mm256_and_ps(_mm256_castsi256_ps(c),
|
||||
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
|
||||
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
|
||||
}
|
||||
#endif
|
||||
|
||||
inline float prepare_softmax(int nc, float * sp, float scale, float slope, const char * mp, bool use_fp16) {
|
||||
__m512 vscale = _mm512_set1_ps(scale);
|
||||
__m512 vmax = _mm512_set1_ps(-INFINITY);
|
||||
float scalar_max = -INFINITY;
|
||||
if (mp) {
|
||||
__m512 vslope = _mm512_set1_ps(slope);
|
||||
if (use_fp16) {
|
||||
const ggml_fp16_t * mp_f16 = (const ggml_fp16_t *)mp;
|
||||
__m512 vmax1 = _mm512_set1_ps(-INFINITY);
|
||||
for (int i = 0; i < nc/32; ++i) {
|
||||
const __m512 m1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+0));
|
||||
const __m512 m2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+1));
|
||||
const __m512 x1 = _mm512_loadu_ps(sp + 32*i + 0);
|
||||
const __m512 x2 = _mm512_loadu_ps(sp + 32*i + 16);
|
||||
const __m512 y1 = _mm512_fmadd_ps(vslope, m1, _mm512_mul_ps(vscale, x1));
|
||||
const __m512 y2 = _mm512_fmadd_ps(vslope, m2, _mm512_mul_ps(vscale, x2));
|
||||
vmax = _mm512_max_ps(vmax , y1);
|
||||
vmax1 = _mm512_max_ps(vmax1, y2);
|
||||
_mm512_storeu_ps(sp + 32*i + 0, y1);
|
||||
_mm512_storeu_ps(sp + 32*i + 16, y2);
|
||||
}
|
||||
vmax = _mm512_max_ps(vmax, vmax1);
|
||||
for (int i = 32*(nc/32); i < nc; ++i) {
|
||||
sp[i] = scale*sp[i] + slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
||||
scalar_max = std::max(scalar_max, sp[i]);
|
||||
}
|
||||
} else {
|
||||
const float * mp_f32 = (const float *)mp;
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
const __m512 m = _mm512_loadu_ps(mp_f32 + 16*i);
|
||||
const __m512 x = _mm512_loadu_ps(sp + 16*i);
|
||||
const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x));
|
||||
vmax = _mm512_max_ps(vmax, y);
|
||||
_mm512_storeu_ps(sp + 16*i, y);
|
||||
}
|
||||
for (int i = 16*(nc/16); i < nc; ++i) {
|
||||
sp[i] = scale*sp[i] + slope*mp_f32[i];
|
||||
scalar_max = std::max(scalar_max, sp[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
const __m512 x = _mm512_loadu_ps(sp + 16*i);
|
||||
const __m512 y = _mm512_mul_ps(vscale, x);
|
||||
vmax = _mm512_max_ps(vmax, y);
|
||||
_mm512_storeu_ps(sp + 16*i, y);
|
||||
}
|
||||
for (int i = 16*(nc/16); i < nc; ++i) {
|
||||
sp[i] = scale*sp[i];
|
||||
scalar_max = std::max(scalar_max, sp[i]);
|
||||
}
|
||||
}
|
||||
float vector_max = _mm512_reduce_max_ps(vmax);
|
||||
return std::max(scalar_max, vector_max);
|
||||
}
|
||||
|
||||
inline float do_soft_max(int nc, float * sp, float max) {
|
||||
__m512 vmax = _mm512_set1_ps(-max);
|
||||
__m512 vsum = _mm512_setzero_ps();
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
auto x = _mm512_loadu_ps(sp + 16*i);
|
||||
auto y = v_expf(_mm512_add_ps(x, vmax));
|
||||
vsum = _mm512_add_ps(vsum, y);
|
||||
_mm512_storeu_ps(sp + 16*i, y);
|
||||
}
|
||||
float sum = _mm512_reduce_add_ps(vsum);
|
||||
for (int i = 16*(nc/16); i < nc; ++i) {
|
||||
float y = expf(sp[i] - max);
|
||||
sum += y;
|
||||
sp[i] = y;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
inline void do_scale(int nc, const float * sp, float * dp, float scale) {
|
||||
auto vscale = _mm512_set1_ps(scale);
|
||||
for (int i = 0; i < nc/16; ++i) {
|
||||
auto x = _mm512_loadu_ps(sp + 16*i);
|
||||
auto y = _mm512_mul_ps(x, vscale);
|
||||
_mm512_storeu_ps(dp + 16*i, y);
|
||||
}
|
||||
for (int i = 16*(nc/16); i < nc; ++i) {
|
||||
dp[i] = sp[i]*scale;
|
||||
}
|
||||
}
|
||||
|
||||
void softmax_extended(int nc, float * sp, float * dp, float scale, float slope, const char * mp, bool mask_is_fp16) {
|
||||
auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16);
|
||||
auto sum = do_soft_max(nc, sp, max);
|
||||
do_scale(nc, sp, dp, 1.f/sum);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00,
|
||||
int int_typeA, const void * A, long strideA,
|
||||
int int_typeB, const void * B, long strideB,
|
||||
float * C, long stride_C,
|
||||
char * work_buffer, long work_size,
|
||||
const char * mask, float scale, float slope,
|
||||
int ith, int nth) {
|
||||
|
||||
constexpr int k_y_step = 5; // TODO: make this CPU dependent
|
||||
|
||||
#if defined(__cpp_lib_hardware_interference_size)
|
||||
constexpr int k_cache_line = std::hardware_destructive_interference_size;
|
||||
#else
|
||||
constexpr int k_cache_line = 64;
|
||||
#endif
|
||||
|
||||
auto typeA = ggml_type(int_typeA);
|
||||
auto typeB = ggml_type(int_typeB);
|
||||
|
||||
if (!(typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) || !(typeB == GGML_TYPE_F16 || typeB == GGML_TYPE_F32)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int ny_per_thread = (Ny + nth - 1)/nth;
|
||||
int first_y = ny_per_thread*ith;
|
||||
if (first_y >= Ny) {
|
||||
return true;
|
||||
}
|
||||
int n_cols_per_iter = std::min(ny_per_thread, k_y_step);
|
||||
int needed_work_size_per_thread = k_cache_line*((n_cols_per_iter*Nx*sizeof(float) + k_cache_line - 1)/k_cache_line);
|
||||
int needed_work_size = needed_work_size_per_thread*nth;
|
||||
if (needed_work_size > work_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::array<mul_mat_t, k_y_step> funcs;
|
||||
if (typeA == GGML_TYPE_F16) {
|
||||
if (typeB == GGML_TYPE_F16) {
|
||||
funcs[0] = mul_mat_fX_fY_T<1, ggml_half, ggml_half>;
|
||||
funcs[1] = mul_mat_fX_fY_T<2, ggml_half, ggml_half>;
|
||||
funcs[2] = mul_mat_fX_fY_T<3, ggml_half, ggml_half>;
|
||||
funcs[3] = mul_mat_fX_fY_T<4, ggml_half, ggml_half>;
|
||||
funcs[4] = mul_mat_fX_fY_T<5, ggml_half, ggml_half>;
|
||||
} else {
|
||||
funcs[0] = mul_mat_fX_fY_T<1, ggml_half, float>;
|
||||
funcs[1] = mul_mat_fX_fY_T<2, ggml_half, float>;
|
||||
funcs[2] = mul_mat_fX_fY_T<3, ggml_half, float>;
|
||||
funcs[3] = mul_mat_fX_fY_T<4, ggml_half, float>;
|
||||
funcs[4] = mul_mat_fX_fY_T<5, ggml_half, float>;
|
||||
}
|
||||
} else {
|
||||
if (typeB == GGML_TYPE_F16) {
|
||||
funcs[0] = mul_mat_fX_fY_T<1, float, ggml_half>;
|
||||
funcs[1] = mul_mat_fX_fY_T<2, float, ggml_half>;
|
||||
funcs[2] = mul_mat_fX_fY_T<3, float, ggml_half>;
|
||||
funcs[3] = mul_mat_fX_fY_T<4, float, ggml_half>;
|
||||
funcs[4] = mul_mat_fX_fY_T<5, float, ggml_half>;
|
||||
} else {
|
||||
funcs[0] = mul_mat_fX_fY_T<1, float, float>;
|
||||
funcs[1] = mul_mat_fX_fY_T<2, float, float>;
|
||||
funcs[2] = mul_mat_fX_fY_T<3, float, float>;
|
||||
funcs[3] = mul_mat_fX_fY_T<4, float, float>;
|
||||
funcs[4] = mul_mat_fX_fY_T<5, float, float>;
|
||||
}
|
||||
}
|
||||
|
||||
auto row_size_qx = strideA*ggml_type_size(typeA);
|
||||
auto row_size_qy = strideB*ggml_type_size(typeB);
|
||||
|
||||
DataInfo info{(float *)(work_buffer + needed_work_size_per_thread*ith), (const char *)B + first_y*row_size_qy,
|
||||
Nx*sizeof(float), row_size_qy, 0, 1, nullptr, 0};
|
||||
|
||||
C += first_y*stride_C;
|
||||
|
||||
const char * mp = mask ? mask + first_y*Nx*sizeof(ggml_half) : nullptr;
|
||||
|
||||
int n_step = (ny_per_thread + k_y_step - 1)/k_y_step;
|
||||
for (int i_step = 1; i_step <= n_step; ++i_step) {
|
||||
int this_ny = i_step*k_y_step <= ny_per_thread ? k_y_step : ny_per_thread - (i_step - 1)*k_y_step;
|
||||
funcs[this_ny-1](ne00, A, row_size_qx, info, Nx);
|
||||
// Now we need to compute the softmax and store the result in C
|
||||
for (int iy = 0; iy < this_ny; ++iy) {
|
||||
softmax_extended(Nx, info.s + iy*Nx, C, scale, slope, mp, true);
|
||||
C += stride_C;
|
||||
if (mp) mp += Nx*sizeof(ggml_half);
|
||||
}
|
||||
info.cy += row_size_qy*this_ny;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#else // IQK_IMPLEMENT
|
||||
|
||||
bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
|
||||
@@ -5926,4 +6226,14 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const
|
||||
return false;
|
||||
}
|
||||
|
||||
bool iqk_fused_mul_mat_softmax(long, long, long,
|
||||
int, const void *, long,
|
||||
int, const void *, long,
|
||||
float *, long,
|
||||
char *, long,
|
||||
const char *, float, float, uint32_t,
|
||||
int, int) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -21,6 +21,13 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
||||
int typeB, const void * B, long strideB,
|
||||
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
|
||||
|
||||
bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00,
|
||||
int typeA, const void * A, long strideA,
|
||||
int typeB, const void * B, long strideB,
|
||||
float * C, long stride_C,
|
||||
char * work_buffer, long work_size,
|
||||
const char * mask, float scale, float slope,
|
||||
int ith, int nth);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user