This works for mainline supported quants

This commit is contained in:
Iwan Kawrakow
2025-08-25 07:46:20 +03:00
parent c4f2af9ccc
commit e9afb0b8fc
3 changed files with 85 additions and 17 deletions

View File

@@ -2394,13 +2394,10 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
//printf("src0(%s): %ld x %ld x %ld, src1: %ld x %ld x %ld dst: ids: %ld x %ld x %ld, %ld x %ld x %ld\n",
// src0->name, src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2],
// ids->ne[0], ids->ne[1], ids->ne[2], dst->ne[0], dst->ne[1], dst->ne[2]);
ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr);
return false;
if (ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr);
return false;
}
GGML_TENSOR_BINARY_OP_LOCALS
@@ -2679,19 +2676,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
cudaStream_t stream = ctx.stream();
ggml_tensor src0_1_row = *src0_1;
ggml_tensor src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
ggml_tensor final_dst;
ggml_tensor final_src;
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];
if (src1->ne[2] <= 2048 &&
ggml_tensor dst_row = *dst;
if (src1->ne[2] <= 2048 && // TODO: this depends on number of total vs number of active experts -> need to find optimum threshod
ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 &&
ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
ggml_cuda_can_use_mmq_id(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
const int64_t ne_get_rows = ne12 * n_ids;
ggml_cuda_pool_alloc<int32_t> ids_device(ctx.pool(), ne_get_rows + ne_get_rows + n_as + 1);
@@ -2746,7 +2738,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
CUDA_CHECK(cudaGetLastError());
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
//ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr);
@@ -2762,6 +2753,12 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
ggml_tensor src0_1_row = *src0_1;
ggml_tensor src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
ggml_tensor final_dst;
ggml_tensor final_src;
char * src0_1_original = (char *) src0_1->data;
char * src0_2_original = (char *) src0_2->data;
char * src1_original = (char *) src1->data;

View File

@@ -399,3 +399,73 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor *
ggml_cuda_mul_mat_q_switch_type_id(ctx, args, stream);
}
bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) {
bool mmq_supported;
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
mmq_supported = true;
break;
default:
mmq_supported = false;
break;
}
if (!mmq_supported) {
return false;
}
if (turing_mma_available(cc)) {
return true;
}
if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
return false;
}
#ifdef GGML_CUDA_FORCE_MMQ
return true;
#endif //GGML_CUDA_FORCE_MMQ
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
if (amd_mfma_available(cc)) {
// As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
// performs better but is currently suffering from a crash on this architecture.
// TODO: Revisit when hipblaslt is fixed on CDNA3
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
return true;
}
if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
return true;
}
if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
return true;
}
return false;
}
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

View File

@@ -9,3 +9,4 @@ void ggml_cuda_mul_mat_q_id(
void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream);
bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11);