From 054c31cf8f8da19c24e9849ae4b0eb28b3b69631 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 19 Nov 2025 09:08:42 +0100 Subject: [PATCH] Fuse Q and K RoPE (#980) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 14 +- ggml/src/ggml-cuda/rope.cu | 424 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/rope.cuh | 2 + 3 files changed, 439 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e494c582..f6891cac 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3346,7 +3346,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_soft_cap_max(ctx, dst); break; case GGML_OP_ROPE: - ggml_cuda_op_rope(ctx, dst); + if (fusion && i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_VIEW && + cgraph->nodes[i+2]->op == GGML_OP_ROPE && + ggml_cuda_op_rope_rope(ctx, dst, cgraph->nodes[i+2])) { + i += 2; + } + else if (fusion && i + 1 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_ROPE && + ggml_cuda_op_rope_rope(ctx, dst, cgraph->nodes[i+1])) { + i += 1; + } else { + ggml_cuda_op_rope(ctx, dst); + } break; case GGML_OP_ROPE_BACK: ggml_cuda_op_rope_back(ctx, dst); diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index ff4b4e58..a75d9408 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -1026,3 +1026,427 @@ bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tens ctx.stream()); return true; } + +template +static __global__ void rope_rope_neox( + const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2, + const int s1_1, const int s2_1, const int s1_2, const int s2_2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne0) { + return; + } + + int row_x = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = blockDim.x*blockIdx.x + threadIdx.x; + + const T * x; + T * dst; + int ne1, s1, s2; + if (row_x < ne1_1) { + x = x1; + dst = dst1; + ne1 = ne1_1; + s1 = s1_1; + s2 = s2_1; + } else { + x = x2; + dst = dst2; + row_x -= ne1_1; + ne1 = ne1_2; + s1 = s1_2; + s2 = s2_2; + } + + const int idst = (row_x + channel_x*ne1)*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + + if (i0 >= n_dims) { + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + + return; + } + + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float t0 = x[ix + 0]; + const float t1 = x[ix + n_dims/2]; + + dst[idst + 0] = t0*cos_theta - t1*sin_theta; + dst[idst + n_dims/2] = t0*sin_theta + t1*cos_theta; +} + +template +static __global__ void rope_rope_norm( + const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2, + const int s1_1, const int s2_1, const int s1_2, const int s2_2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne0) { + return; + } + + int row_x = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = blockDim.x*blockIdx.x + threadIdx.x; + + const T * x; + T * dst; + int ne1, s1, s2; + if (row_x < ne1_1) { + x = x1; + dst = dst1; + ne1 = ne1_1; + s1 = s1_1; + s2 = s2_1; + } else { + x = x2; + dst = dst2; + row_x -= ne1_1; + ne1 = ne1_2; + s1 = s1_2; + s2 = s2_2; + } + + const int idst = (row_x + channel_x*ne1)*ne0 + i0; + const int ix = channel_x*s2 + row_x*s1 + i0; + + if (i0 >= n_dims) { + dst[idst + 0] = x[ix + 0]; + dst[idst + 1] = x[ix + 1]; + + return; + } + + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float t0 = x[ix + 0]; + const float t1 = x[ix + 1]; + + dst[idst + 0] = t0*cos_theta - t1*sin_theta; + dst[idst + 1] = t0*sin_theta + t1*cos_theta; +} + +template +static __global__ void rope_rope_multi( + const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2, + const int ne2, const int s1_1, const int s2_1, const int s1_2, const int s2_2, + const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne0) { + return; + } + + int row_x = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = blockDim.x*blockIdx.x + threadIdx.x; + + const T * x; + T * dst; + int ne1, s1, s2; + if (row_x < ne1_1) { + x = x1; + dst = dst1; + ne1 = ne1_1; + s1 = s1_1; + s2 = s2_1; + } else { + x = x2; + dst = dst2; + row_x -= ne1_1; + ne1 = ne1_2; + s1 = s1_2; + s2 = s2_2; + } + + const int idst = (row_x + channel_x*ne1)*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + + if (i0 >= n_dims) { + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + + return; + } + + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + } + } else { + if (sector < sections.v[0]) { + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + } + } + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float t0 = x[ix + 0]; + const float t1 = x[ix + n_dims/2]; + + dst[idst + 0] = t0*cos_theta - t1*sin_theta; + dst[idst + n_dims/2] = t0*sin_theta + t1*cos_theta; +} + //rope_rope_vision<<>>( + // x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor, + // attn_factor, corr_dims, theta_scale, freq_factors); + +template +static __global__ void rope_rope_vision( + const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2, + const int ne2, const int s1_1, const int s2_1, const int s1_2, const int s2_2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne0) { + return; + } + + int row_x = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = blockDim.x*blockIdx.x + threadIdx.x; + + const T * x; + T * dst; + int ne1, s1, s2; + if (row_x < ne1_1) { + x = x1; + dst = dst1; + ne1 = ne1_1; + s1 = s1_1; + s2 = s2_1; + } else { + x = x2; + dst = dst2; + row_x -= ne1_1; + ne1 = ne1_2; + s1 = s1_2; + s2 = s2_2; + } + + const int idst = (row_x + channel_x*ne1)*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + + const int sect_dims = sections.v[0] + sections.v[1]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < sections.v[0]) { + const int p = sector; + theta_base = pos[channel_x]*powf(theta_scale, p); + } + else if (sector >= sections.v[0] && sector < sec_w) { + const int p = sector - sections.v[0]; + theta_base = pos[channel_x + ne2]*powf(theta_scale, p); + } + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float t0 = x[ix + 0]; + const float t1 = x[ix + n_dims]; + + dst[idst + 0] = t0*cos_theta - t1*sin_theta; + dst[idst + n_dims] = t0*sin_theta + t1*cos_theta; +} + +template +static void rope_rope_cuda(int kernel, + const T * x1, const T * x2, T * dst1, T * dst2, const int ne0, const int ne1_1, const int ne1_2, + const int s1_1, const int s2_1, const int se1_2, const int se2_2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_mrope, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, ne1_1 + ne1_2); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + switch (kernel) { + case 0: + rope_rope_neox<<>>( + x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); + break; + case 1: + rope_rope_multi<<>>( + x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_mrope); + break; + case 2: + rope_rope_vision<<>>( + x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + break; + case 3: + rope_rope_norm<<>>( + x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); + break; + default: GGML_ABORT("fatal error"); + } + } else { + switch (kernel) { + case 0: + rope_rope_neox<<>>( + x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); + break; + case 1: + rope_rope_multi<<>>( + x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_mrope); + break; + case 2: + rope_rope_vision<<>>( + x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, nr, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + break; + case 3: + rope_rope_norm<<>>( + x1, x2, dst1, dst2, ne0, ne1_1, ne1_2, s1_1, s2_1, se1_2, se2_2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); + break; + default: GGML_ABORT("fatal error"); + } + } +} + +template +bool ggml_cuda_op_rope_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) { + if (dst1->src[1] != dst2->src[1]) return false; + if (dst1->src[2] != dst2->src[2]) return false; + //if (dst1->ne[2] > 1 || dst1->ne[3] > 1 || dst1->ne[0] != dst2->ne[0] || dst1->ne[2] != dst2->ne[2] || dst1->ne[3] != dst2->ne[3]) return false; + if (dst1->ne[3] > 1 || dst1->ne[0] != dst2->ne[0] || dst1->ne[2] != dst2->ne[2] || dst1->ne[3] != dst2->ne[3]) return false; + if (memcmp(dst1->op_params, dst2->op_params, 15*sizeof(int))) return false; + if (dst1->src[0]->type != dst2->src[0]->type) return false; + if (dst1->type != dst2->type) return false; + if (dst1->src[0]->type != GGML_TYPE_F32 && dst1->src[0]->type != GGML_TYPE_F16) return false; + if (dst1->src[0]->type != dst1->type) return false; + + const int64_t ne00 = dst1->src[0]->ne[0]; + const int64_t ne01_1 = dst1->src[0]->ne[1]; + const int64_t ne01_2 = dst2->src[0]->ne[1]; + const int64_t ne02 = dst1->src[0]->ne[2]; + + const size_t s01_1 = dst1->src[0]->nb[1] / ggml_type_size(dst1->src[0]->type); + const size_t s02_1 = dst1->src[0]->nb[2] / ggml_type_size(dst1->src[0]->type); + const size_t s01_2 = dst2->src[0]->nb[1] / ggml_type_size(dst2->src[0]->type); + const size_t s02_2 = dst2->src[0]->nb[2] / ggml_type_size(dst2->src[0]->type); + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst1->op_params)[1]; + const int mode = ((int32_t *) dst1->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst1->op_params)[4]; + mrope_sections sections; + + // RoPE alteration for extended context + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (int32_t *) dst1->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst1->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst1->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst1->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst1->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst1->op_params + 10, sizeof(float)); + memcpy(§ions.v, (int32_t *) dst1->op_params + 11, sizeof(int)*4); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } + + const int32_t * pos = (const int32_t *) dst1->src[1]->data; + + const float * freq_factors = nullptr; + if (dst1->src[2] != nullptr) { + freq_factors = (const float *) dst1->src[2]->data; + } + + rope_corr_dims corr_dims; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); + + auto type = dst1->type; + int kernel = is_neox ? 0 : is_mrope && !is_vision ? 1 : is_vision ? 2 : 3; + + // compute + if (type == GGML_TYPE_F32) { + rope_rope_cuda(kernel, + (const float *)dst1->src[0]->data, (const float *)dst2->src[0]->data, (float *)dst1->data, (float *)dst2->data, + ne00, ne01_1, ne01_2, s01_1, s02_1, s01_2, s02_2, n_dims, ne02, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, ctx.stream()); + } else { + rope_rope_cuda(kernel, + (const half *)dst1->src[0]->data, (const half *)dst2->src[0]->data, (half *)dst1->data, (half *)dst2->data, + ne00, ne01_1, ne01_2, s01_1, s02_1, s01_2, s02_2, n_dims, ne02, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, ctx.stream()); + } + return true; +} + +bool ggml_cuda_op_rope_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) { + return ggml_cuda_op_rope_rope_impl(ctx, dst1, dst2); +} diff --git a/ggml/src/ggml-cuda/rope.cuh b/ggml/src/ggml-cuda/rope.cuh index 4bf0ed42..b4d036c8 100644 --- a/ggml/src/ggml-cuda/rope.cuh +++ b/ggml/src/ggml-cuda/rope.cuh @@ -13,3 +13,5 @@ void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst); bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2); bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2); + +bool ggml_cuda_op_rope_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);