diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 6ffe908c..3faf60b5 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3249,7 +3249,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg // for (int j = 1; j <= 6; ++j) printf(" %s(%s)\n", ggml_op_name(cgraph->nodes[i+j]->op), cgraph->nodes[i+j]->name); //} // Disabled because currently there is something wrong in the fused kernel implementation - if (false && ENABLE_FUSION && i + 4 < cgraph->n_nodes && + if (ENABLE_FUSION && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST && diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 064aa98c..b535854e 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -208,74 +208,56 @@ static __global__ void fused_rms_rope_neox_fast(const float * src0_1, const floa int s01_1, int s02_1, int s01_2, int s02_2, int n_dims, float eps) { int i0 = 2*threadIdx.y; - int i1 = blockIdx.x*blockDim.x + threadIdx.x; - int i2 = blockIdx.z*blockDim.z + threadIdx.z; + int i2 = blockIdx.x*blockDim.x + threadIdx.x; + int i1 = blockIdx.z*blockDim.z + threadIdx.z; - __shared__ float s_sum[WARP_SIZE]; + const float * src0, *c; + float * dst; + int ne1, s01, s02; - src0_1 += i1*s01_1 + i2*s02_1; - src0_2 += i1*s01_2 + i2*s02_2; - dst_1 += ne0*(i1 + i2*ne1_1); - dst_2 += ne0*(i1 + i2*ne1_2); - - float norm_1 = 1, norm_2 = 1; if (i1 < ne1_1) { - float sum = i0 < ne0 ? src0_1[i0]*src0_1[i0] + src0_1[i0+1]*src0_1[i0+1] : 0.0f; - sum = warp_reduce_sum(sum); - if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) { - int warp_id = (i0/2) / WARP_SIZE; - int lane_id = (i0/2) % WARP_SIZE; - if (lane_id == 0) s_sum[warp_id] = sum; - __syncthreads(); - sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0; - sum = warp_reduce_sum(sum); - } - norm_1 = rsqrtf(sum/ne0 + eps); + ne1 = ne1_1; + s01 = s01_1; s02 = s02_1; + src0 = src0_1 + i1*s01 + i2*s02; + dst = dst_1 + ne0*(i1 + i2*ne1); + c = c_1; + } else { + i1 -= ne1_1; + ne1 = ne1_2; + s01 = s01_2; s02 = s02_2; + src0 = src0_2 + i1*s01 + i2*s02; + dst = dst_2 + ne0*(i1 + i2*ne1); + c = c_2; } - if (i2 < ne1_2) { - float sum = i0 < ne0 ? src0_2[i0]*src0_2[i0] + src0_2[i0+1]*src0_2[i0+1] : 0.0f; + + float sum = i0 < ne0 ? src0[i0]*src0[i0] + src0[i0+1]*src0[i0+1] : 0.0f; + sum = warp_reduce_sum(sum); + if constexpr (CUDA_ROPE_BLOCK_SIZE > WARP_SIZE) { + __shared__ float s_sum[WARP_SIZE]; + int warp_id = (i0/2) / WARP_SIZE; + int lane_id = (i0/2) % WARP_SIZE; + if (lane_id == 0) s_sum[warp_id] = sum; + __syncthreads(); + sum = lane_id < CUDA_ROPE_BLOCK_SIZE / WARP_SIZE ? s_sum[lane_id] : 0; sum = warp_reduce_sum(sum); - if constexpr (CUDA_ROPE_BLOCK_SIZE > 2*WARP_SIZE) { - int warp_id = (i0/2) / WARP_SIZE; - int lane_id = (i0/2) % WARP_SIZE; - if (lane_id == 0) s_sum[warp_id] = sum; - __syncthreads(); - sum = lane_id < CUDA_ROPE_BLOCK_SIZE / (2*WARP_SIZE) ? s_sum[lane_id] : 0; - sum = warp_reduce_sum(sum); - } - norm_2 = rsqrtf(sum/ne0 + eps); } + float norm = rsqrtf(sum/ne0 + eps); if (i0 >= ne0) return; if (i0 >= n_dims) { - if (i1 < ne1_1) { - dst_1[i0 + 0] = norm_1*c_1[i0 + 0]*src0_1[i0 + 0]; - dst_1[i0 + 1] = norm_1*c_1[i0 + 1]*src0_1[i0 + 1]; - } - if (i1 < ne1_2) { - dst_2[i0 + 0] = norm_2*c_2[i0 + 0]*src0_2[i0 + 0]; - dst_2[i0 + 1] = norm_2*c_2[i0 + 1]*src0_2[i0 + 1]; - } + dst[i0 + 0] = norm*c[i0 + 0]*src0[i0 + 0]; + dst[i0 + 1] = norm*c[i0 + 1]*src0[i0 + 1]; return; } const float cos_theta = src1[i2*ne0 + i0 + 0]; const float sin_theta = src1[i2*ne0 + i0 + 1]; - if (i1 < ne1_1) { - const float x0 = norm_1*c_1[i0/2 + 0]*src0_1[i0/2 + 0]; - const float x1 = norm_1*c_1[i0/2 + n_dims/2]*src0_1[i0/2 + n_dims/2]; - dst_1[i0/2 + 0] = x0*cos_theta - x1*sin_theta; - dst_1[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta; - } - - if (i1 < ne1_2) { - const float x0 = norm_2*c_2[i0/2 + 0]*src0_2[i0/2 + 0]; - const float x1 = norm_2*c_2[i0/2 + n_dims/2]*src0_2[i0/2 + n_dims/2]; - dst_2[i0/2 + 0] = x0*cos_theta - x1*sin_theta; - dst_2[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta; - } + const float x0 = norm*c[i0/2 + 0]*src0[i0/2 + 0]; + const float x1 = norm*c[i0/2 + n_dims/2]*src0[i0/2 + n_dims/2]; + dst[i0/2 + 0] = x0*cos_theta - x1*sin_theta; + dst[i0/2 + n_dims/2] = x0*sin_theta + x1*cos_theta; } @@ -540,8 +522,7 @@ static void fused_rms_rope_neox_fast_cuda(const float * src0_1, const float * sr GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 <= 2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - int ne1 = std::max(ne1_1, ne1_2); - const dim3 block_nums(ne1, 1, ne2); + const dim3 block_nums(ne2, 1, ne1_1 + ne1_2); fused_rms_rope_neox_fast<<>>(src0_1, src0_2, src1, c_1, c_2, dst_1, dst_2, ne0, ne1_1, ne1_2, s01_1, s02_1, s01_2, s02_2, n_dims, eps); } @@ -974,22 +955,26 @@ bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tens const auto src0_1 = rms_1->src[0]; const auto src0_2 = rms_2->src[0]; + const auto c_1 = rms_1->src[1]; + const auto c_2 = rms_2->src[1]; if (src0_1->type != GGML_TYPE_F32) return false; if (src0_2->type != GGML_TYPE_F32) return false; if (dst1->type != GGML_TYPE_F32) return false; if (dst2->type != GGML_TYPE_F32) return false; if (src1->type != dst1->type) return false; + if (c_1->type != GGML_TYPE_F32) return false; + if (c_2->type != GGML_TYPE_F32) return false; if (src0_1->ne[0] != src0_2->ne[0]) return false; if (src0_1->ne[2] != src0_2->ne[2]) return false; if (src0_1->ne[3] != src0_2->ne[3]) return false; if (src0_1->ne[0] > 2*CUDA_ROPE_BLOCK_SIZE) return false; - GGML_ASSERT(ggml_nrows(rms_1->src[1]) == 1); - GGML_ASSERT(ggml_nrows(rms_2->src[1]) == 1); - GGML_ASSERT(rms_1->src[1]->ne[0] == src0_1->ne[0]); - GGML_ASSERT(rms_2->src[1]->ne[0] == src0_2->ne[0]); + GGML_ASSERT(ggml_nrows(c_1) == 1); + GGML_ASSERT(ggml_nrows(c_2) == 1); + GGML_ASSERT(c_1->ne[0] == src0_1->ne[0]); + GGML_ASSERT(c_2->ne[0] == src0_2->ne[0]); const int n_dims = ((const int32_t *) src1->op_params)[1]; const int mode = ((const int32_t *) src1->op_params)[2]; @@ -1023,7 +1008,7 @@ bool ggml_cuda_op_fused_rms_rope_fast(ggml_backend_cuda_context & ctx, ggml_tens // compute fused_rms_rope_neox_fast_cuda( (const float *)src0_1->data, (const float *)src0_2->data, (const float *)src1->data, - (const float *)rms_1->src[1]->data, (const float *)rms_2->src[1]->data, + (const float *)c_1->data, (const float *)c_2->data, (float *)dst1->data, (float *)dst2->data, ne00, ne01_1, ne01_2, ne02, s01_1, s02_1, s01_2, s02_2, n_dims, eps1, ctx.stream()); return true;