RoPE cache (#887)

* Introducing rope cache

When computing RoPE, the rotation angles in each layer
are exactly the same, and only depend on the token positions
(and other constant, model dependent parameters).
So, I wonder, why don't we compute the angles just once
and then reuse for the Q and K RoPE in each layer?

This commit does it as a POC on the CPU, and uses it in
the Qwen3-MoE compute graph.

* cuda: neox works

* WIP

* rope_cache: norm works

* Fused rope+rope

* Fused rope+rope (norm)

* Fused rms+rms+rope+rope (neox) - not working

* WIP

* Also qwen3

* Add command line arg to disable rope cache

* Disable RoPE cache if rope type is not neox or norm

* Add missing break after merge with main

* Fused fused_rms+fused_rms+rope+rope (with -mqkv)

* Fused fused_rms+fused_rms+rope+rope (without -mqkv)

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-03 18:42:20 +02:00
committed by GitHub
parent 846e736e85
commit fb0d5a995c
12 changed files with 1002 additions and 72 deletions

View File

@@ -639,6 +639,8 @@ extern "C" {
GGML_OP_SOFT_MAX_BACK,
GGML_OP_ROPE,
GGML_OP_ROPE_BACK,
GGML_OP_ROPE_CACHE,
GGML_OP_ROPE_FAST,
GGML_OP_CLAMP,
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL,
@@ -2020,6 +2022,26 @@ extern "C" {
float beta_fast,
float beta_slow);
GGML_API struct ggml_tensor * ggml_rope_cache(
struct ggml_context * ctx,
struct ggml_tensor * b,
struct ggml_tensor * c,
int ne0,
int n_dims,
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
GGML_API struct ggml_tensor * ggml_rope_fast(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// clamp
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_clamp(

View File

@@ -3062,6 +3062,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr;
//printf("%4d %s(%s)\n", i, ggml_op_name(dst->op), dst->name);
switch (dst->op) {
case GGML_OP_ARGMAX:
ggml_cuda_argmax(ctx, dst);
@@ -3096,7 +3097,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_are_same_shape(dst, cgraph->nodes[i+1]->src[1]) &&
cgraph->nodes[i+1] == cgraph->nodes[i+2]->src[0] &&
ops_are_same_device(cgraph, i, i+2)) {
//printf("Fusing add->add->fused_rms of %s, %s, %s\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name);
ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
}
@@ -3244,7 +3244,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_rms_norm(ctx, dst);
break;
case GGML_OP_FUSED_RMS_NORM:
if (i + 2 < cgraph->n_nodes &&
//if (i + 6 < cgraph->n_nodes) {
// printf("=== Fused rms_norm(%s)\n", dst->name);
// for (int j = 1; j <= 6; ++j) printf(" %s(%s)\n", ggml_op_name(cgraph->nodes[i+j]->op), cgraph->nodes[i+j]->name);
//}
if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST &&
cgraph->nodes[i+4]->op == GGML_OP_ROPE_FAST &&
ggml_cuda_op_fused_rms_rope_fast(ctx, cgraph->nodes[i+3], cgraph->nodes[i+4])) {
i += 4;
}
else if (ENABLE_FUSION && i + 4 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_ROPE_FAST &&
cgraph->nodes[i+2]->op == GGML_OP_RESHAPE &&
cgraph->nodes[i+3]->op == GGML_OP_FUSED_RMS_NORM &&
cgraph->nodes[i+4]->op == GGML_OP_ROPE_FAST &&
ggml_cuda_op_fused_rms_rope_fast(ctx, cgraph->nodes[i+1], cgraph->nodes[i+4])) {
i += 4;
}
else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
dst->ne[2] == 1 && cgraph->nodes[i+2]->ne[2] == 1) {
@@ -3318,6 +3338,32 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ROPE_BACK:
ggml_cuda_op_rope_back(ctx, dst);
break;
case GGML_OP_ROPE_FAST:
if (ENABLE_FUSION && i + 3 < cgraph->n_nodes &&
(cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) &&
(cgraph->nodes[i+2]->op == GGML_OP_RESHAPE || cgraph->nodes[i+2]->op == GGML_OP_VIEW) &&
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST &&
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+3])) {
i += 3;
}
else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
(cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) &&
cgraph->nodes[i+2]->op == GGML_OP_ROPE_FAST &&
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+2])) {
i += 2;
}
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_ROPE_FAST &&
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+1])) {
i += 1;
}
else {
ggml_cuda_op_rope_fast(ctx, dst);
}
break;
case GGML_OP_ROPE_CACHE:
ggml_cuda_op_rope_cache(ctx, dst);
break;
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
@@ -4377,6 +4423,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_SOFT_CAP_MAX:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_ROPE_FAST:
case GGML_OP_ROPE_CACHE:
return true;
//case GGML_OP_ROPE:
// return ggml_is_contiguous(op->src[0]);

