mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-11 08:50:11 +00:00
RoPE(Neox, Metal): don't use power functions in a loop
Speeds up Bitnet by ~2% on Metal.
This commit is contained in:
@@ -1674,7 +1674,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
static inline void rope_yarn(
|
||||
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
||||
thread float * cos_theta, thread float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
@@ -1828,35 +1828,69 @@ kernel void kernel_rope_neox(
|
||||
const float theta_base = (float) pos[i2];
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
|
||||
float theta = theta_base * pow(freq_base, 2*tiitg*inv_ndims);
|
||||
const float theta_multiplier = pow(freq_base, 2*tptg.x*inv_ndims);
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < n_dims) {
|
||||
const int64_t ic = i0/2;
|
||||
int64_t i0 = 2*tiitg;
|
||||
for ( ; i0 < n_dims; i0 += 2*tptg.x) {
|
||||
const int64_t ic = i0/2;
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
theta *= theta_multiplier;
|
||||
}
|
||||
for ( ; i0 < ne0; i0 += 2*tptg.x) {
|
||||
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
|
||||
// Original version
|
||||
//for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||
// if (i0 < n_dims) {
|
||||
// const int64_t ic = i0/2;
|
||||
|
||||
// // Who thought that having a pow() evaluation in a loop is a good idea?
|
||||
// //const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
// const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
// rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
// device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
// device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
// const float x0 = src[0];
|
||||
// const float x1 = src[n_dims/2];
|
||||
|
||||
// dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
// dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
// theta *= theta_multiplier;
|
||||
// } else {
|
||||
// device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
// device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
// dst_data[0] = src[0];
|
||||
// dst_data[1] = src[1];
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||
|
||||
Reference in New Issue
Block a user