diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index dd6e8616..afcef90a 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2712,22 +2712,24 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten CUDA_CHECK(cudaStreamSynchronize(stream)); ggml_tensor src0_1_row = *src0_1; - ggml_tensor src0_2_row = *src0_2; + ggml_tensor src0_2_row; if (src0_2) src0_2_row = *src0_2; ggml_tensor src1_row = *src1; ggml_tensor final_dst; ggml_tensor final_src; char * src0_1_original = (char *) src0_1->data; - char * src0_2_original = (char *) src0_2->data; + char * src0_2_original = src0_2 ? (char *) src0_2->data : nullptr; char * src1_original = (char *) src1->data; char * dst_original = (char *) dst->data; src0_1_row.ne[2] = 1; src0_1_row.ne[3] = 1; src0_1_row.nb[3] = nb02; - src0_2_row.ne[2] = 1; - src0_2_row.ne[3] = 1; - src0_2_row.nb[3] = nb02; + if (src0_2) { + src0_2_row.ne[2] = 1; + src0_2_row.ne[3] = 1; + src0_2_row.nb[3] = nb02; + } src1_row.ne[1] = 1; src1_row.ne[2] = 1; @@ -2755,7 +2757,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten ggml_cuda_pool_alloc src1_quantized(ctx.pool()); bool use_quantized_src1 = false; int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; - if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { + if (ggml_is_quantized(src0_1->type) && (!src0_2 || src0_1->type == src0_2->type) && src1->ne[1] == 1 && src1->ne[3] == 1) { if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); @@ -2768,8 +2770,14 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten if (!use_quantized_src1) { src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1)); } - ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool()), dst_gate_contiguous(ctx.pool()); + if (src0_2) { + dst_up_contiguous.alloc(sizeof(float)*ggml_nelements(dst)); + dst_gate_contiguous.alloc(sizeof(float)*ggml_nelements(dst)); + } else { + dst_up_contiguous.alloc(2*sizeof(float)*ggml_nelements(dst)); + dst_gate_contiguous.alloc(sizeof(float)*ggml_nelements(dst)); + } ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); if (fuse_down) { final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); @@ -2812,20 +2820,26 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten } src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; + if (src0_2_original) src0_2_row.data = src0_2_original + i02*nb02; GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb1 == sizeof(float)*ne0); + auto nb1l = nb1; + if (!src0_2) { + nb1l *= 2; + dst_row.ne[0] *= 2; + } + src1_row.ne[1] = num_src1_rows; src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; dst_row.ne[1] = num_src1_rows; - dst_row.nb[1] = nb1; - dst_row.nb[2] = num_src1_rows*nb1; - dst_row.nb[3] = num_src1_rows*nb1; + dst_row.nb[1] = nb1l; + dst_row.nb[2] = num_src1_rows*nb1l; + dst_row.nb[3] = num_src1_rows*nb1l; dst_row.data = dst_up_contiguous.get(); if (use_quantized_src1) { @@ -2843,6 +2857,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten CUDA_CHECK(cudaGetLastError()); } + if (src0_2) { dst_row.data = dst_gate_contiguous.get(); if (use_quantized_src1) { ggml_cuda_mul_mat_q_id(ctx, &src0_2_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get()); @@ -2858,8 +2873,10 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); CUDA_CHECK(cudaGetLastError()); } + } auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (src0_2) { if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], @@ -2869,6 +2886,17 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); } + } else { + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_up_contiguous.get() + dst->ne[0], (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), ggml_nelements(dst), dst->ne[0], dst->ne[0], dst->ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } + } CUDA_CHECK(cudaGetLastError()); if (fuse_down) {