View File

@@ -121,6 +121,226 @@ static __global__ void rope_neox(
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
static __global__ void rope_neox_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int nelem,
int s01, int s02, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= nelem) {
return;
}
//i = i0 + i1*ne0 + i2*ne0*ne1;
int i2 = i / (ne0*ne1); i -= i2*ne0*ne1;
int i1 = i / ne0;
int i0 = i - i1*ne0;
const int idst = i2*ne0*ne1 + i1*ne0 + i0/2;
const int ix = i2*s02 + i1*s01 + i0/2;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = src0[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = src0[ix + i0/2 + 1];
return;
}
const float x0 = src0[ix + 0];
const float x1 = src0[ix + n_dims/2];
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
static __global__ void fused_rope_neox_fast(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int nelem1, int nelem,
int s01_1, int s02_1, int s01_2, int s02_2, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= nelem) {
return;
}
const float * src0;
float * dst;
int ne1, s01, s02;
if (i < nelem1) {
src0 = src0_1;
dst = dst_1;
ne1 = ne1_1;
s01 = s01_1;
s02 = s02_1;
} else {
i -= nelem1;
src0 = src0_2;
dst = dst_2;
ne1 = ne1_2;
s01 = s01_2;
s02 = s02_2;
}
int i2 = i / (ne0*ne1); i -= i2*ne0*ne1;
int i1 = i / ne0;
int i0 = i - i1*ne0;
const int idst = i2*ne0*ne1 + i1*ne0 + i0/2;
const int ix = i2*s02 + i1*s01 + i0/2;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = src0[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = src0[ix + i0/2 + 1];
return;
}
const float x0 = src0[ix + 0];
const float x1 = src0[ix + n_dims/2];
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
static __global__ void fused_rms_rope_neox_fast(const float * src0_1, const float * src0_2, const float * src1,
const float * c_1, const float * c_2,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2,
int s01_1, int s02_1, int s01_2, int s02_2, int n_dims, float eps) {
int i0 = 2*threadIdx.y;
int i2 = blockIdx.x*blockDim.x + threadIdx.x;
int i1 = blockIdx.z*blockDim.z + threadIdx.z;
const float * src0, *c;
float * dst;
int ne1, s01, s02;
if (i1 < ne1_1) {
ne1 = ne1_1;
s01 = s01_1; s02 = s02_1;
src0 = src0_1 + i1*s01 + i2*s02;
dst = dst_1 + ne0*(i1 + i2*ne1);
c = c_1;
} else {
i1 -= ne1_1;
ne1 = ne1_2;
s01 = s01_2; s02 = s02_2;
src0 = src0_2 + i1*s01 + i2*s02;
dst = dst_2 + ne0*(i1 + i2*ne1);
c = c_2;
}
float sum = i0 < ne0 ? src0[i0]*src0[i0] + src0[i0+1]*src0[i0+1] : 0.0f;
sum = warp_reduce_sum(sum);
if constexpr (CUDA_ROPE_BLOCK_SIZE > WARP_SIZE) {
__shared__ float s_sum[WARP_SIZE];
int warp_id = (i0/2) / WARP_SIZE;
int lane_id = (i0/2) % WARP_SIZE;
if (lane_id == 0) s_sum[warp_id] = sum;
__syncthreads();
sum = lane_id < CUDA_ROPE_BLOCK_SIZE / WARP_SIZE ? s_sum[lane_id] : 0;
sum = warp_reduce_sum(sum);
}
float norm = rsqrtf(sum/ne0 + eps);
if (i0 >= ne0) return;
if (i0 >= n_dims) {
dst[i0 + 0] = norm*c[i0 + 0]*src0[i0 + 0];
dst[i0 + 1] = norm*c[i0 + 1]*src0[i0 + 1];
return;
}
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
const float x0 = norm*c[i0/2 + 0]*src0[i0/2 + 0];
const float x1 = norm*c[i0/2 + n_dims/2]*src0[i0/2 + n_dims/2];
dst[i0/2 + 0] = x0*cos_theta - x1*sin_theta;
dst[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
static __global__ void rope_norm_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int nelem,
int s01, int s02, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= nelem) {
return;
}
int i2 = i / (ne0*ne1); i -= i2*ne0*ne1;
int i1 = i / ne0;
int i0 = i - i1*ne0;
const int idst = i2*ne0*ne1 + i1*ne0 + i0;
const int ix = i2*s02 + i1*s01 + i0;
if (i0 >= n_dims) {
dst[idst + 0] = src0[ix + 0];
dst[idst + 1] = src0[ix + 1];
return;
}
const float x0 = src0[ix + 0];
const float x1 = src0[ix + 1];
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
}
static __global__ void fused_rope_norm_fast(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int nelem1, int nelem,
int s01_1, int s02_1, int s01_2, int s02_2, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= nelem) {
return;
}
const float * src0;
float * dst;
int ne1, s01, s02;
if (i < nelem1) {
src0 = src0_1;
dst = dst_1;
ne1 = ne1_1;
s01 = s01_1;
s02 = s02_1;
} else {
i -= nelem1;
src0 = src0_2;
dst = dst_2;
ne1 = ne1_2;
s01 = s01_2;
s02 = s02_2;
}
int i2 = i / (ne0*ne1); i -= i2*ne0*ne1;
int i1 = i / ne0;
int i0 = i - i1*ne0;
const int idst = i2*ne0*ne1 + i1*ne0 + i0;
const int ix = i2*s02 + i1*s01 + i0;
if (i0 >= n_dims) {
dst[idst + 0] = src0[ix + 0];
dst[idst + 1] = src0[ix + 1];
return;
}
const float x0 = src0[ix + 0];
const float x1 = src0[ix + 1];
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
}
template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
@@ -272,6 +492,84 @@ static void rope_neox_cuda(
}
}
static void rope_neox_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
int n_dims, cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne01*ne02*ne02, s01, s02, n_dims);
}
static void fused_rope_neox_fast_cuda(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int ne2, int s01_1, int s02_1, int s01_2, int s02_2,
int n_dims, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int nelem1 = ne0*ne1_1*ne2;
const int nelem2 = ne0*ne1_2*ne2;
const int nelem = nelem1 + nelem2;
const int n_blocks = (nelem + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
fused_rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0_1, src0_2, src1, dst_1, dst_2, ne0, ne1_1, ne1_2, nelem1, nelem,
s01_1, s02_1, s01_2, s02_2, n_dims);
}
static void fused_rms_rope_neox_fast_cuda(const float * src0_1, const float * src0_2, const float * src1,
const float * c_1, const float * c_2,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int ne2, int s01_1, int s02_1, int s01_2, int s02_2,
int n_dims, float eps, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
GGML_ASSERT(ne0 <= 2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const dim3 block_nums(ne2, 1, ne1_1 + ne1_2);
fused_rms_rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0_1, src0_2, src1, c_1, c_2, dst_1, dst_2, ne0, ne1_1, ne1_2,
s01_1, s02_1, s01_2, s02_2, n_dims, eps);
}
static void fused_rope_norm_fast_cuda(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int ne2, int s01_1, int s02_1, int s01_2, int s02_2,
int n_dims, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int nelem1 = ne0*ne1_1*ne2;
const int nelem2 = ne0*ne1_2*ne2;
const int nelem = nelem1 + nelem2;
const int n_blocks = (nelem + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
fused_rope_norm_fast<<<block_nums, block_dims, 0, stream>>>(src0_1, src0_2, src1, dst_1, dst_2, ne0, ne1_1, ne1_2, nelem1, nelem,
s01_1, s02_1, s01_2, s02_2, n_dims);
}
static void rope_norm_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
int n_dims, cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
rope_norm_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne01*ne02*ne02, s01, s02, n_dims);
}
static void rope_multi_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
int n_dims, cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
// TODO
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne02, s01, s02, n_dims);
}
static void rope_vision_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
int n_dims, cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
// TODO
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne02, s01, s02, n_dims);
}
template<bool forward, typename T>
static void rope_multi_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
@@ -448,3 +746,270 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_impl<false>(ctx, dst);
}
template <bool forward, bool has_ff>
static __global__ void k_rope_cache(int nelem, int ne0, float * dst, const int * pos, const float * freq_factors,
float theta_scale, float freq_scale, rope_corr_dims corr_dims, float ext_factor, float attn_factor) {
int i = 2*(blockIdx.x*blockDim.x + threadIdx.x);
if (i >= nelem) {
return;
}
int i2 = i / ne0;
int i0 = i % ne0;
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, dst[i], dst[i+1]);
if constexpr (!forward) {
dst[i+1] *= -1;
}
}
template <bool forward>
void ggml_cuda_op_rope_cache_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
const struct ggml_tensor * tpos = dst->src[0];
GGML_ASSERT(tpos->type == GGML_TYPE_I32);
GGML_ASSERT(tpos->ne[0] == dst->ne[1]);
GGML_ASSERT(n_dims <= dst->ne[0]);
GGML_ASSERT(n_dims % 2 == 0);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == dst->ne[0]);
}
const float * freq_factors = NULL;
if (dst->src[1] != NULL) {
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->ne[0] >= n_dims / 2);
freq_factors = (const float *) dst->src[1]->data;
}
const int * pos = (const int *) dst->src[0]->data;
if (dst->src[1]!= nullptr) {
freq_factors = (const float *) dst->src[1]->data;
}
int nelem = ggml_nelements(dst);
int nblocks = (nelem + 2*CUDA_ROPE_BLOCK_SIZE - 1)/(2*CUDA_ROPE_BLOCK_SIZE);
if (freq_factors) {
k_rope_cache<true, true ><<<nblocks, CUDA_ROPE_BLOCK_SIZE, 0, ctx.stream()>>>(ggml_nelements(dst), dst->ne[0],
(float *)dst->data, pos, freq_factors, theta_scale, freq_scale, corr_dims, ext_factor, attn_factor);
} else {
k_rope_cache<true, false><<<nblocks, CUDA_ROPE_BLOCK_SIZE, 0, ctx.stream()>>>(ggml_nelements(dst), dst->ne[0],
(float *)dst->data, pos, freq_factors, theta_scale, freq_scale, corr_dims, ext_factor, attn_factor);
}
}
void ggml_cuda_op_rope_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_cache_impl<true>(ctx, dst);
}
void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == dst->type);
const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((const int32_t *) src1->op_params)[1];
const int mode = ((const int32_t *) src1->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}
// compute
if (is_neox) {
//printf("Using neox\n");
rope_neox_fast_cuda(
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, stream);
} else if (is_mrope && !is_vision) {
rope_multi_fast_cuda(
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
} else if (is_vision) {
rope_vision_fast_cuda(
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
} else {
//printf("Using norm\n");
rope_norm_fast_cuda(
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
}
}
bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) {
if (dst1->src[1] != dst2->src[1]) return false;
const ggml_tensor * src0_1 = dst1->src[0];
const ggml_tensor * src0_2 = dst2->src[0];
const ggml_tensor * src1 = dst1->src[1];
if (src0_1->type != GGML_TYPE_F32) return false;
if (src0_2->type != GGML_TYPE_F32) return false;
if (dst1->type != GGML_TYPE_F32) return false;
if (dst2->type != GGML_TYPE_F32) return false;
if (src1->type != dst1->type) return false;
if (src0_1->ne[0] != src0_2->ne[0]) return false;
if (src0_1->ne[2] != src0_2->ne[2]) return false;
if (src0_1->ne[3] != src0_2->ne[3]) return false;
const int n_dims = ((const int32_t *) src1->op_params)[1];
const int mode = ((const int32_t *) src1->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_vision || is_mrope) return false; // not implemented
const int64_t ne00 = src0_1->ne[0]; // head dims
const int64_t ne02 = src0_1->ne[2]; // num tokens
const int64_t ne01_1 = src0_1->ne[1]; // num heads
const int64_t ne01_2 = src0_2->ne[1]; // num heads
const size_t s01_1 = src0_1->nb[1] / ggml_type_size(src0_1->type);
const size_t s02_1 = src0_1->nb[2] / ggml_type_size(src0_1->type);
const size_t s01_2 = src0_2->nb[1] / ggml_type_size(src0_2->type);
const size_t s02_2 = src0_2->nb[2] / ggml_type_size(src0_2->type);
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}
// compute
if (is_neox) {
fused_rope_neox_fast_cuda(
(const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, ctx.stream());
} else {
fused_rope_norm_fast_cuda(
(const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, ctx.stream());
}
return true;
}
bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) {
if (dst1->src[1] != dst2->src[1]) return false;
const auto rms_1 = dst1->src[0];
const auto rms_2 = dst2->src[0];
const auto src1 = dst1->src[1];
if (rms_1->op != GGML_OP_FUSED_RMS_NORM) return false;
if (rms_2->op != GGML_OP_FUSED_RMS_NORM) return false;
const auto src0_1 = rms_1->src[0];
const auto src0_2 = rms_2->src[0];
const auto c_1 = rms_1->src[1];
const auto c_2 = rms_2->src[1];
if (src0_1->type != GGML_TYPE_F32) return false;
if (src0_2->type != GGML_TYPE_F32) return false;
if (dst1->type != GGML_TYPE_F32) return false;
if (dst2->type != GGML_TYPE_F32) return false;
if (src1->type != dst1->type) return false;
if (c_1->type != GGML_TYPE_F32) return false;
if (c_2->type != GGML_TYPE_F32) return false;
if (src0_1->ne[0] != src0_2->ne[0]) return false;
if (src0_1->ne[2] != src0_2->ne[2]) return false;
if (src0_1->ne[3] != src0_2->ne[3]) return false;
if (src0_1->ne[0] > 2*CUDA_ROPE_BLOCK_SIZE) return false;
GGML_ASSERT(ggml_nrows(c_1) == 1);
GGML_ASSERT(ggml_nrows(c_2) == 1);
GGML_ASSERT(c_1->ne[0] == src0_1->ne[0]);
GGML_ASSERT(c_2->ne[0] == src0_2->ne[0]);
const int n_dims = ((const int32_t *) src1->op_params)[1];
const int mode = ((const int32_t *) src1->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_vision || is_mrope) return false; // not implemented
if (!is_neox) return false; // TODO
float eps1, eps2;
memcpy(&eps1, rms_1->op_params, sizeof(float));
memcpy(&eps2, rms_2->op_params, sizeof(float));
if (eps1 != eps2) return false;
const int64_t ne00 = src0_1->ne[0]; // head dims
const int64_t ne02 = src0_1->ne[2]; // num tokens
const int64_t ne01_1 = src0_1->ne[1]; // num heads
const int64_t ne01_2 = src0_2->ne[1]; // num heads
const size_t s01_1 = src0_1->nb[1] / ggml_type_size(src0_1->type);
const size_t s02_1 = src0_1->nb[2] / ggml_type_size(src0_1->type);
const size_t s01_2 = src0_2->nb[1] / ggml_type_size(src0_2->type);
const size_t s02_2 = src0_2->nb[2] / ggml_type_size(src0_2->type);
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}
// compute
fused_rms_rope_neox_fast_cuda(
(const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
(const float *)c_1->data, (const float *)c_2->data,
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, eps1,
ctx.stream());
return true;
}

View File

@@ -5,3 +5,11 @@
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);
bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);

