diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 93fa4ff4..c234ec07 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -40,6 +40,7 @@ #include "ggml-cuda/add-id.cuh" #include "ggml-cuda/graph.cuh" #include "ggml-cuda/mmq_id.cuh" +#include "ggml-cuda/quantize_id.cuh" #include #include @@ -2393,6 +2394,10 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } + //printf("src0(%s): %ld x %ld x %ld, src1: %ld x %ld x %ld dst: ids: %ld x %ld x %ld, %ld x %ld x %ld\n", + // src0->name, src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], + // ids->ne[0], ids->ne[1], ids->ne[2], dst->ne[0], dst->ne[1], dst->ne[2]); + ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr); return false; @@ -2667,21 +2672,13 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } } - GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers"); GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_2->buffer) && "mul_mat_id does not support split buffers"); + GGML_TENSOR_BINARY_OP_LOCALS + cudaStream_t stream = ctx.stream(); - const int64_t n_as = ne02; - const int64_t n_ids = ids->ne[0]; - - std::vector ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - ggml_tensor src0_1_row = *src0_1; ggml_tensor src0_2_row = *src0_2; ggml_tensor src1_row = *src1; @@ -2689,6 +2686,61 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor ggml_tensor final_dst; ggml_tensor final_src; + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; + + if (src1->ne[2] <= 2048 && + ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 && + ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + + const int64_t ne_get_rows = ne12 * n_ids; + ggml_cuda_pool_alloc ids_device(ctx.pool(), ne_get_rows + ne_get_rows + n_as + 1); + auto ids_src1 = ids_device.get(); + auto ids_dst = ids_src1 + ne_get_rows; + auto expert_bounds = ids_dst + ne_get_rows; + + compute_row_ids((const int32_t *)ids->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream); + + const int64_t ne11_flat = ne12*n_ids; + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc src1_quantized(ctx.pool(), nbytes_src1_q8_1); + + size_t ts_src1 = ggml_type_size(src1->type); + quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(), + src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1, + ne10_padded, ne11_flat, 1, 1, stream); + + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst->data); + + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && + ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + //ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr); + ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr); + return true; + } + + return false; + } + + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + char * src0_1_original = (char *) src0_1->data; char * src0_2_original = (char *) src0_2->data; char * src1_original = (char *) src1->data; @@ -2728,222 +2780,170 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor final_src.nb[3] = final_src.nb[2]; } - if (false && ne12 == 1) { - ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); - ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); - if (fuse_down) { - final_dst.src[1] = &dst_row; + ggml_cuda_pool_alloc src1_quantized(ctx.pool()); + bool use_quantized_src1 = false; + int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; + if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { + if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); + src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + src1_quantized.alloc(src1_quantized_size); + use_quantized_src1 = true; } - for (int64_t id = 0; id < n_ids; id++) { - const int32_t i02 = *(const int32_t *) (ids_host.data() + id*ids->nb[0]); + } + ggml_cuda_pool_alloc src1_contiguous(ctx.pool()); + if (!use_quantized_src1) { + src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1)); + } + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); + if (fuse_down) { + final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); + final_dst.src[1] = &dst_row; + } - if (i02 < 0 || i02 >= n_as) continue; - //GGML_ASSERT(i02 >= 0 && i02 < n_as); + src1_row.data = src1_contiguous.get(); - const int64_t i11 = id % ne11; - const int64_t i12 = 0; + bool first = false; //true; - const int64_t i1 = id; - const int64_t i2 = i12; + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); + std::vector moe_counts, cum_moe_counts; - src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; - src1_row.data = src1_original + i11*nb11 + i12*nb12; - //dst_row.data = dst_original + i1*nb1 + i2*nb2; + bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); + if (is_ser) { + if (fuse_down) { + CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream)); + } else { + CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); + } + } - dst_row.data = dst_up_contiguous.get(); + for (int64_t i02 = 0; i02 < n_as; i02++) { + int64_t num_src1_rows = moe_counts[i02]; + + if (num_src1_rows == 0) continue; + size_t mapping_offset = cum_moe_counts[i02]; + + if (use_quantized_src1) { + quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset), + src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream); + CUDA_CHECK(cudaGetLastError()); + src1_row.data = src1_quantized.get(); + } + else { + dim3 block_dims(std::min((unsigned int)ne10, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_src_to_contiguous<<>>( + src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + src1_row.data = src1_contiguous.get(); + } + + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; + + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); + + src1_row.ne[1] = num_src1_rows; + src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; + src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; + src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; + + dst_row.ne[1] = num_src1_rows; + dst_row.nb[1] = nb1; + dst_row.nb[2] = num_src1_rows*nb1; + dst_row.nb[3] = num_src1_rows*nb1; + + dst_row.data = dst_up_contiguous.get(); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, + 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); - CUDA_CHECK(cudaGetLastError()); + } + CUDA_CHECK(cudaGetLastError()); - dst_row.data = dst_gate_contiguous.get(); + if (dst->src[4]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data); + CUDA_CHECK(cudaGetLastError()); + } + + dst_row.data = dst_gate_contiguous.get(); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, + 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + } + CUDA_CHECK(cudaGetLastError()); + + if (dst->src[5]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); CUDA_CHECK(cudaGetLastError()); - - if (fuse_down) { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); - CUDA_CHECK(cudaGetLastError()); - - final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); - CUDA_CHECK(cudaGetLastError()); - - } else { - - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); - CUDA_CHECK(cudaGetLastError()); - - } } - } else { - //printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]); - ggml_cuda_pool_alloc src1_quantized(ctx.pool()); - bool use_quantized_src1 = false; - int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; - if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { - if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { - src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); - src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); - src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); - src1_quantized.alloc(src1_quantized_size); - use_quantized_src1 = true; - } + + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); } - ggml_cuda_pool_alloc src1_contiguous(ctx.pool()); - if (!use_quantized_src1) { - src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1)); - } - ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); + CUDA_CHECK(cudaGetLastError()); + if (fuse_down) { - final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); - final_dst.src[1] = &dst_row; + + final_dst.ne[1] = num_src1_rows; + final_dst.nb[1] = final_dst.ne[0]*sizeof(float); + final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + if (first) { + printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, + (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], + (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], + (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); + printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", + (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], + (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], + (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); + first = false; + } + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); + CUDA_CHECK(cudaGetLastError()); + + dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + (char *)next->data, final_dst_contiguous.get(), + dev_row_mapping.get() + mapping_offset, + next->ne[0], + next->nb[1], next->nb[2]); + CUDA_CHECK(cudaGetLastError()); + } + else { - src1_row.data = src1_contiguous.get(); - - bool first = false; //true; - - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); - std::vector moe_counts, cum_moe_counts; - - bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); - if (is_ser) { - if (fuse_down) { - CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream)); - } else { - CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); - } - } - - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = moe_counts[i02]; - - if (num_src1_rows == 0) continue; - size_t mapping_offset = cum_moe_counts[i02]; - - if (use_quantized_src1) { - quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset), - src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream); - CUDA_CHECK(cudaGetLastError()); - src1_row.data = src1_quantized.get(); - } - else { - dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_src_to_contiguous<<>>( - src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); - CUDA_CHECK(cudaGetLastError()); - src1_row.data = src1_contiguous.get(); - } - - src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; - - GGML_ASSERT(nb11 == sizeof(float)*ne10); - GGML_ASSERT(nb1 == sizeof(float)*ne0); - - src1_row.ne[1] = num_src1_rows; - src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; - src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; - src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; - - dst_row.ne[1] = num_src1_rows; - dst_row.nb[1] = nb1; - dst_row.nb[2] = num_src1_rows*nb1; - dst_row.nb[3] = num_src1_rows*nb1; - - dst_row.data = dst_up_contiguous.get(); - if (use_quantized_src1) { - ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, - 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); - } else { - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); - } + dim3 block_dims(std::min((unsigned int)ne0, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + dst_original, dst_gate_contiguous.get(), + dev_row_mapping.get() + mapping_offset, + ne0, + nb1, nb2); CUDA_CHECK(cudaGetLastError()); - - if (dst->src[4]) { - dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); - dim3 grid_dims(num_src1_rows); - k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, - (const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data); - CUDA_CHECK(cudaGetLastError()); - } - - dst_row.data = dst_gate_contiguous.get(); - if (use_quantized_src1) { - ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, - 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); - } else { - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); - } - CUDA_CHECK(cudaGetLastError()); - - if (dst->src[5]) { - dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); - dim3 grid_dims(num_src1_rows); - k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, - (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); - CUDA_CHECK(cudaGetLastError()); - } - - auto unary_op = (ggml_unary_op)dst->op_params[0]; - if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { - ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), - (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], - 1.702f, 7.0f, stream); - } else { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), - (float *)dst_gate_contiguous.get()); - } - CUDA_CHECK(cudaGetLastError()); - - if (fuse_down) { - - final_dst.ne[1] = num_src1_rows; - final_dst.nb[1] = final_dst.ne[0]*sizeof(float); - final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; - final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - if (first) { - printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, - (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], - (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], - (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); - printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", - (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], - (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], - (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); - first = false; - } - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); - //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); - CUDA_CHECK(cudaGetLastError()); - - dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_dst_from_contiguous<<>>( - (char *)next->data, final_dst_contiguous.get(), - dev_row_mapping.get() + mapping_offset, - next->ne[0], - next->nb[1], next->nb[2]); - CUDA_CHECK(cudaGetLastError()); - - } - else { - - dim3 block_dims(std::min((unsigned int)ne0, 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_dst_from_contiguous<<>>( - dst_original, dst_gate_contiguous.get(), - dev_row_mapping.get() + mapping_offset, - ne0, - nb1, nb2); - CUDA_CHECK(cudaGetLastError()); - } } } diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu index 7c0a76fc..78282997 100644 --- a/ggml/src/ggml-cuda/mmq_id.cu +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -4131,6 +4131,46 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, } } +void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, + cudaStream_t stream) { + + const int si1 = nb21 / sizeof(int); + const int sis1 = nb12 / nb11; + + switch (n_expert_used) { + case 2: + launch_mmq_ids_helper< 2> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 4: + launch_mmq_ids_helper< 4> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 6: + launch_mmq_ids_helper< 6> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 8: + launch_mmq_ids_helper< 8> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 16: + launch_mmq_ids_helper<16> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 32: + launch_mmq_ids_helper<32> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + default: + launch_mmq_ids_helper< 0> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + } + CUDA_CHECK(cudaGetLastError()); +} + void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data) { GGML_ASSERT( src1->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-cuda/mmq_id.cuh b/ggml/src/ggml-cuda/mmq_id.cuh index bc5d7c61..c85c468f 100644 --- a/ggml/src/ggml-cuda/mmq_id.cuh +++ b/ggml/src/ggml-cuda/mmq_id.cuh @@ -5,3 +5,7 @@ void ggml_cuda_mul_mat_q_id( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, char * ids_data, char * src1_quantized_data); + +void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream); +