Softcap: WIP

Fuses scale + tanh + scale as used for softcaping in some
models.

Just CPU for now. ~1.4% for PP-512 on Gemma2-9b, no effect on TG.

Somewhat surprisingly the improvement does not increase as I
go to longer contexts. Gemma2 does softcap on K*Q, which grows
quadratically with context length, so I would have thought
the benefit from fusing scale, tanh, scale would increase.
But no, no luck.
This commit is contained in:
Iwan Kawrakow
2024-08-01 20:32:28 +03:00
parent a73702d93b
commit c4951cbc35
3 changed files with 213 additions and 10 deletions

View File

@@ -514,6 +514,7 @@ extern "C" {
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,
GGML_OP_SOFTCAP,
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
@@ -1223,6 +1224,19 @@ extern "C" {
struct ggml_tensor * a,
float s);
GGML_API struct ggml_tensor * ggml_softcap(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s_before,
float s_after);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_softcap_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s_before,
float s_after);
// b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set(
struct ggml_context * ctx,

View File

@@ -2558,6 +2558,14 @@ inline static float32x4_t ggml_v_tanh(float32x4_t x) {
return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
}
inline static float32x4_t ggml_v_softcap(float32x4_t x, float s_before, float s_after) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f*s_before));
const float32x4_t exp_two_x = ggml_v_expf(two_x);
const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
return vmulq_f32(th, vdupq_n_f32(s_after));
}
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
// adapted from arm limited optimized routine
@@ -2607,6 +2615,13 @@ inline static __m512 ggml_v_tanh(__m512 x) {
return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
}
inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) {
const __m512 one = _mm512_set1_ps(1.0f);
const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, s_before));
const __m512 th = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
return _mm512_mul_ps(th, s_after);
}
#elif defined(__AVX2__) && defined(__FMA__)
// adapted from arm limited optimized routine
@@ -2668,6 +2683,13 @@ inline static __m256 ggml_v_tanh(__m256 x) {
return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
}
inline static __m256 ggml_v_softcap(__m256 x, float s_before, float s_after) {
const __m256 one = _mm256_set1_ps(1.0f);
const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before)));
const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
return _mm256_mul_ps(th, _mm256_set1_ps(s_after));
}
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
#if defined(__FMA__)
@@ -2728,6 +2750,13 @@ inline static __m128 ggml_v_tanh(__m128 x) {
return _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one));
}
inline static __m128 ggml_v_softcap(__m128 x, float s_before, float s_after) {
const __m128 one = _mm_set1_ps(1.0f);
const __m128 exp_two_x = ggml_v_expf(_mm_mul_ps(x, _mm_set1_ps(2.f*s_before)));
const __m128 th = _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one));
return _mm_mul_ps(th, _mm_set1_ps(s_after));
}
#endif // __ARM_NEON / __AVX2__ / __SSE2__
static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
@@ -2778,6 +2807,42 @@ static void ggml_vec_tanh_f32(const int n, float * y, const float * x) {
}
}
static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 vs_before = _mm512_set1_ps(2.f*s_before);
__m512 vs_after = _mm512_set1_ps(s_after);
//for (; i + 63 < n; i += 64) {
// __m512 x1 = _mm512_loadu_ps(x + i);
// __m512 x2 = _mm512_loadu_ps(x + i + 16);
// __m512 x3 = _mm512_loadu_ps(x + i + 32);
// __m512 x4 = _mm512_loadu_ps(x + i + 48);
// _mm512_storeu_ps(x + i + 0, ggml_v_softcap(x1, vs_before, vs_after));
// _mm512_storeu_ps(x + i + 16, ggml_v_softcap(x2, vs_before, vs_after));
// _mm512_storeu_ps(x + i + 32, ggml_v_softcap(x3, vs_before, vs_after));
// _mm512_storeu_ps(x + i + 48, ggml_v_softcap(x4, vs_before, vs_after));
//}
for (; i + 15 < n; i += 16) {
_mm512_storeu_ps(x + i, ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after));
}
#elif defined(__AVX2__) && defined(__FMA__)
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after));
}
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), s_before, s_after));
}
#endif
for (; i < n; ++i) {
x[i] = s_after*tanhf(x[i]*s_before);
}
}
static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
int i = 0;
ggml_float sum = 0;
@@ -2968,6 +3033,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TIMESTEP_EMBEDDING",
"ARGSORT",
"LEAKY_RELU",
"SOFTCAP",
"FLASH_ATTN_EXT",
"FLASH_ATTN_BACK",
@@ -2995,7 +3061,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3056,6 +3122,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"leaky_relu(x)",
"k2*tanh(k1*x)",
"flash_attn_ext(x)",
"flash_attn_back(x)",
@@ -3083,7 +3150,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5742,6 +5809,50 @@ struct ggml_tensor * ggml_scale_inplace(
return ggml_scale_impl(ctx, a, s, true);
}
// ggml_softcap
static struct ggml_tensor * ggml_softcap_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s_before,
float s_after,
bool inplace) {
GGML_ASSERT(ggml_is_padded_1d(a));
bool is_node = false;
if (a->grad) {
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
float params[2] = {s_before, s_after};
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_SOFTCAP;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_softcap(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s_before,
float s_after) {
return ggml_softcap_impl(ctx, a, s_before, s_after, false);
}
struct ggml_tensor * ggml_softcap_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s_before,
float s_after) {
return ggml_softcap_impl(ctx, a, s_before, s_after, true);
}
// ggml_set
static struct ggml_tensor * ggml_set_impl(
@@ -13324,6 +13435,71 @@ static void ggml_compute_forward_scale(
}
}
// ggml_compute_forward_softcap
static void ggml_compute_forward_softcap_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
// scale factor
float val[2];
memcpy(val, dst->op_params, sizeof(val));
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const size_t nb01 = src0->nb[1];
const size_t nb1 = dst->nb[1];
for (int i1 = ir0; i1 < ir1; i1++) {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
}
// TODO: better implementation
float * row = (float *) ((char *) dst->data + i1*nb1);
ggml_vec_softcap_f32(nc, row, val[0], val[1]);
//ggml_vec_scale_f32(nc, row, val[0]);
//ggml_vec_tanh_f32(nc, row, row);
//ggml_vec_scale_f32(nc, row, val[1]);
}
}
static void ggml_compute_forward_softcap(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_softcap_f32(params, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_set
static void ggml_compute_forward_set_f32(
@@ -17175,6 +17351,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_scale(params, tensor);
} break;
case GGML_OP_SOFTCAP:
{
ggml_compute_forward_softcap(params, tensor);
} break;
case GGML_OP_SET:
{
ggml_compute_forward_set(params, tensor);
@@ -17917,6 +18097,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
case GGML_OP_SOFTCAP:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_SET:
{
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -18928,6 +19112,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
n_tasks = 1; //TODO
} break;
case GGML_OP_SCALE:
case GGML_OP_SOFTCAP:
case GGML_OP_SOFT_MAX:
{
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));

View File

@@ -8317,14 +8317,17 @@ static struct ggml_tensor * llm_build_kqv(
//try from phi2
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
kq = ggml_scale(ctx, kq, 30);
//kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
//kq = ggml_scale(ctx, kq, 30);
kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f);
}
if (hparams.attn_soft_cap) {
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
kq = ggml_tanh(ctx, kq);
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
//kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
//kq = ggml_tanh(ctx, kq);
//kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
}
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
@@ -11935,9 +11938,10 @@ struct llm_build_context {
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
// final logit soft-capping
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping);
//cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
//cur = ggml_tanh(ctx0, cur);
//cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
cb(cur, "result_output", -1);