RoPE(Neox, Metal): don't use power functions in a loop

Speeds up Bitnet by ~2% on Metal.
This commit is contained in:
Iwan Kawrakow
2024-06-26 11:22:47 +02:00
parent 767bce7caf
commit 641dd6bc68

View File

@@ -1674,7 +1674,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
static inline void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
thread float * cos_theta, thread float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
@@ -1828,35 +1828,69 @@ kernel void kernel_rope_neox(
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/n_dims;
float theta = theta_base * pow(freq_base, 2*tiitg*inv_ndims);
const float theta_multiplier = pow(freq_base, 2*tptg.x*inv_ndims);
float cos_theta;
float sin_theta;
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
if (i0 < n_dims) {
const int64_t ic = i0/2;
int64_t i0 = 2*tiitg;
for ( ; i0 < n_dims; i0 += 2*tptg.x) {
const int64_t ic = i0/2;
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
theta *= theta_multiplier;
}
for ( ; i0 < ne0; i0 += 2*tptg.x) {
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
// Original version
//for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
// if (i0 < n_dims) {
// const int64_t ic = i0/2;
// // Who thought that having a pow() evaluation in a loop is a good idea?
// //const float theta = theta_base * pow(freq_base, inv_ndims*i0);
// const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
// rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
// device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
// device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
// const float x0 = src[0];
// const float x1 = src[n_dims/2];
// dst_data[0] = x0*cos_theta - x1*sin_theta;
// dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
// theta *= theta_multiplier;
// } else {
// device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
// device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
// dst_data[0] = src[0];
// dst_data[1] = src[1];
// }
//}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;