diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index a7ae405e..8a7cd303 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2344,11 +2344,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids); ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids); - //printf("dst: %ld x %ld x %ld, %zu x %zu x %zu\n", dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2]); - //if (next && next->op == GGML_OP_MUL_MAT_ID) { - // printf(" next: %ld x %ld x %ld, %zu x %zu x %zu\n", next->ne[0], next->ne[1], next->ne[2], next->nb[0], next->nb[1], next->nb[2]); - //} - auto local_dst = *dst; local_dst.ne[2] = n_ids; local_dst.ne[1] = local_dst.ne[3] = 1; @@ -2409,32 +2404,24 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); - //printf("next: %ld x %ld x %ld, %zu x %zu x %zu\n", next->ne[0], next->ne[1], next->ne[2], next->nb[0], next->nb[1], next->nb[2]); - local_dst.ne[2] = 1; auto local_next = *next; - local_next.ne[1] = local_next.ne[2] = local_next.ne[3] = 1; + local_next.ne[2] = local_next.ne[1]; + local_next.ne[1] = local_next.ne[3] = 1; + local_next.nb[2] = local_next.nb[1]; - auto local_src1 = *next->src[1]; + local_src1 = *next->src[1]; local_src1.ne[1] = local_src1.ne[2] = local_src1.ne[3] = 1; - local_src1.nb[1] = dst->ne[0]*sizeof(float); - local_src1.nb[2] = local_src1.nb[3] = 0; + local_src1.nb[1] = local_src1.nb[2] = local_src1.nb[3] = dst_row_size; auto local_src0 = *next->src[0]; local_src0.ne[2] = local_src0.ne[3] = 1; - for (int i = 0; i < n_ids; ++i) { - const int32_t i02 = *(const int32_t *) (ids_host.data() + i*ids->nb[0]); - if (i02 < 0) continue; - local_src0.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - local_next.data = (char *)next->data + i*next->nb[1]; - local_src1.data = (char *)dst_gate_contiguous.get() + i*dst->ne[0]*sizeof(float); - ggml_cuda_op_mul_mat_vec_q(ctx, &local_src0, &local_src1, &local_next, - (const char *)local_src0.data, nullptr, dst_quantized.get() + i*dst_row_size, (float *)local_next.data, - 0, local_src0.ne[1], 1, dst_padded_col_size, stream); - CUDA_CHECK(cudaGetLastError()); - } + ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, + (const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data, + 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); return true; } else { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 81b174ef..b9e9c216 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -556,7 +556,7 @@ void ggml_cuda_op_mul_mat_vec_q_id( ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, ne00, ne0, dst->ne[2], - src0->nb[2], 0, dst->nb[2], ids->nb[0], + src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], src0_dd_i, src1_ddq_i, dst_dd_i, (const char *)ids->data, row_low, row_high, src1_ncols, src1_padded_row_size, stream);