Fused rope+rope (norm)

This commit is contained in:
Iwan Kawrakow
2025-11-02 07:21:17 +02:00
parent f5ac78de5c
commit 623d775929
3 changed files with 99 additions and 47 deletions

View File

@@ -3322,22 +3322,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
if (ENABLE_FUSION && i + 3 < cgraph->n_nodes &&
(cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) &&
(cgraph->nodes[i+2]->op == GGML_OP_RESHAPE || cgraph->nodes[i+2]->op == GGML_OP_VIEW) &&
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST) {
//printf("Fusing %s, %s\n", dst->name, cgraph->nodes[i+3]->name);
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+3]);
i += 2;
cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST &&
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+3])) {
i += 3;
}
else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes &&
(cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) &&
cgraph->nodes[i+2]->op == GGML_OP_ROPE_FAST) {
//printf("Fusing %s, %s\n", dst->name, cgraph->nodes[i+2]->name);
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+2]);
cgraph->nodes[i+2]->op == GGML_OP_ROPE_FAST &&
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+2])) {
i += 2;
}
else if (ENABLE_FUSION && i + 1 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_ROPE_FAST) {
//printf("Fusing %s, %s\n", dst->name, cgraph->nodes[i+1]->name);
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+1]);
cgraph->nodes[i+1]->op == GGML_OP_ROPE_FAST &&
ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+1])) {
i += 1;
}
else {

View File

@@ -233,6 +233,55 @@ static __global__ void rope_norm_fast(const float * src0, const float * src1, fl
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
}
static __global__ void fused_rope_norm_fast(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int nelem1, int nelem,
int s01_1, int s02_1, int s01_2, int s02_2, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= nelem) {
return;
}
const float * src0;
float * dst;
int ne1, s01, s02;
if (i < nelem1) {
src0 = src0_1;
dst = dst_1;
ne1 = ne1_1;
s01 = s01_1;
s02 = s02_1;
} else {
i -= nelem1;
src0 = src0_2;
dst = dst_2;
ne1 = ne1_2;
s01 = s01_2;
s02 = s02_2;
}
int i2 = i / (ne0*ne1); i -= i2*ne0*ne1;
int i1 = i / ne0;
int i0 = i - i1*ne0;
const int idst = i2*ne0*ne1 + i1*ne0 + i0;
const int ix = i2*s02 + i1*s01 + i0;
if (i0 >= n_dims) {
dst[idst + 0] = src0[ix + 0];
dst[idst + 1] = src0[ix + 1];
return;
}
const float x0 = src0[ix + 0];
const float x1 = src0[ix + 1];
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
}
template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
@@ -407,6 +456,20 @@ static void fused_rope_neox_fast_cuda(const float * src0_1, const float * src0_2
s01_1, s02_1, s01_2, s02_2, n_dims);
}
static void fused_rope_norm_fast_cuda(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int ne2, int s01_1, int s02_1, int s01_2, int s02_2,
int n_dims, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int nelem1 = ne0*ne1_1*ne2;
const int nelem2 = ne0*ne1_2*ne2;
const int nelem = nelem1 + nelem2;
const int n_blocks = (nelem + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
fused_rope_norm_fast<<<block_nums, block_dims, 0, stream>>>(src0_1, src0_2, src1, dst_1, dst_2, ne0, ne1_1, ne1_2, nelem1, nelem,
s01_1, s02_1, s01_2, s02_2, n_dims);
}
static void rope_norm_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
int n_dims, cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
@@ -754,26 +817,35 @@ void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
}
}
void ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) {
GGML_ASSERT(dst1->src[1] == dst2->src[1]);
bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) {
if (dst1->src[1] != dst2->src[1]) return false;
const ggml_tensor * src0_1 = dst1->src[0];
const ggml_tensor * src0_2 = dst2->src[0];
const ggml_tensor * src1 = dst1->src[1];
cudaStream_t stream = ctx.stream();
if (src0_1->type != GGML_TYPE_F32) return false;
if (src0_2->type != GGML_TYPE_F32) return false;
if (dst1->type != GGML_TYPE_F32) return false;
if (dst2->type != GGML_TYPE_F32) return false;
if (src1->type != dst1->type) return false;
GGML_ASSERT(src0_1->type == GGML_TYPE_F32);
GGML_ASSERT(src0_2->type == GGML_TYPE_F32);
GGML_ASSERT(dst1->type == GGML_TYPE_F32);
GGML_ASSERT(dst2->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == dst1->type);
if (src0_1->ne[0] != src0_2->ne[0]) return false;
if (src0_1->ne[2] != src0_2->ne[2]) return false;
if (src0_1->ne[3] != src0_2->ne[3]) return false;
GGML_ASSERT(src0_1->ne[0] == src0_2->ne[0]);
GGML_ASSERT(src0_1->ne[2] == src0_2->ne[2]);
GGML_ASSERT(src0_1->ne[3] == src0_2->ne[3]);
const int n_dims = ((const int32_t *) src1->op_params)[1];
const int mode = ((const int32_t *) src1->op_params)[2];
const int64_t ne00 = src0_1->ne[0]; // head dims
const int64_t ne02 = src0_1->ne[2]; // num heads
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_vision || is_mrope) return false; // not implemented
const int64_t ne00 = src0_1->ne[0]; // head dims
const int64_t ne02 = src0_1->ne[2]; // num tokens
const int64_t ne01_1 = src0_1->ne[1]; // num heads
const int64_t ne01_2 = src0_2->ne[1]; // num heads
@@ -782,36 +854,19 @@ void ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor *
const size_t s01_2 = src0_2->nb[1] / ggml_type_size(src0_2->type);
const size_t s02_2 = src0_2->nb[2] / ggml_type_size(src0_2->type);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((const int32_t *) src1->op_params)[1];
const int mode = ((const int32_t *) src1->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}
// compute
if (is_neox) {
//printf("Using neox\n");
fused_rope_neox_fast_cuda(
(const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, stream);
//} else if (is_mrope && !is_vision) {
// rope_multi_fast_cuda(
// (const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
// (float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, stream);
//} else if (is_vision) {
// rope_vision_fast_cuda(
// (const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
// (float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, stream);
//} else {
// //printf("Using norm\n");
// fused_rope_norm_fast_cuda(
// (const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
// (float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, stream);
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, ctx.stream());
} else {
fused_rope_norm_fast_cuda(
(const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data,
(float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, ctx.stream());
}
return true;
}

View File

@@ -10,4 +10,4 @@ void ggml_cuda_op_rope_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);
bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);