Now also -fmoe works

This commit is contained in:
Iwan Kawrakow
2025-08-24 18:59:18 +03:00
parent af6b5365cc
commit 406da6670d
3 changed files with 254 additions and 210 deletions

View File

@@ -40,6 +40,7 @@
#include "ggml-cuda/add-id.cuh"
#include "ggml-cuda/graph.cuh"
#include "ggml-cuda/mmq_id.cuh"
#include "ggml-cuda/quantize_id.cuh"
#include <algorithm>
#include <array>
@@ -2393,6 +2394,10 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
//printf("src0(%s): %ld x %ld x %ld, src1: %ld x %ld x %ld dst: ids: %ld x %ld x %ld, %ld x %ld x %ld\n",
// src0->name, src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2],
// ids->ne[0], ids->ne[1], ids->ne[2], dst->ne[0], dst->ne[1], dst->ne[2]);
ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr);
return false;
@@ -2667,21 +2672,13 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
}
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers");
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_2->buffer) && "mul_mat_id does not support split buffers");
GGML_TENSOR_BINARY_OP_LOCALS
cudaStream_t stream = ctx.stream();
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
ggml_tensor src0_1_row = *src0_1;
ggml_tensor src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
@@ -2689,6 +2686,61 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
ggml_tensor final_dst;
ggml_tensor final_src;
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];
if (src1->ne[2] <= 2048 &&
ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 &&
ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
const int64_t ne_get_rows = ne12 * n_ids;
ggml_cuda_pool_alloc<int32_t> ids_device(ctx.pool(), ne_get_rows + ne_get_rows + n_as + 1);
auto ids_src1 = ids_device.get();
auto ids_dst = ids_src1 + ne_get_rows;
auto expert_bounds = ids_dst + ne_get_rows;
compute_row_ids((const int32_t *)ids->data, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream);
const int64_t ne11_flat = ne12*n_ids;
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool(), nbytes_src1_q8_1);
size_t ts_src1 = ggml_type_size(src1->type);
quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(),
src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1,
ne10_padded, ne11_flat, 1, 1, stream);
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
dst_row.data = dst_up_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get());
dst_row.data = dst_gate_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst->data);
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
//ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr);
ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr);
return true;
}
return false;
}
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
char * src0_1_original = (char *) src0_1->data;
char * src0_2_original = (char *) src0_2->data;
char * src1_original = (char *) src1->data;
@@ -2728,57 +2780,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
final_src.nb[3] = final_src.nb[2];
}
if (false && ne12 == 1) {
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
if (fuse_down) {
final_dst.src[1] = &dst_row;
}
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + id*ids->nb[0]);
if (i02 < 0 || i02 >= n_as) continue;
//GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = 0;
const int64_t i1 = id;
const int64_t i2 = i12;
src0_1_row.data = src0_1_original + i02*nb02;
src0_2_row.data = src0_2_original + i02*nb02;
src1_row.data = src1_original + i11*nb11 + i12*nb12;
//dst_row.data = dst_original + i1*nb1 + i2*nb2;
dst_row.data = dst_up_contiguous.get();
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
CUDA_CHECK(cudaGetLastError());
dst_row.data = dst_gate_contiguous.get();
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
CUDA_CHECK(cudaGetLastError());
if (fuse_down) {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
CUDA_CHECK(cudaGetLastError());
final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2];
ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
CUDA_CHECK(cudaGetLastError());
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
CUDA_CHECK(cudaGetLastError());
}
}
} else {
//printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]);
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
bool use_quantized_src1 = false;
int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0;
@@ -2945,7 +2946,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
CUDA_CHECK(cudaGetLastError());
}
}
}
return fuse_down;
}

View File

@@ -4131,6 +4131,46 @@ static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx,
}
}
void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21,
cudaStream_t stream) {
const int si1 = nb21 / sizeof(int);
const int sis1 = nb12 / nb11;
switch (n_expert_used) {
case 2:
launch_mmq_ids_helper< 2> (ids, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 4:
launch_mmq_ids_helper< 4> (ids, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 6:
launch_mmq_ids_helper< 6> (ids, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 8:
launch_mmq_ids_helper< 8> (ids, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 16:
launch_mmq_ids_helper<16> (ids, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 32:
launch_mmq_ids_helper<32> (ids, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
default:
launch_mmq_ids_helper< 0> (ids, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
}
CUDA_CHECK(cudaGetLastError());
}
void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);

View File

@@ -5,3 +5,7 @@
void ggml_cuda_mul_mat_q_id(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids,
ggml_tensor * dst, char * ids_data, char * src1_quantized_data);
void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream);