diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 6b114ef8..c78a0f39 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3245,7 +3245,20 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_rms_norm(ctx, dst); break; case GGML_OP_FUSED_RMS_NORM: - if (i + 2 < cgraph->n_nodes && + //if (i + 6 < cgraph->n_nodes) { + // printf("=== Fused rms_norm(%s)\n", dst->name); + // for (int j = 1; j <= 6; ++j) printf(" %s(%s)\n", ggml_op_name(cgraph->nodes[i+j]->op), cgraph->nodes[i+j]->name); + //} + if (false && ENABLE_FUSION && i + 4 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_VIEW && + cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && + cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST && + cgraph->nodes[i+4]->op == GGML_OP_ROPE_FAST && + ggml_cuda_op_fused_rms_rope_fast(ctx, cgraph->nodes[i+3], cgraph->nodes[i+4])) { + //printf("Fused rms+rms+rope+rope\n"); + i += 4; + } + else if (ENABLE_FUSION && i + 2 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && dst->ne[2] == 1 && cgraph->nodes[i+2]->ne[2] == 1) { diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 00b74287..064aa98c 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -202,6 +202,83 @@ static __global__ void fused_rope_neox_fast(const float * src0_1, const float * dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } +static __global__ void fused_rms_rope_neox_fast(const float * src0_1, const float * src0_2, const float * src1, + const float * c_1, const float * c_2, + float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, + int s01_1, int s02_1, int s01_2, int s02_2, int n_dims, float eps) { + + int i0 = 2*threadIdx.y; + int i1 = blockIdx.x*blockDim.x + threadIdx.x; + int i2 = blockIdx.z*blockDim.z + threadIdx.z; + + __shared__ float s_sum[WARP_SIZE]; + + src0_1 += i1*s01_1 + i2*s02_1; + src0_2 += i1*s01_2 + i2*s02_2; + dst_1 += ne0*(i1 + i2*ne1_1); + dst_2 += ne0*(i1 + i2*ne1_2); + + float norm_1 = 1, norm_2 = 1; + if (i1 < ne1_1) { + float sum = i0 < ne0 ? src0_1[i0]*src0_1[i0] + src0_1[i0+1]*src0_1[i0+1] : 0.0f; + sum = warp_reduce_sum(sum); + if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) { + int warp_id = (i0/2) / WARP_SIZE; + int lane_id = (i0/2) % WARP_SIZE; + if (lane_id == 0) s_sum[warp_id] = sum; + __syncthreads(); + sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0; + sum = warp_reduce_sum(sum); + } + norm_1 = rsqrtf(sum/ne0 + eps); + } + if (i2 < ne1_2) { + float sum = i0 < ne0 ? src0_2[i0]*src0_2[i0] + src0_2[i0+1]*src0_2[i0+1] : 0.0f; + sum = warp_reduce_sum(sum); + if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) { + int warp_id = (i0/2) / WARP_SIZE; + int lane_id = (i0/2) % WARP_SIZE; + if (lane_id == 0) s_sum[warp_id] = sum; + __syncthreads(); + sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0; + sum = warp_reduce_sum(sum); + } + norm_2 = rsqrtf(sum/ne0 + eps); + } + + if (i0 >= ne0) return; + + if (i0 >= n_dims) { + if (i1 < ne1_1) { + dst_1[i0 + 0] = norm_1*c_1[i0 + 0]*src0_1[i0 + 0]; + dst_1[i0 + 1] = norm_1*c_1[i0 + 1]*src0_1[i0 + 1]; + } + if (i1 < ne1_2) { + dst_2[i0 + 0] = norm_2*c_2[i0 + 0]*src0_2[i0 + 0]; + dst_2[i0 + 1] = norm_2*c_2[i0 + 1]*src0_2[i0 + 1]; + } + return; + } + + const float cos_theta = src1[i2*ne0 + i0 + 0]; + const float sin_theta = src1[i2*ne0 + i0 + 1]; + + if (i1 < ne1_1) { + const float x0 = norm_1*c_1[i0/2 + 0]*src0_1[i0/2 + 0]; + const float x1 = norm_1*c_1[i0/2 + n_dims/2]*src0_1[i0/2 + n_dims/2]; + dst_1[i0/2 + 0] = x0*cos_theta - x1*sin_theta; + dst_1[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta; + } + + if (i1 < ne1_2) { + const float x0 = norm_2*c_2[i0/2 + 0]*src0_2[i0/2 + 0]; + const float x1 = norm_2*c_2[i0/2 + n_dims/2]*src0_2[i0/2 + n_dims/2]; + dst_2[i0/2 + 0] = x0*cos_theta - x1*sin_theta; + dst_2[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta; + } + +} + static __global__ void rope_norm_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int nelem, int s01, int s02, int n_dims) { int i = 2*(blockDim.x*blockIdx.x + threadIdx.x); @@ -456,6 +533,19 @@ static void fused_rope_neox_fast_cuda(const float * src0_1, const float * src0_2 s01_1, s02_1, s01_2, s02_2, n_dims); } +static void fused_rms_rope_neox_fast_cuda(const float * src0_1, const float * src0_2, const float * src1, + const float * c_1, const float * c_2, + float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int ne2, int s01_1, int s02_1, int s01_2, int s02_2, + int n_dims, float eps, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne0 <= 2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + int ne1 = std::max(ne1_1, ne1_2); + const dim3 block_nums(ne1, 1, ne2); + fused_rms_rope_neox_fast<<>>(src0_1, src0_2, src1, c_1, c_2, dst_1, dst_2, ne0, ne1_1, ne1_2, + s01_1, s02_1, s01_2, s02_2, n_dims, eps); +} + static void fused_rope_norm_fast_cuda(const float * src0_1, const float * src0_2, const float * src1, float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int ne2, int s01_1, int s02_1, int s01_2, int s02_2, int n_dims, cudaStream_t stream) { @@ -870,3 +960,71 @@ bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * } return true; } + +bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2) { + + if (dst1->src[1] != dst2->src[1]) return false; + + const auto rms_1 = dst1->src[0]; + const auto rms_2 = dst2->src[0]; + const auto src1 = dst1->src[1]; + + if (rms_1->op != GGML_OP_FUSED_RMS_NORM) return false; + if (rms_2->op != GGML_OP_FUSED_RMS_NORM) return false; + + const auto src0_1 = rms_1->src[0]; + const auto src0_2 = rms_2->src[0]; + + if (src0_1->type != GGML_TYPE_F32) return false; + if (src0_2->type != GGML_TYPE_F32) return false; + if (dst1->type != GGML_TYPE_F32) return false; + if (dst2->type != GGML_TYPE_F32) return false; + if (src1->type != dst1->type) return false; + + if (src0_1->ne[0] != src0_2->ne[0]) return false; + if (src0_1->ne[2] != src0_2->ne[2]) return false; + if (src0_1->ne[3] != src0_2->ne[3]) return false; + if (src0_1->ne[0] > 2*CUDA_ROPE_BLOCK_SIZE) return false; + + GGML_ASSERT(ggml_nrows(rms_1->src[1]) == 1); + GGML_ASSERT(ggml_nrows(rms_2->src[1]) == 1); + GGML_ASSERT(rms_1->src[1]->ne[0] == src0_1->ne[0]); + GGML_ASSERT(rms_2->src[1]->ne[0] == src0_2->ne[0]); + + const int n_dims = ((const int32_t *) src1->op_params)[1]; + const int mode = ((const int32_t *) src1->op_params)[2]; + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_vision || is_mrope) return false; // not implemented + if (!is_neox) return false; // TODO + + float eps1, eps2; + memcpy(&eps1, rms_1->op_params, sizeof(float)); + memcpy(&eps2, rms_2->op_params, sizeof(float)); + if (eps1 != eps2) return false; + + const int64_t ne00 = src0_1->ne[0]; // head dims + const int64_t ne02 = src0_1->ne[2]; // num tokens + const int64_t ne01_1 = src0_1->ne[1]; // num heads + const int64_t ne01_2 = src0_2->ne[1]; // num heads + + const size_t s01_1 = src0_1->nb[1] / ggml_type_size(src0_1->type); + const size_t s02_1 = src0_1->nb[2] / ggml_type_size(src0_1->type); + const size_t s01_2 = src0_2->nb[1] / ggml_type_size(src0_2->type); + const size_t s02_2 = src0_2->nb[2] / ggml_type_size(src0_2->type); + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } + + // compute + fused_rms_rope_neox_fast_cuda( + (const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data, + (const float *)rms_1->src[1]->data, (const float *)rms_2->src[1]->data, + (float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, eps1, + ctx.stream()); + return true; +} diff --git a/ggml/src/ggml-cuda/rope.cuh b/ggml/src/ggml-cuda/rope.cuh index f537473f..4bf0ed42 100644 --- a/ggml/src/ggml-cuda/rope.cuh +++ b/ggml/src/ggml-cuda/rope.cuh @@ -11,3 +11,5 @@ void ggml_cuda_op_rope_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst); bool ggml_cuda_op_fused_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2); + +bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst1, ggml_tensor * dst2);