View File

@@ -4242,6 +4242,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"SOFT_MAX_BACK",
"ROPE",
"ROPE_BACK",
"ROPE_CACHE",
"ROPE_FAST",
"CLAMP",
"CONV_TRANSPOSE_1D",
"IM2COL",
@@ -4290,7 +4292,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -4347,6 +4349,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"soft_max_back(x)",
"rope(x)",
"rope_back(x)",
"rope_cache(pos)",
"rope_fast(x)",
"clamp(x)",
"conv_transpose_1d(x)",
"im2col(x)",
@@ -4395,7 +4399,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x),"
};
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -8664,6 +8668,80 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
// ggml_rope
struct ggml_tensor * ggml_rope_cache(
struct ggml_context * ctx,
struct ggml_tensor * b,
struct ggml_tensor * c,
int ne0,
int n_dims,
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
GGML_ASSERT(!mrope_used);
//if (mrope_used) {
// GGML_ASSERT(ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
//} else {
// GGML_ASSERT(a->ne[2] == b->ne[0]);
//}
if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32);
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne0, b->ne[0]);
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
//if (mrope_used) {
// memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
//} else {
// memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
//}
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE_CACHE;
result->src[0] = b;
result->src[1] = c;
return result;
}
struct ggml_tensor * ggml_rope_fast(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(a->ne[0] <= b->ne[0]);
GGML_ASSERT(a->ne[2] <= b->ne[1]);
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(b->type == GGML_TYPE_F32);
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_ROPE_FAST;
result->src[0] = a;
result->src[1] = b;
return result;
}
static struct ggml_tensor * ggml_rope_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -18396,6 +18474,181 @@ static void ggml_mrope_cache_init(
}
}
static void ggml_compute_forward_rope_cache_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
const bool forward) {
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
const struct ggml_tensor * tpos = dst->src[0];
GGML_ASSERT(tpos->type == GGML_TYPE_I32);
GGML_ASSERT(tpos->ne[0] == dst->ne[1]);
GGML_ASSERT(n_dims <= dst->ne[0]);
GGML_ASSERT(n_dims % 2 == 0);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == dst->ne[0]);
}
const float * freq_factors = NULL;
if (dst->src[1] != NULL) {
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->ne[0] >= n_dims / 2);
freq_factors = (const float *) dst->src[1]->data;
}
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;
const int32_t * pos = (const int32_t *) tpos->data;
int ith = params->ith;
int nth = params->nth;
const int npt = (dst->ne[1] + nth - 1)/nth;
int first = npt*ith;
int last = MIN(dst->ne[1], first + npt);
int64_t ne0 = dst->ne[0];
int64_t ne2 = dst->ne[1];
for (int i1 = first; i1 < last; ++i1) {
float * cache = (float *)((char *)dst->data + dst->nb[1]*i1);
if (!is_mrope) {
const int64_t p = pos[i1];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
else {
const int64_t p_t = pos[i1];
const int64_t p_h = pos[i1 + ne2];
const int64_t p_w = pos[i1 + ne2 * 2];
const int64_t p_e = pos[i1 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
}
}
static void ggml_compute_forward_rope_fast_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[0] <= src1->ne[0]);
GGML_ASSERT(src0->ne[2] <= src1->ne[1]);
const int n_dims = ((const int32_t *) src1->op_params)[1];
const int mode = ((const int32_t *) src1->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
const int ith = params->ith;
const int nth = params->nth;
const int nrows = ggml_nrows(src0);
const int npt = (nrows + nth - 1)/nth;
const int first = ith*npt;
const int last = MIN(first + npt, nrows);
const int ne02 = src0->ne[2];
const int ne01 = src0->ne[1];
const int ne00 = src0->ne[0];
for (int ir = first; ir < last; ++ir) {
const int i3 = ir/(ne01*ne02);
const int i2 = (ir - i3*ne01*ne02)/ne01;
const int i1 = ir - i3*ne01*ne02 - i2*ne01;
const float * c = (const float *)((const char *)src1->data + i2*src1->nb[1]);
const float * x = (const float *)((const char *)src0->data + i1*src0->nb[1] + i2*src0->nb[2] + i3*src0->nb[3]);
float * y = ( float *)(( char *)dst->data + i1* dst->nb[1] + i2* dst->nb[2] + i3* dst->nb[3]);
if (is_neox || is_mrope) {
const int n_gap = is_vision ? n_dims : n_dims/2;
for (int i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = c[i0 + 0];
const float sin_theta = c[i0 + 1];
const float x0 = x[ic];
const float x1 = x[ic+n_gap];
y[ic ] = x0*cos_theta - x1*sin_theta;
y[ic+n_gap] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = c[i0 + 0];
const float sin_theta = c[i0 + 1];
const float x0 = x[i0+0];
const float x1 = x[i0+1];
y[i0+0] = x0*cos_theta - x1*sin_theta;
y[i0+1] = x0*sin_theta + x1*cos_theta;
}
}
if (is_vision) {
for (int i0 = n_dims; i0 < ne00; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = c[i0 + 0];
const float sin_theta = c[i0 + 1];
const float x0 = x[ic];
const float x1 = x[ic+n_dims];
y[ic] = x0*cos_theta - x1*sin_theta;
y[ic+n_dims] = x0*sin_theta + x1*cos_theta;
}
} else {
// fill the remain channels with data from src tensor
for (int i0 = n_dims; i0 < ne00; i0 += 2) {
y[i0+0] = x[i0+0];
y[i0+1] = x[i0+1];
}
}
}
}
static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
@@ -22584,6 +22837,14 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
{
ggml_compute_forward_rope_back(params, tensor);
} break;
case GGML_OP_ROPE_CACHE:
{
ggml_compute_forward_rope_cache_f32(params, tensor, true);
} break;
case GGML_OP_ROPE_FAST:
{
ggml_compute_forward_rope_fast_f32(params, tensor);
} break;
case GGML_OP_CLAMP:
{
ggml_compute_forward_clamp(params, tensor);
@@ -23635,6 +23896,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
case GGML_OP_ROPE_CACHE:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_ROPE_FAST:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_GLU:
{
GGML_ABORT("fatal error"); // TODO: not implemented
@@ -24408,6 +24677,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_ROPE_CACHE:
case GGML_OP_ROPE_FAST:
case GGML_OP_ADD_REL_POS:
{
n_tasks = n_threads;