Add mtmd: refresh CPU rope

This commit is contained in:
Iwan Kawrakow
2025-09-25 18:49:25 +03:00
parent 879201c26d
commit 293a59d4f1

View File

@@ -17814,6 +17814,22 @@ static void ggml_compute_forward_clamp(
// ggml_compute_forward_rope
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}
void ggml_rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
dims[0] = MAX(0, start);
dims[1] = MIN(n_dims - 1, end);
}
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
return 1 - MIN(1, MAX(0, y));
@@ -17838,12 +17854,6 @@ static void rope_yarn(
*sin_theta = sinf(theta) * mscale;
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}
static void ggml_rope_cache_init(
float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
@@ -17860,14 +17870,62 @@ static void ggml_rope_cache_init(
}
}
GGML_CALL void ggml_rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
dims[0] = MAX(0, start);
dims[1] = MIN(n_dims - 1, end);
static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta_t = theta_base_t;
float theta_h = theta_base_h;
float theta_w = theta_base_w;
float theta_e = theta_base_e; // extra position id for vision encoder
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
GGML_ASSERT(sect_dims <= ne0);
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
int sector = (i0 / 2) % sect_dims;
if (indep_sects) {
// compute theta independently for each dim sections
// (i.e. reset corresponding theta when `i0` go from one section to another)
if (sector == 0) {
theta_t = theta_base_t;
}
else if (sector == sections[0]) {
theta_h = theta_base_h;;
}
else if (sector == sec_w) {
theta_w = theta_base_w;
}
else if (sector == sec_e) {
theta_e = theta_base_e;
}
}
float theta = theta_t;
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
rope_yarn(
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
);
cache[i0 + 1] *= sin_sign;
theta_t *= theta_scale;
theta_w *= theta_scale;
theta_h *= theta_scale;
theta_e *= theta_scale;
}
}
static void ggml_compute_forward_rope_f32(
@@ -17880,6 +17938,7 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src2 = dst->src[2];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -17893,6 +17952,7 @@ static void ggml_compute_forward_rope_f32(
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
GGML_TENSOR_UNARY_OP_LOCALS
@@ -17924,7 +17984,17 @@ static void ggml_compute_forward_rope_f32(
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2;
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
const float * freq_factors = NULL;
if (src2 != NULL) {
@@ -17940,18 +18010,63 @@ static void ggml_compute_forward_rope_f32(
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
if (!is_mrope) {
const int64_t p = pos[i2];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
else {
const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + ne2];
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
for (int64_t i1 = 0; i1 < ne1; i1++) {
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
if (ir++ < ir0) continue;
if (ir > ir1) break;
if (!is_neox) {
if (is_neox || is_mrope) {
if (is_vision){
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
}
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
@@ -17965,8 +18080,10 @@ static void ggml_compute_forward_rope_f32(
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
}
if (is_vision) {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
@@ -17976,19 +18093,20 @@ static void ggml_compute_forward_rope_f32(
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
}
} else {
// fill the remain channels with data from src tensor
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
}
@@ -18006,6 +18124,7 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_tensor * src2 = dst->src[2];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -18018,6 +18137,8 @@ static void ggml_compute_forward_rope_f16(
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
GGML_TENSOR_UNARY_OP_LOCALS
@@ -18049,7 +18170,17 @@ static void ggml_compute_forward_rope_f16(
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & 2;
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
const float * freq_factors = NULL;
if (src2 != NULL) {
@@ -18067,16 +18198,61 @@ static void ggml_compute_forward_rope_f16(
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
if (!is_mrope) {
const int64_t p = pos[i2];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
else {
const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + ne2];
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
if (!is_neox) {
if (is_neox || is_mrope) {
if (is_vision) {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
@@ -18090,8 +18266,10 @@ static void ggml_compute_forward_rope_f16(
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
}
if (is_vision) {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
@@ -18101,19 +18279,19 @@ static void ggml_compute_forward_rope_f16(
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
} else {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
}