WIP: Fusing K*Q and softmax - not working yet

This commit is contained in:
Iwan Kawrakow
2024-08-23 09:56:28 +03:00
parent bd99ed7d0a
commit b127c6cced
3 changed files with 430 additions and 4 deletions

View File

@@ -12761,6 +12761,8 @@ static void ggml_compute_forward_mul_mat(
#if GGML_USE_IQK_MULMAT
if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) {
//if (ith == 0) printf("%s(%-10s): %d multiplies with Nx = %d, Ny = %d ne00 = %d\n", __func__, dst->name,
// (int)(ne12*ne13), (int)ne01, (int)ne11, (int)ne00);
int counter = 0;
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
@@ -12773,6 +12775,10 @@ static void ggml_compute_forward_mul_mat(
}
}
}
#if IK_PRINT_TIMING
if (ith == 0) printf("%s(%s, 0): %g GFLOP, %g MiB\n", __func__, dst->name, 2e-9*ggml_nelements(dst)*ne00,
4.*ggml_nelements(dst)/(1024*1024));
#endif
return;
}
if (dst->type == GGML_TYPE_F32) {
@@ -12783,6 +12789,10 @@ static void ggml_compute_forward_mul_mat(
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith, nth)) goto IQK_MulMat_Not_Available1;
#if IK_PRINT_TIMING
if (ith == 0) printf("%s(%s, 1): %g GFLOP, %g MiB\n", __func__, dst->name, 2e-9*ggml_nelements(dst)*ne00,
4.*ggml_nelements(dst)/(1024*1024));
#endif
return;
}
IQK_MulMat_Not_Available1:;
@@ -14302,6 +14312,93 @@ static void ggml_compute_forward_diag_mask_zero(
}
}
static bool ggml_fused_mul_mat_softmax(const struct ggml_compute_params * params,
struct ggml_tensor * mul_mat,
struct ggml_tensor * soft_max) {
if (!(mul_mat->src[0]->type == GGML_TYPE_F16 || mul_mat->src[0]->type == GGML_TYPE_F32) ||
!(mul_mat->src[1]->type == GGML_TYPE_F16 || mul_mat->src[1]->type == GGML_TYPE_F32) ||
!(soft_max->type == GGML_TYPE_F16 ||soft_max->type == GGML_TYPE_F32) ||
!ggml_is_contiguous(soft_max) || !ggml_are_same_shape(mul_mat, soft_max)) {
return false;
}
const int ith = params->ith;
const int nth = params->nth;
//if (ith == 0) printf("%s: %s = softmax(%s = %s x %s)\n", __func__, soft_max->name, mul_mat->name, mul_mat->src[0]->name, mul_mat->src[1]->name);
const struct ggml_tensor * dst = mul_mat;
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const enum ggml_type type = src0->type;
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
float op_params[2];
memcpy(op_params, soft_max->op_params, sizeof(op_params));
const uint32_t n_head = ne02;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(op_params[0] ) / n_head_log2);
const float m1 = powf(2.0f, -(op_params[0] / 2.0f) / n_head_log2);
//if ((ne12*ne13)%nth == 0) {
// int counter = 0;
// for (int64_t i13 = 0; i13 < ne13; i13++) {
// for (int64_t i12 = 0; i12 < ne12; i12++) {
// if (counter++ % nth == ith) {
// const uint32_t h = i12;
// const float slope = (op_params[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
// if (!iqk_fused_mul_mat_softmax(ne01, ne11, ne00,
// src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
// src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
// (float *)((char *)soft_max->data + i12*nb2 + i13*nb3), nb1/sizeof(float),
// params->wdata, params->wsize,
// soft_max->src[1]->data, op_params[0], slope, 0, 1)) return false;
// }
// }
// }
//} else {
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
const uint32_t h = i12;
const float slope = (op_params[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
if (!iqk_fused_mul_mat_softmax(ne01, ne11, ne00,
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
(float *)((char *)soft_max->data + i12*nb2 + i13*nb3), nb1/sizeof(float),
params->wdata, params->wsize,
soft_max->src[1]->data, op_params[0], slope, ith, nth)) {
if (ith == 0) printf("iqk_fused_mul_mat_softmax returned false!\n");
return false;
}
}
}
//}
//if (ith == 0) printf(" success!\n");
return true;
}
// ggml_compute_forward_soft_max
static void ggml_compute_forward_soft_max_f32(
@@ -14340,6 +14437,10 @@ static void ggml_compute_forward_soft_max_f32(
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
#if IK_PRINT_TIMING
if (ith == 0) printf("%s: %d x %d, %g MiB\n", __func__, nc, nr, 4.*nc*nr/(1024*1024));
#endif
// rows per thread
const int dr = (nr + nth - 1)/nth;
@@ -17357,11 +17458,12 @@ static void ggml_compute_forward_cross_entropy_loss_back(
/////////////////////////////////
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
GGML_ASSERT(params);
bool result = false;
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
return;
return result;
}
#if IK_PRINT_TIMING
@@ -17459,7 +17561,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor);
if (next && next->op == GGML_OP_SOFT_MAX) {
result = ggml_fused_mul_mat_softmax(params, tensor, next);
} else {
ggml_compute_forward_mul_mat(params, tensor);
}
} break;
case GGML_OP_MUL_MAT_ID:
{
@@ -17701,6 +17807,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
int64_t t2 = ggml_time_us();
if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1));
#endif
return result;
}
////////////////////////////////////////////////////////////////////////////////
@@ -19526,7 +19633,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
if (ggml_is_noop(node)) continue;
ggml_compute_forward(&params, node);
bool skip_next = ggml_compute_forward(&params, node, node_n < cgraph->n_nodes - 1 ? cgraph->nodes[node_n + 1] : NULL);
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->ec = GGML_STATUS_ABORTED;
@@ -19537,6 +19644,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
if (state->shared->ec != GGML_STATUS_SUCCESS) {
break;
}
if (skip_next) ++node_n;
}
return 0;

