diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index abd2479b..8f97c890 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3244,7 +3244,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_rms_norm(ctx, dst); break; case GGML_OP_FUSED_RMS_NORM: - ggml_cuda_op_fused_rms_norm(ctx, dst); + if (i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_VIEW && + cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && + dst->ne[2] == 1 && cgraph->nodes[i+2]->ne[2] == 1) { + ggml_cuda_op_fused_rms_rms_norm(ctx, dst, cgraph->nodes[i+2]); + i += 2; + } else { + ggml_cuda_op_fused_rms_norm(ctx, dst); + } break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index f296b79f..98d33ebc 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -619,3 +619,84 @@ void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx, fused_add_add_rms_norm_f32_cuda((const float *)add1->src[0]->data, (const float *)add1->src[1]->data, (const float *)add2->src[1]->data, src1_d, (float *)add2->data, dst_d, ne00, nrows, eps, stream); } + +template +static __global__ void fused_rms_rms_norm_f32(int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps, + const char *x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + auto x_row = (const float *)(row < nrows1 ? x1 + row*nb1 : x2 + (row - nrows1)*nb2); + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x_row[col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + auto dst = row < nrows1 ? y1 + row*ncols : y2 + (row - nrows1)*ncols; + auto c = row < nrows1 ? c1 : c2; + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale * c[col] * x_row[col]; + } +} + +static void fused_rms_rms_norm_f32_cuda(int ncols, int nrows1, int nrows2, size_t nb1, size_t nb2, float eps, + const char * x1, const char * x2, const float * c1, const float * c2, float * y1, float * y2, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + int nrows = nrows1 + nrows2; + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + fused_rms_rms_norm_f32<256><<>>(ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2); + } else { + const dim3 block_dims(1024, 1, 1); + fused_rms_rms_norm_f32<1024><<>>(ncols, nrows1, nrows2, nb1, nb2, eps, x1, x2, c1, c2, y1, y2); + } +} + +void ggml_cuda_op_fused_rms_rms_norm([[maybe_unused]] ggml_backend_cuda_context & ctx, [[maybe_unused]] ggml_tensor * rms1, [[maybe_unused]] ggml_tensor * rms2) { + GGML_ASSERT(rms1->ne[2] == 1 && rms1->ne[3] == 1); + GGML_ASSERT(rms2->ne[2] == 1 && rms2->ne[3] == 1); + GGML_ASSERT(rms1->ne[0] == rms2->ne[0]); + GGML_ASSERT(rms1->type == GGML_TYPE_F32); + GGML_ASSERT(rms2->type == GGML_TYPE_F32); + GGML_ASSERT(rms1->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms2->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms1->src[0]->ne[0] == rms1->src[1]->ne[0]); + GGML_ASSERT(rms2->src[0]->ne[0] == rms2->src[1]->ne[0]); + GGML_ASSERT(ggml_nrows(rms1->src[1]) == 1); + GGML_ASSERT(ggml_nrows(rms2->src[1]) == 1); + GGML_ASSERT(rms1->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(rms2->src[1]->type == GGML_TYPE_F32); + + float eps1, eps2; + memcpy(&eps1, rms1->op_params, sizeof(float)); + memcpy(&eps2, rms2->op_params, sizeof(float)); + GGML_ASSERT(eps1 == eps2); + + fused_rms_rms_norm_f32_cuda(rms1->ne[0], rms1->ne[1], rms2->ne[1], rms1->nb[1], rms2->nb[1], eps1, + (const char *)rms1->src[0]->data, (const char *)rms2->src[0]->data, + (const float *)rms1->src[1]->data, (const float *)rms2->src[1]->data, + (float *)rms1->data, (float *)rms2->data, ctx.stream()); + + +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 40f758de..cadf7248 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -11,3 +11,5 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst); void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst); + +void ggml_cuda_op_fused_rms_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * rms1, ggml_tensor * rms2); diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 04e5a142..24150b11 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1279,10 +1279,12 @@ std::tuple llm_build_context::llm_buil if (q_norm) { Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); cb(Qcur, "Qcur_normed", il); + ggml_build_forward_expand(gf, Qcur); } if (k_norm) { Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur_normed", il); + ggml_build_forward_expand(gf, Kcur); } return {Qcur, Kcur, Vcur}; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index d2797911..d1265616 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -2451,7 +2451,6 @@ bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias) { layer.wk = ml.create_tensor_as_view(ctx_split, layer.wqkv, wk_name.c_str(), { wk->ne[0], wk->ne[1] }, wq->ne[1]*wq->nb[1]); layer.wv = ml.create_tensor_as_view(ctx_split, layer.wqkv, wv_name.c_str(), { wv->ne[0], wv->ne[1] }, wq->ne[1]*wq->nb[1] + wk->ne[1]*wk->nb[1] ); fused_qkv = true; - printf("================================== Created merged qkv %s\n", layer.wqkv->name); if (bias) { auto bq_name = tn(LLM_TENSOR_ATTN_Q, "bias", i); auto bk_name = tn(LLM_TENSOR_ATTN_K, "bias", i);