Fused rope+rope

This commit is contained in:
Iwan Kawrakow
2025-11-02 06:57:08 +02:00
parent ea97dc3a1c
commit f5ac78de5c
4 changed files with 186 additions and 21 deletions

View File

@@ -3319,7 +3319,30 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ROPE_BACK:
ggml_cuda_op_rope_back(ctx, dst);
case GGML_OP_ROPE_FAST:
ggml_cuda_op_rope_fast(ctx, dst);
if (ENABLE_FUSION && i + 3 < cgraph->n_nodes &&
(cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) &&
(cgraph->nodes[i+2]->op == GGML_OP_RESHAPE || cgraph->nodes[i+2]->op == GGML_OP_VIEW) &&
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST) {
//printf("Fusing %s, %s\n", dst->name, cgraph->nodes[i+3]->name);
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+3]);
i += 2;
}
else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
(cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) &&
cgraph->nodes[i+2]->op == GGML_OP_ROPE_FAST) {
//printf("Fusing %s, %s\n", dst->name, cgraph->nodes[i+2]->name);
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+2]);
i += 2;
}
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_ROPE_FAST) {
//printf("Fusing %s, %s\n", dst->name, cgraph->nodes[i+1]->name);
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+1]);
i += 1;
}
else {
ggml_cuda_op_rope_fast(ctx, dst);
}
break;
case GGML_OP_ROPE_CACHE:
ggml_cuda_op_rope_cache(ctx, dst);

View File

@@ -121,11 +121,11 @@ static __global__ void rope_neox(
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
static __global__ void rope_neox_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int ne2,
static __global__ void rope_neox_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int nelem,
int s01, int s02, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= ne0*ne1*ne2) {
if (i >= nelem) {
return;
}
//i = i0 + i1*ne0 + i2*ne0*ne1;
@@ -153,11 +153,60 @@ static __global__ void rope_neox_fast(const float * src0, const float * src1, fl
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
static __global__ void rope_norm_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int ne2,
static __global__ void fused_rope_neox_fast(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int nelem1, int nelem,
int s01_1, int s02_1, int s01_2, int s02_2, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= nelem) {
return;
}
const float * src0;
float * dst;
int ne1, s01, s02;
if (i < nelem1) {
src0 = src0_1;
dst = dst_1;
ne1 = ne1_1;
s01 = s01_1;
s02 = s02_1;
} else {
i -= nelem1;
src0 = src0_2;
dst = dst_2;
ne1 = ne1_2;
s01 = s01_2;
s02 = s02_2;
}
int i2 = i / (ne0*ne1); i -= i2*ne0*ne1;
int i1 = i / ne0;
int i0 = i - i1*ne0;
const int idst = i2*ne0*ne1 + i1*ne0 + i0/2;
const int ix = i2*s02 + i1*s01 + i0/2;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = src0[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = src0[ix + i0/2 + 1];
return;
}
const float x0 = src0[ix + 0];
const float x1 = src0[ix + n_dims/2];
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
static __global__ void rope_norm_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int nelem,
int s01, int s02, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= ne0*ne1*ne2) {
if (i >= nelem) {
return;
}
@@ -341,7 +390,21 @@ static void rope_neox_fast_cuda(const float * src0, const float * src1, float *
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne02, s01, s02, n_dims);
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne01*ne02*ne02, s01, s02, n_dims);
}
static void fused_rope_neox_fast_cuda(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int ne2, int s01_1, int s02_1, int s01_2, int s02_2,
int n_dims, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int nelem1 = ne0*ne1_1*ne2;
const int nelem2 = ne0*ne1_2*ne2;
const int nelem = nelem1 + nelem2;
const int n_blocks = (nelem + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
fused_rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0_1, src0_2, src1, dst_1, dst_2, ne0, ne1_1, ne1_2, nelem1, nelem,
s01_1, s02_1, s01_2, s02_2, n_dims);
}
static void rope_norm_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
@@ -350,7 +413,7 @@ static void rope_norm_fast_cuda(const float * src0, const float * src1, float *
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
rope_norm_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne02, s01, s02, n_dims);
rope_norm_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne01*ne02*ne02, s01, s02, n_dims);
}
static void rope_multi_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
@@ -690,3 +753,65 @@ void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
}
}
void ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) {
GGML_ASSERT(dst1->src[1] == dst2->src[1]);
const ggml_tensor * src0_1 = dst1->src[0];
const ggml_tensor * src0_2 = dst2->src[0];
const ggml_tensor * src1 = dst1->src[1];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0_1->type == GGML_TYPE_F32);
GGML_ASSERT(src0_2->type == GGML_TYPE_F32);
GGML_ASSERT(dst1->type == GGML_TYPE_F32);
GGML_ASSERT(dst2->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == dst1->type);
GGML_ASSERT(src0_1->ne[0] == src0_2->ne[0]);
GGML_ASSERT(src0_1->ne[2] == src0_2->ne[2]);
GGML_ASSERT(src0_1->ne[3] == src0_2->ne[3]);
const int64_t ne00 = src0_1->ne[0]; // head dims
const int64_t ne02 = src0_1->ne[2]; // num heads
const int64_t ne01_1 = src0_1->ne[1]; // num heads
const int64_t ne01_2 = src0_2->ne[1]; // num heads
const size_t s01_1 = src0_1->nb[1] / ggml_type_size(src0_1->type);
const size_t s02_1 = src0_1->nb[2] / ggml_type_size(src0_1->type);
const size_t s01_2 = src0_2->nb[1] / ggml_type_size(src0_2->type);
const size_t s02_2 = src0_2->nb[2] / ggml_type_size(src0_2->type);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((const int32_t *) src1->op_params)[1];
const int mode = ((const int32_t *) src1->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}
// compute
if (is_neox) {
//printf("Using neox\n");
fused_rope_neox_fast_cuda(
(const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, stream);
//} else if (is_mrope && !is_vision) {
// rope_multi_fast_cuda(
// (const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
// (float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, stream);
//} else if (is_vision) {
// rope_vision_fast_cuda(
// (const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
// (float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, stream);
//} else {
// //printf("Using norm\n");
// fused_rope_norm_fast_cuda(
// (const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
// (float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, stream);
}
}

View File

@@ -9,3 +9,5 @@ void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);

View File

@@ -7779,6 +7779,9 @@ ggml_cgraph * llm_build_context::build_hunyuan_moe() {
ggml_cgraph * llm_build_context::build_openai_moe() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
@@ -7796,6 +7799,9 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
const int sliding_window_pattern = 2;
auto rope_cache = ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
for (int il = 0; il < n_layer; ++il) {
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
ggml_tensor * inpSA = inpL;
@@ -7815,14 +7821,17 @@ ggml_cgraph * llm_build_context::build_openai_moe() {
model.layers[il].wv, model.layers[il].bv,
nullptr, nullptr, 0.0f, il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
beta_fast, beta_slow);
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
//Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
// beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
attn_factor, beta_fast, beta_slow);
//Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
// attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
//auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, model.layers[il].wq, model.layers[il].bq,
@@ -7926,6 +7935,9 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() {
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
auto rope_cache = ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
@@ -7950,16 +7962,19 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() {
//Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
//cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
//Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
//cb(Kcur, "Kcur_normed", il);
//Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
// ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
////Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
////cb(Kcur, "Kcur_normed", il);
//Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
// n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
// ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);