View File

@@ -17,6 +17,7 @@
#include <cstring>
#include <type_traits>
#include <new> // for hardware_destructive_interference_size
#if defined IQK_IMPLEMENT
@@ -5915,6 +5916,305 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
#endif // __aarch64__
namespace {
#if defined(__ARM_NEON) && defined(__aarch64__)
// copy-pasted from Justine Tunney's contribution to llama.cpp
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
inline float32x4_t v_expf(float32x4_t x) {
const float32x4_t r = vdupq_n_f32(0x1.8p23f);
const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
const float32x4_t n = vsubq_f32(z, r);
const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
vdupq_n_f32(0x1.7f7d1cp-20f));
const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
const float32x4_t u = vmulq_f32(b, b);
const float32x4_t j = vfmaq_f32(
vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
return vfmaq_f32(k, j, k);
const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
}
#endif
#if defined(__AVX512F__) && defined(__AVX512DQ__)
// copy-pasted from Justine Tunney's contribution to llama.cpp
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
inline __m512 v_expf(__m512 x) {
const __m512 r = _mm512_set1_ps(0x1.8p23f);
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
const __m512 n = _mm512_sub_ps(z, r);
const __m512 b =
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
const __mmask16 d =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
const __m512 u = _mm512_mul_ps(b, b);
const __m512 j = _mm512_fmadd_ps(
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
_mm512_set1_ps(0x1.573e2ep-5f)),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
_mm512_set1_ps(0x1.fffdb6p-2f))),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
const __m512 res = _mm512_scalef_ps(j, n);
if (_mm512_kortestz(d, d))
return res;
const __m512 zero = _mm512_setzero_ps();
const __m512 alt = _mm512_mask_blend_ps(
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
return _mm512_mask_blend_ps(d, res, alt);
}
#endif
#if defined(__AVX2__) && defined(__FMA__)
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
inline __m256 v_expf(__m256 x) {
const __m256 r = _mm256_set1_ps(0x1.8p23f);
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
const __m256 n = _mm256_sub_ps(z, r);
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
const __m256 k = _mm256_castsi256_ps(
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
const __m256i c = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(126), _CMP_GT_OQ));
const __m256 u = _mm256_mul_ps(b, b);
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
_mm256_set1_ps(0x1.573e2ep-5f)), u,
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
_mm256_set1_ps(0x1.fffdb6p-2f))),
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
return _mm256_fmadd_ps(j, k, k);
const __m256i g = _mm256_and_si256(
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
_mm256_set1_epi32(0x82000000u));
const __m256 s1 =
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
const __m256i d = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(192), _CMP_GT_OQ));
return _mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
_mm256_andnot_ps(
_mm256_castsi256_ps(d),
_mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(c),
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
}
#endif
inline float prepare_softmax(int nc, float * sp, float scale, float slope, const char * mp, bool use_fp16) {
__m512 vscale = _mm512_set1_ps(scale);
__m512 vmax = _mm512_set1_ps(-INFINITY);
float scalar_max = -INFINITY;
if (mp) {
__m512 vslope = _mm512_set1_ps(slope);
if (use_fp16) {
const ggml_fp16_t * mp_f16 = (const ggml_fp16_t *)mp;
__m512 vmax1 = _mm512_set1_ps(-INFINITY);
for (int i = 0; i < nc/32; ++i) {
const __m512 m1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+0));
const __m512 m2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+1));
const __m512 x1 = _mm512_loadu_ps(sp + 32*i + 0);
const __m512 x2 = _mm512_loadu_ps(sp + 32*i + 16);
const __m512 y1 = _mm512_fmadd_ps(vslope, m1, _mm512_mul_ps(vscale, x1));
const __m512 y2 = _mm512_fmadd_ps(vslope, m2, _mm512_mul_ps(vscale, x2));
vmax = _mm512_max_ps(vmax , y1);
vmax1 = _mm512_max_ps(vmax1, y2);
_mm512_storeu_ps(sp + 32*i + 0, y1);
_mm512_storeu_ps(sp + 32*i + 16, y2);
}
vmax = _mm512_max_ps(vmax, vmax1);
for (int i = 32*(nc/32); i < nc; ++i) {
sp[i] = scale*sp[i] + slope*GGML_FP16_TO_FP32(mp_f16[i]);
scalar_max = std::max(scalar_max, sp[i]);
}
} else {
const float * mp_f32 = (const float *)mp;
for (int i = 0; i < nc/16; ++i) {
const __m512 m = _mm512_loadu_ps(mp_f32 + 16*i);
const __m512 x = _mm512_loadu_ps(sp + 16*i);
const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x));
vmax = _mm512_max_ps(vmax, y);
_mm512_storeu_ps(sp + 16*i, y);
}
for (int i = 16*(nc/16); i < nc; ++i) {
sp[i] = scale*sp[i] + slope*mp_f32[i];
scalar_max = std::max(scalar_max, sp[i]);
}
}
} else {
for (int i = 0; i < nc/16; ++i) {
const __m512 x = _mm512_loadu_ps(sp + 16*i);
const __m512 y = _mm512_mul_ps(vscale, x);
vmax = _mm512_max_ps(vmax, y);
_mm512_storeu_ps(sp + 16*i, y);
}
for (int i = 16*(nc/16); i < nc; ++i) {
sp[i] = scale*sp[i];
scalar_max = std::max(scalar_max, sp[i]);
}
}
float vector_max = _mm512_reduce_max_ps(vmax);
return std::max(scalar_max, vector_max);
}
inline float do_soft_max(int nc, float * sp, float max) {
__m512 vmax = _mm512_set1_ps(-max);
__m512 vsum = _mm512_setzero_ps();
for (int i = 0; i < nc/16; ++i) {
auto x = _mm512_loadu_ps(sp + 16*i);
auto y = v_expf(_mm512_add_ps(x, vmax));
vsum = _mm512_add_ps(vsum, y);
_mm512_storeu_ps(sp + 16*i, y);
}
float sum = _mm512_reduce_add_ps(vsum);
for (int i = 16*(nc/16); i < nc; ++i) {
float y = expf(sp[i] - max);
sum += y;
sp[i] = y;
}
return sum;
}
inline void do_scale(int nc, const float * sp, float * dp, float scale) {
auto vscale = _mm512_set1_ps(scale);
for (int i = 0; i < nc/16; ++i) {
auto x = _mm512_loadu_ps(sp + 16*i);
auto y = _mm512_mul_ps(x, vscale);
_mm512_storeu_ps(dp + 16*i, y);
}
for (int i = 16*(nc/16); i < nc; ++i) {
dp[i] = sp[i]*scale;
}
}
void softmax_extended(int nc, float * sp, float * dp, float scale, float slope, const char * mp, bool mask_is_fp16) {
auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16);
auto sum = do_soft_max(nc, sp, max);
do_scale(nc, sp, dp, 1.f/sum);
}
}
bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00,
int int_typeA, const void * A, long strideA,
int int_typeB, const void * B, long strideB,
float * C, long stride_C,
char * work_buffer, long work_size,
const char * mask, float scale, float slope,
int ith, int nth) {
constexpr int k_y_step = 5; // TODO: make this CPU dependent
#if defined(__cpp_lib_hardware_interference_size)
constexpr int k_cache_line = std::hardware_destructive_interference_size;
#else
constexpr int k_cache_line = 64;
#endif
auto typeA = ggml_type(int_typeA);
auto typeB = ggml_type(int_typeB);
if (!(typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) || !(typeB == GGML_TYPE_F16 || typeB == GGML_TYPE_F32)) {
return false;
}
int ny_per_thread = (Ny + nth - 1)/nth;
int first_y = ny_per_thread*ith;
if (first_y >= Ny) {
return true;
}
int n_cols_per_iter = std::min(ny_per_thread, k_y_step);
int needed_work_size_per_thread = k_cache_line*((n_cols_per_iter*Nx*sizeof(float) + k_cache_line - 1)/k_cache_line);
int needed_work_size = needed_work_size_per_thread*nth;
if (needed_work_size > work_size) {
return false;
}
std::array<mul_mat_t, k_y_step> funcs;
if (typeA == GGML_TYPE_F16) {
if (typeB == GGML_TYPE_F16) {
funcs[0] = mul_mat_fX_fY_T<1, ggml_half, ggml_half>;
funcs[1] = mul_mat_fX_fY_T<2, ggml_half, ggml_half>;
funcs[2] = mul_mat_fX_fY_T<3, ggml_half, ggml_half>;
funcs[3] = mul_mat_fX_fY_T<4, ggml_half, ggml_half>;
funcs[4] = mul_mat_fX_fY_T<5, ggml_half, ggml_half>;
} else {
funcs[0] = mul_mat_fX_fY_T<1, ggml_half, float>;
funcs[1] = mul_mat_fX_fY_T<2, ggml_half, float>;
funcs[2] = mul_mat_fX_fY_T<3, ggml_half, float>;
funcs[3] = mul_mat_fX_fY_T<4, ggml_half, float>;
funcs[4] = mul_mat_fX_fY_T<5, ggml_half, float>;
}
} else {
if (typeB == GGML_TYPE_F16) {
funcs[0] = mul_mat_fX_fY_T<1, float, ggml_half>;
funcs[1] = mul_mat_fX_fY_T<2, float, ggml_half>;
funcs[2] = mul_mat_fX_fY_T<3, float, ggml_half>;
funcs[3] = mul_mat_fX_fY_T<4, float, ggml_half>;
funcs[4] = mul_mat_fX_fY_T<5, float, ggml_half>;
} else {
funcs[0] = mul_mat_fX_fY_T<1, float, float>;
funcs[1] = mul_mat_fX_fY_T<2, float, float>;
funcs[2] = mul_mat_fX_fY_T<3, float, float>;
funcs[3] = mul_mat_fX_fY_T<4, float, float>;
funcs[4] = mul_mat_fX_fY_T<5, float, float>;
}
}
auto row_size_qx = strideA*ggml_type_size(typeA);
auto row_size_qy = strideB*ggml_type_size(typeB);
DataInfo info{(float *)(work_buffer + needed_work_size_per_thread*ith), (const char *)B + first_y*row_size_qy,
Nx*sizeof(float), row_size_qy, 0, 1, nullptr, 0};
C += first_y*stride_C;
const char * mp = mask ? mask + first_y*Nx*sizeof(ggml_half) : nullptr;
int n_step = (ny_per_thread + k_y_step - 1)/k_y_step;
for (int i_step = 1; i_step <= n_step; ++i_step) {
int this_ny = i_step*k_y_step <= ny_per_thread ? k_y_step : ny_per_thread - (i_step - 1)*k_y_step;
funcs[this_ny-1](ne00, A, row_size_qx, info, Nx);
// Now we need to compute the softmax and store the result in C
for (int iy = 0; iy < this_ny; ++iy) {
softmax_extended(Nx, info.s + iy*Nx, C, scale, slope, mp, true);
C += stride_C;
if (mp) mp += Nx*sizeof(ggml_half);
}
info.cy += row_size_qy*this_ny;
}
return true;
}
#else // IQK_IMPLEMENT
bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
@@ -5926,4 +6226,14 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const
return false;
}
bool iqk_fused_mul_mat_softmax(long, long, long,
int, const void *, long,
int, const void *, long,
float *, long,
char *, long,
const char *, float, float, uint32_t,
int, int) {
return false;
}
#endif

View File

@@ -21,6 +21,13 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeB, const void * B, long strideB,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C,
char * work_buffer, long work_size,
const char * mask, float scale, float slope,
int ith, int nth);
#ifdef __cplusplus
}