mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 03:41:53 +00:00
RoPE(Neox, Metal): don't use power functions in a loop
Speeds up Bitnet by ~2% on Metal.
This commit is contained in:
@@ -1674,7 +1674,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|||||||
|
|
||||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||||
static void rope_yarn(
|
static inline void rope_yarn(
|
||||||
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
||||||
thread float * cos_theta, thread float * sin_theta) {
|
thread float * cos_theta, thread float * sin_theta) {
|
||||||
// Get n-d rotational scaling corrected for extrapolation
|
// Get n-d rotational scaling corrected for extrapolation
|
||||||
@@ -1828,35 +1828,69 @@ kernel void kernel_rope_neox(
|
|||||||
const float theta_base = (float) pos[i2];
|
const float theta_base = (float) pos[i2];
|
||||||
const float inv_ndims = -1.f/n_dims;
|
const float inv_ndims = -1.f/n_dims;
|
||||||
|
|
||||||
|
float theta = theta_base * pow(freq_base, 2*tiitg*inv_ndims);
|
||||||
|
const float theta_multiplier = pow(freq_base, 2*tptg.x*inv_ndims);
|
||||||
|
|
||||||
float cos_theta;
|
float cos_theta;
|
||||||
float sin_theta;
|
float sin_theta;
|
||||||
|
|
||||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
int64_t i0 = 2*tiitg;
|
||||||
if (i0 < n_dims) {
|
for ( ; i0 < n_dims; i0 += 2*tptg.x) {
|
||||||
const int64_t ic = i0/2;
|
const int64_t ic = i0/2;
|
||||||
|
|
||||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||||
|
|
||||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||||
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||||
|
|
||||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
const float x0 = src[0];
|
||||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
const float x1 = src[n_dims/2];
|
||||||
|
|
||||||
const float x0 = src[0];
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
const float x1 = src[n_dims/2];
|
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||||
|
|
||||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
theta *= theta_multiplier;
|
||||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
||||||
} else {
|
|
||||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
||||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
||||||
|
|
||||||
dst_data[0] = src[0];
|
|
||||||
dst_data[1] = src[1];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
for ( ; i0 < ne0; i0 += 2*tptg.x) {
|
||||||
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
dst_data[0] = src[0];
|
||||||
|
dst_data[1] = src[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Original version
|
||||||
|
//for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||||
|
// if (i0 < n_dims) {
|
||||||
|
// const int64_t ic = i0/2;
|
||||||
|
|
||||||
|
// // Who thought that having a pow() evaluation in a loop is a good idea?
|
||||||
|
// //const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||||
|
|
||||||
|
// const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||||
|
|
||||||
|
// rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||||
|
|
||||||
|
// device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||||
|
// device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||||
|
|
||||||
|
// const float x0 = src[0];
|
||||||
|
// const float x1 = src[n_dims/2];
|
||||||
|
|
||||||
|
// dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
|
// dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||||
|
|
||||||
|
// theta *= theta_multiplier;
|
||||||
|
// } else {
|
||||||
|
// device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
|
// device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
// dst_data[0] = src[0];
|
||||||
|
// dst_data[1] = src[1];
|
||||||
|
// }
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||||
|
|||||||
Reference in New Issue
Block a user