Skip the row id computation for the ffn_down op

Sadly, almost negligible performance gain.
This commit is contained in:
Iwan Kawrakow
2025-08-27 18:07:57 +03:00
parent 46968d4ab1
commit 0ce8068d2b
3 changed files with 34 additions and 31 deletions

View File

@@ -2395,7 +2395,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
if (ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr);
ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr, false);
return false;
}
@@ -2702,36 +2702,38 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
compute_row_ids((const int32_t *)ids->data, ids_src1, ids_dst, expert_bounds,
ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream);
const int64_t ne11_flat = ne12*n_ids;
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool(), nbytes_src1_q8_1);
size_t ts_src1 = ggml_type_size(src1->type);
quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(),
src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1,
ne10_padded, ne11_flat, 1, 1, stream);
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
dst_row.data = dst_up_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get());
if (dst->src[4]) {
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data,
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
CUDA_CHECK(cudaGetLastError());
}
{
const int64_t ne11_flat = ne12*n_ids;
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool(), nbytes_src1_q8_1);
dst_row.data = dst_gate_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get());
if (dst->src[5]) {
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data,
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
CUDA_CHECK(cudaGetLastError());
size_t ts_src1 = ggml_type_size(src1->type);
quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(),
src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1,
ne10_padded, ne11_flat, 1, 1, stream);
dst_row.data = dst_up_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get(), false);
if (dst->src[4]) {
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data,
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
CUDA_CHECK(cudaGetLastError());
}
dst_row.data = dst_gate_contiguous.get();
ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get(), false);
if (dst->src[5]) {
ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data,
(float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1],
dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream);
CUDA_CHECK(cudaGetLastError());
}
}
auto unary_op = (ggml_unary_op)dst->op_params[0];
@@ -2748,8 +2750,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
//ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr);
ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr);
ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr, true);
//ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr);
return true;
}

View File

@@ -314,7 +314,7 @@ void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst,
}
void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data) {
const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data, bool is_next) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ids_tensor->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
@@ -377,6 +377,7 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor *
ids_src1 = (int32_t *)ids_data;
ids_dst = ids_src1 + ne_get_rows;
expert_bounds = ids_dst + ne_get_rows;
if (is_next) ids_src1 = ids_dst;
}
else {
GGML_ASSERT(ids_tensor->nb[0] == ggml_element_size(ids_tensor));

View File

@@ -4,7 +4,7 @@
void ggml_cuda_mul_mat_q_id(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids,
ggml_tensor * dst, char * ids_data, char * src1_quantized_data);
ggml_tensor * dst, char * ids_data, char * src1_quantized_data, bool is_next);
void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream);