CUDA: prompt processing optimizations for MoE models (#739)

* Skip the row id computation for the ffn_down op

Sadly, almost negligible performance gain.

* Also this doesn't do much

* Also this barely moves the needle

* This is slightly better

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-08-30 12:09:41 +03:00
committed by GitHub
parent 46968d4ab1
commit f22a9ef95a
5 changed files with 117 additions and 57 deletions

View File

@@ -2395,7 +2395,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
if (ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr);
ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr, false);
return false;
}
@@ -2702,36 +2702,38 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
compute_row_ids((const int32_t *)ids->data, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream);
const int64_t ne11_flat = ne12*n_ids;
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool(), nbytes_src1_q8_1);
size_t ts_src1 = ggml_type_size(src1->type);
quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(),
src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1,
ne10_padded, ne11_flat, 1, 1, stream);
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
dst_row.data = dst_up_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get());
if (dst->src[4]) {
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data,
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
CUDA_CHECK(cudaGetLastError());
}
{
const int64_t ne11_flat = ne12*n_ids;
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool(), nbytes_src1_q8_1);
dst_row.data = dst_gate_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get());
if (dst->src[5]) {
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data,
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
CUDA_CHECK(cudaGetLastError());
size_t ts_src1 = ggml_type_size(src1->type);
quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(),
src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1,
ne10_padded, ne11_flat, 1, 1, stream);
dst_row.data = dst_up_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get(), false);
if (dst->src[4]) {
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data,
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
CUDA_CHECK(cudaGetLastError());
}
dst_row.data = dst_gate_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get(), false);
if (dst->src[5]) {
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data,
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
CUDA_CHECK(cudaGetLastError());
}
}
auto unary_op = (ggml_unary_op)dst->op_params[0];
@@ -2748,19 +2750,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
//ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr);
ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr);
ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr, true);
//ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr);
return true;
}
return false;
}
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
ggml_tensor src0_1_row = *src0_1;
ggml_tensor src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
@@ -2834,20 +2831,19 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
bool first = false; //true;
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool());
std::vector<int> moe_counts, cum_moe_counts;
const int64_t ne_get_rows = ne12 * n_ids;
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), ne_get_rows + (n_as + 2)/2);
bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
if (is_ser) {
if (fuse_down) {
CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream));
} else {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
}
}
compute_row_ids2((const int32_t *)ids->data, dev_row_mapping.get(), (int32_t *)(dev_row_mapping.get() + ne_get_rows),
ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream);
std::vector<int> cum_moe_counts(n_as + 1);
CUDA_CHECK(cudaMemcpyAsync(cum_moe_counts.data(), dev_row_mapping.get() + ne_get_rows, (n_as+1)*sizeof(int),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = moe_counts[i02];
int64_t num_src1_rows = cum_moe_counts[i02+1] - cum_moe_counts[i02];
if (num_src1_rows == 0) continue;
size_t mapping_offset = cum_moe_counts[i02];

View File

@@ -24,11 +24,13 @@ struct mmq_ids_helper_store {
};
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
struct mmid_row_mapping { int32_t i1, i2; };
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
template <int n_expert_used_template, int type = 0>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
@@ -103,8 +105,12 @@ static __global__ void mmq_ids_helper(
const mmq_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
if constexpr (type == 0) {
ids_src1[nex_prev + itc] = it;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
} else {
((mmid_row_mapping *)ids_src1)[nex_prev + itc] = {iex_used, it};
}
}
if (threadIdx.x != 0) {
@@ -120,7 +126,7 @@ static __global__ void mmq_ids_helper(
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
template <int n_expert_used_template, int type = 0>
static void launch_mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
@@ -135,7 +141,7 @@ static void launch_mmq_ids_helper(
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
mmq_ids_helper<n_expert_used_template, type><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
@@ -313,8 +319,48 @@ void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst,
CUDA_CHECK(cudaGetLastError());
}
void compute_row_ids2(const int32_t * ids, mmid_row_mapping * rmapping, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21,
cudaStream_t stream) {
const int si1 = nb21 / sizeof(int);
const int sis1 = nb12 / nb11;
switch (n_expert_used) {
case 2:
launch_mmq_ids_helper< 2, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 4:
launch_mmq_ids_helper< 4, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 6:
launch_mmq_ids_helper< 6, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 8:
launch_mmq_ids_helper< 8, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 16:
launch_mmq_ids_helper<16, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 32:
launch_mmq_ids_helper<32, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
default:
launch_mmq_ids_helper< 0, 1> (ids, (int32_t *)rmapping, nullptr, expert_bounds,
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
}
CUDA_CHECK(cudaGetLastError());
}
void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data) {
const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data, bool is_next) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ids_tensor->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
@@ -377,6 +423,7 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor *
ids_src1 = (int32_t *)ids_data;
ids_dst = ids_src1 + ne_get_rows;
expert_bounds = ids_dst + ne_get_rows;
if (is_next) ids_src1 = ids_dst;
}
else {
GGML_ASSERT(ids_tensor->nb[0] == ggml_element_size(ids_tensor));
@@ -460,7 +507,7 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor *
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k, ne12};
use_stream_k, ne12, (int)n_expert_used, (int)ne02};
//printf("ne00 = %ld, ne01 = %ld, ne_get_rows = %ld, s01 = %ld, s1 = %ld\n", ne00, ne01, ne_get_rows, s01, s1);
//printf("ne02 = %ld, s02 = %ld, s12 = %ld, s2 = %ld\n", ne02, s02, s12, s2);

View File

@@ -4,9 +4,14 @@
void ggml_cuda_mul_mat_q_id(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids,
ggml_tensor * dst, char * ids_data, char * src1_quantized_data);
ggml_tensor * dst, char * ids_data, char * src1_quantized_data, bool is_next);
void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream);
struct mmid_row_mapping;
void compute_row_ids2(const int32_t * ids, mmid_row_mapping * rmapping, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream);
bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11);

View File

@@ -3943,7 +3943,7 @@ struct mmq_args_id {
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
bool use_stream_k; int64_t ncols_max;
bool use_stream_k; int64_t ncols_max; int n_experts_used; int n_total_experts;
};
template<ggml_type type>
@@ -4057,9 +4057,21 @@ void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args
const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size;
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
const int mmq_x_max = get_mmq_x_max_host(cc);
int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc);
int ncols_max = args.ncols_max;
if (args.ids_dst && 4*args.n_experts_used < args.n_total_experts) {
ncols_max *= 4*args.n_experts_used;
ncols_max /= args.n_total_experts;
if (ncols_max < 1) ncols_max = 1;
ncols_max = 32*((ncols_max + 31)/32);
//ncols_max = 16*((ncols_max + 15)/16);
if (ncols_max > args.ncols_max) ncols_max = args.ncols_max;
//printf("%s: ncols_max = %d, %d\n", __func__, (int)args.ncols_max, ncols_max);
//mmq_x_max /= 2;
}
int mmq_x_best = 0;
int ntiles_x_best = INT_MAX;
@@ -4070,7 +4082,7 @@ void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args
continue;
}
const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
const int ntiles_x = (ncols_max + mmq_x - 1) / mmq_x;
if (ntiles_x < ntiles_x_best) {
mmq_x_best = mmq_x;

View File

@@ -13342,7 +13342,7 @@ struct llm_build_context {
// whether to use n_tokens as the matrix dimension during multiplication or n_head
// n_tokens is higher during prompt processing, this allows to optimize for this case
bool pp_opt = n_tokens >= 128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head;
bool pp_opt = n_tokens >= 32; //128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head;
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;