diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index c0c9d405..ff4b4e58 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -868,9 +868,8 @@ void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((const int32_t *) src1->op_params)[1]; - const int mode = ((const int32_t *) src1->op_params)[2]; + const int n_dims = ((const int32_t *) dst->op_params)[0]; + const int mode = ((const int32_t *) dst->op_params)[1]; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; @@ -916,8 +915,10 @@ bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * if (src0_1->ne[2] != src0_2->ne[2]) return false; if (src0_1->ne[3] != src0_2->ne[3]) return false; - const int n_dims = ((const int32_t *) src1->op_params)[1]; - const int mode = ((const int32_t *) src1->op_params)[2]; + const int n_dims = ((const int32_t *) dst1->op_params)[0]; + const int mode = ((const int32_t *) dst1->op_params)[1]; + + if (n_dims != dst2->op_params[0] || mode != dst2->op_params[1]) return false; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; @@ -986,8 +987,10 @@ bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tens GGML_ASSERT(c_1->ne[0] == src0_1->ne[0]); GGML_ASSERT(c_2->ne[0] == src0_2->ne[0]); - const int n_dims = ((const int32_t *) src1->op_params)[1]; - const int mode = ((const int32_t *) src1->op_params)[2]; + const int n_dims = ((const int32_t *) dst1->op_params)[0]; + const int mode = ((const int32_t *) dst1->op_params)[1]; + + if (n_dims != dst2->op_params[0] || mode != dst2->op_params[1]) return false; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ba15c70c..41501453 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -8735,6 +8735,9 @@ struct ggml_tensor * ggml_rope_fast( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + result->op_params[0] = b->op_params[1]; + result->op_params[1] = b->op_params[2]; + result->op = GGML_OP_ROPE_FAST; result->src[0] = a; result->src[1] = b; @@ -18586,8 +18589,8 @@ static void ggml_compute_forward_rope_fast_f32( GGML_ASSERT(src0->ne[0] <= src1->ne[0]); GGML_ASSERT(src0->ne[2] <= src1->ne[1]); - const int n_dims = ((const int32_t *) src1->op_params)[1]; - const int mode = ((const int32_t *) src1->op_params)[2]; + const int n_dims = dst->op_params[0]; + const int mode = dst->op_params[1]; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding