mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-14 15:57:37 +00:00
Introducing rope cache
When computing RoPE, the rotation angles in each layer are exactly the same, and only depend on the token positions (and other constant, model dependent parameters). So, I wonder, why don't we compute the angles just once and then reuse for the Q and K RoPE in each layer? This commit does it as a POC on the CPU, and uses it in the Qwen3-MoE compute graph.
This commit is contained in:
@@ -639,6 +639,8 @@ extern "C" {
|
||||
GGML_OP_SOFT_MAX_BACK,
|
||||
GGML_OP_ROPE,
|
||||
GGML_OP_ROPE_BACK,
|
||||
GGML_OP_ROPE_CACHE,
|
||||
GGML_OP_ROPE_FAST,
|
||||
GGML_OP_CLAMP,
|
||||
GGML_OP_CONV_TRANSPOSE_1D,
|
||||
GGML_OP_IM2COL,
|
||||
@@ -2020,6 +2022,26 @@ extern "C" {
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rope_cache(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int ne0,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rope_fast(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// clamp
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_clamp(
|
||||
|
||||
276
ggml/src/ggml.c
276
ggml/src/ggml.c
@@ -4242,6 +4242,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"SOFT_MAX_BACK",
|
||||
"ROPE",
|
||||
"ROPE_BACK",
|
||||
"ROPE_CACHE",
|
||||
"ROPE_FAST",
|
||||
"CLAMP",
|
||||
"CONV_TRANSPOSE_1D",
|
||||
"IM2COL",
|
||||
@@ -4290,7 +4292,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
|
||||
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -4347,6 +4349,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"soft_max_back(x)",
|
||||
"rope(x)",
|
||||
"rope_back(x)",
|
||||
"rope_cache(pos)",
|
||||
"rope_fast(x)",
|
||||
"clamp(x)",
|
||||
"conv_transpose_1d(x)",
|
||||
"im2col(x)",
|
||||
@@ -4395,7 +4399,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"glu(x),"
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
|
||||
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -8664,6 +8668,80 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||
|
||||
// ggml_rope
|
||||
|
||||
struct ggml_tensor * ggml_rope_cache(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int ne0,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
||||
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
|
||||
bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
|
||||
GGML_ASSERT(!mrope_used);
|
||||
//if (mrope_used) {
|
||||
// GGML_ASSERT(ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
|
||||
//} else {
|
||||
// GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
//}
|
||||
|
||||
if (c) {
|
||||
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne0, b->ne[0]);
|
||||
|
||||
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||
memcpy(params + 5, &freq_base, sizeof(float));
|
||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||
memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
|
||||
//if (mrope_used) {
|
||||
// memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
|
||||
//} else {
|
||||
// memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
|
||||
//}
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_ROPE_CACHE;
|
||||
result->src[0] = b;
|
||||
result->src[1] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_fast(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
GGML_ASSERT(a->ne[0] <= b->ne[0]);
|
||||
GGML_ASSERT(a->ne[2] <= b->ne[1]);
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_ROPE_FAST;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_rope_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -8684,6 +8762,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
printf("%s: b->ne[0] = %ld\n", __func__, b->ne[0]);
|
||||
|
||||
bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
|
||||
if (mrope_used) {
|
||||
@@ -18396,6 +18475,181 @@ static void ggml_mrope_cache_init(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_cache_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
const bool forward) {
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
int sections[4];
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
|
||||
const struct ggml_tensor * tpos = dst->src[0];
|
||||
GGML_ASSERT(tpos->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(tpos->ne[0] == dst->ne[1]);
|
||||
|
||||
GGML_ASSERT(n_dims <= dst->ne[0]);
|
||||
GGML_ASSERT(n_dims % 2 == 0);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == dst->ne[0]);
|
||||
}
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (dst->src[1] != NULL) {
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[1]->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) dst->src[1]->data;
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
||||
// this essentially just switches the sign of sin.
|
||||
const float sin_sign = forward ? 1.0f : -1.0f;
|
||||
|
||||
const int32_t * pos = (const int32_t *) tpos->data;
|
||||
|
||||
int ith = params->ith;
|
||||
int nth = params->nth;
|
||||
const int npt = (dst->ne[1] + nth - 1)/nth;
|
||||
|
||||
int first = npt*ith;
|
||||
int last = MIN(dst->ne[1], first + npt);
|
||||
|
||||
int64_t ne0 = dst->ne[0];
|
||||
int64_t ne2 = dst->ne[1];
|
||||
|
||||
for (int i1 = first; i1 < last; ++i1) {
|
||||
float * cache = (float *)((char *)dst->data + dst->nb[1]*i1);
|
||||
if (!is_mrope) {
|
||||
const int64_t p = pos[i1];
|
||||
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
else {
|
||||
const int64_t p_t = pos[i1];
|
||||
const int64_t p_h = pos[i1 + ne2];
|
||||
const int64_t p_w = pos[i1 + ne2 * 2];
|
||||
const int64_t p_e = pos[i1 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_fast_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->ne[0] <= src1->ne[0]);
|
||||
GGML_ASSERT(src0->ne[2] <= src1->ne[1]);
|
||||
|
||||
const int n_dims = ((const int32_t *) src1->op_params)[1];
|
||||
const int mode = ((const int32_t *) src1->op_params)[2];
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nrows = ggml_nrows(src0);
|
||||
const int npt = (nrows + nth - 1)/nth;
|
||||
const int first = ith*npt;
|
||||
const int last = MIN(first + npt, nrows);
|
||||
|
||||
const int ne02 = src0->ne[2];
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne00 = src0->ne[0];
|
||||
|
||||
for (int ir = first; ir < last; ++ir) {
|
||||
const int i3 = ir/(ne01*ne02);
|
||||
const int i2 = (ir - i3*ne01*ne02)/ne01;
|
||||
const int i1 = ir - i3*ne01*ne02 - i2*ne01;
|
||||
const float * c = (const float *)((const char *)src1->data + i2*src1->nb[1]);
|
||||
const float * x = (const float *)((const char *)src0->data + i1*src0->nb[1] + i2*src0->nb[2] + i3*src0->nb[3]);
|
||||
float * y = ( float *)(( char *)dst->data + i1* dst->nb[1] + i2* dst->nb[2] + i3* dst->nb[3]);
|
||||
if (is_neox || is_mrope) {
|
||||
const int n_gap = is_vision ? n_dims : n_dims/2;
|
||||
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = c[i0 + 0];
|
||||
const float sin_theta = c[i0 + 1];
|
||||
|
||||
const float x0 = x[ic];
|
||||
const float x1 = x[ic+n_gap];
|
||||
|
||||
y[ic ] = x0*cos_theta - x1*sin_theta;
|
||||
y[ic+n_gap] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const float cos_theta = c[i0 + 0];
|
||||
const float sin_theta = c[i0 + 1];
|
||||
|
||||
const float x0 = x[i0+0];
|
||||
const float x1 = x[i0+1];
|
||||
|
||||
y[i0+0] = x0*cos_theta - x1*sin_theta;
|
||||
y[i0+1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
for (int i0 = n_dims; i0 < ne00; i0 += 2) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float cos_theta = c[i0 + 0];
|
||||
const float sin_theta = c[i0 + 1];
|
||||
|
||||
const float x0 = x[ic];
|
||||
const float x1 = x[ic+n_dims];
|
||||
|
||||
y[ic] = x0*cos_theta - x1*sin_theta;
|
||||
y[ic+n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
// fill the remain channels with data from src tensor
|
||||
for (int i0 = n_dims; i0 < ne00; i0 += 2) {
|
||||
y[i0+0] = x[i0+0];
|
||||
y[i0+1] = x[i0+1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
@@ -22584,6 +22838,14 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
{
|
||||
ggml_compute_forward_rope_back(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_ROPE_CACHE:
|
||||
{
|
||||
ggml_compute_forward_rope_cache_f32(params, tensor, true);
|
||||
} break;
|
||||
case GGML_OP_ROPE_FAST:
|
||||
{
|
||||
ggml_compute_forward_rope_fast_f32(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CLAMP:
|
||||
{
|
||||
ggml_compute_forward_clamp(params, tensor);
|
||||
@@ -23635,6 +23897,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ROPE_CACHE:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
}
|
||||
case GGML_OP_ROPE_FAST:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
}
|
||||
case GGML_OP_GLU:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
@@ -24408,6 +24678,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_ROPE_CACHE:
|
||||
case GGML_OP_ROPE_FAST:
|
||||
case GGML_OP_ADD_REL_POS:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
@@ -3468,6 +3468,9 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
auto rope_cache = ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
ggml_set_input(rope_cache);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
@@ -3483,18 +3486,20 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
|
||||
model.layers[il].wq, nullptr, model.layers[il].wk, nullptr, model.layers[il].wv, nullptr,
|
||||
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
|
||||
//Qcur = ggml_rope_ext(
|
||||
// ctx0, Qcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow
|
||||
// );
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
|
||||
//Kcur = ggml_rope_ext(
|
||||
// ctx0, Kcur, inp_pos, nullptr,
|
||||
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
// ext_factor, attn_factor, beta_fast, beta_slow
|
||||
// );
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
|
||||
Reference in New Issue
Block a user