Fuse mul_mat_id and add_id into a single kernel for mmvq

This commit is contained in:
Iwan Kawrakow
2025-10-24 18:29:16 +03:00
parent 4a08ac7241
commit 6b57074431
3 changed files with 68 additions and 123 deletions

View File

@@ -1862,93 +1862,6 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
}
/*
static void ggml_cuda_op_gemv_id(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src0_ids, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
quantize_cuda_t quantize_src1) {
GGML_ASSERT(src0->ne[3] == 1);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_nrows(src1) == 1);
GGML_ASSERT(src0_ids->ne[1] == 1);
GGML_ASSERT(src0_ids->ne[0] <= dst->ne[2]);
GGML_ASSERT(dst->ne[1] == 1);
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context;
ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
int device_id = ctx.device;
GGML_ASSERT(src0_ctx->device == device_id);
GGML_ASSERT(src1_ctx->device == device_id);
GGML_ASSERT(dst_ctx->device == device_id);
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
GGML_ASSERT(!split);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne10 = src1->ne[0];
const int64_t nrows1 = 1;
const int64_t ne0 = dst->ne[0];
const int64_t ne2 = dst->ne[2];
const int64_t nb2 = dst->nb[2];
// Why?
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
const size_t src0_rs = ggml_row_size(src0->type, ne00);
const size_t q8_1_ts = sizeof(block_q8_1);
const size_t q8_1_bs = QK8_1;
const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
ggml_cuda_pool_alloc<char> src0_dd_alloc;
ggml_cuda_pool_alloc<float> src1_ddf_alloc;
ggml_cuda_pool_alloc<char> src1_ddq_alloc;
ggml_cuda_pool_alloc<float> dst_dd_alloc;
char * src0_dd = nullptr;
float * src1_ddf = (float *)src1->data;
char * src1_ddq = nullptr; // q8_1
float * dst_dd = (float *)dst->data;
bool quantization_done = false;
const bool src1_on_device = device_id == src1_ctx->device;
const bool dst_on_device = device_id == dst_ctx->device;
ggml_cuda_set_device(device_id);
cudaStream_t stream = ctx.stream(device_id, 0);
src0_dd = (char *) src0->data;
if (quantize_src1) {
size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
src1_ddq = src1_ddq_alloc.alloc(ctx.pool(device_id), src_1_ddq_size);
quantize_src1(src1_ddf, src1_ddq, ne10, 1, 1, src1_padded_col_size, src0->type, stream);
}
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, src1, src0_ids, dst,
(const char *)src0->data, (const float *)src1->data, src1_ddq, (float *)dst->data,
0, ne01, 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
}
*/
static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
@@ -2433,7 +2346,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
local_src1.nb[1] = src_1_ddq_size;
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, &local_src1, ids, &local_dst,
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, &local_src1, ids, &local_dst, nullptr,
(const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
0, src0->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
@@ -2448,7 +2361,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
if (next_src0_ctx->device == device_id &&
next_dst_ctx->device == device_id) {
local_dst.data = next->data;
ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst,
ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst, nullptr,
(const char *)next->src[0]->data, nullptr, src1_quantized.get(), (float *)next->data,
0, src0->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
@@ -2587,7 +2500,10 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
return false;
}
static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
const ggml_cgraph * graph, int i) {
ggml_tensor * next = i + 1 < graph->n_nodes ? graph->nodes[i+1] : nullptr;
const ggml_tensor * src0_1 = dst->src[0];
const ggml_tensor * src0_2 = dst->src[1];
const ggml_tensor * src0 = src0_1;
@@ -2664,7 +2580,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, stream);
CUDA_CHECK(cudaGetLastError());
if (!fuse_next) return false;
if (!fuse_next) return i;
const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
GGML_ASSERT(dst->ne[0] % QK8_1 == 0);
@@ -2689,14 +2605,27 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
auto local_src0 = *next->src[0];
local_src0.ne[2] = local_src0.ne[3] = 1;
CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream));
int result = i + 1;
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next,
(const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data,
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
if (i+2 < graph->n_nodes &&
graph->nodes[i+2]->op == GGML_OP_ADD_ID &&
graph->nodes[i+2]->src[0] == next &&
graph->nodes[i+2]->src[2] == ids) {
//auto bias = graph->nodes[i+2]->src[1];
//printf("Fusing bias: ids: %ld x %ld x %ld x %ld, bias: %ld x %ld x %ld x %ld\n",
// ids->ne[0], ids->ne[2], ids->ne[2], ids->ne[3], bias->ne[0], bias->ne[1], bias->ne[2], bias->ne[3]);
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, graph->nodes[i+2]->src[1],
(const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)graph->nodes[i+2]->data,
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
++result;
} else {
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, nullptr,
(const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data,
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
}
CUDA_CHECK(cudaGetLastError());
return true;
return result;
}
}
@@ -2781,10 +2710,10 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {
//ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr);
ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr);
return true;
return i+1;
}
return false;
return i;
}
std::vector<char> ids_host(ggml_nbytes(ids));
@@ -3003,7 +2932,7 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
}
}
return fuse_down;
return fuse_down ? i+1 : i;
}
static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -3293,7 +3222,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
if (ggml_cuda_mul_mat_id(ctx, dst, next)) ++i;
break;
case GGML_OP_MOE_FUSED_UP_GATE:
if (ggml_cuda_moe_up_gate_unary(ctx, dst, next)) ++i;
i = ggml_cuda_moe_up_gate_unary(ctx, dst, cgraph, i);
break;
case GGML_OP_FUSED_UP_GATE:
ggml_cuda_up_gate_unary(ctx, dst);

View File

@@ -66,7 +66,8 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
template <ggml_type type, int ncols_y, int nwarps>
static __device__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy,
const float * bias, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -138,7 +139,7 @@ static __device__ void mul_mat_vec_q(
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
dst[j*nrows_dst + row0 + threadIdx.x] = bias ? tmp[j][threadIdx.x] + bias[j*nrows_dst + row0 + threadIdx.x] : tmp[j][threadIdx.x];
}
}
}
@@ -258,9 +259,10 @@ template <ggml_type type, int ncols_y, int nwarps>
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const char * __restrict__ ids_data, const void * __restrict__ bias,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst,
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) {
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, const int64_t bias_nb1) {
int i2 = blockIdx.y;
char * cdst = (char *)dst + i2*nb2;
@@ -270,7 +272,8 @@ static __global__ void mul_mat_vec_q(
}
const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
const float * b = (const float *)(bias ? ids_data ? (const char *)bias + i02*bias_nb1 : bias : nullptr);
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, b, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
template <ggml_type type, int ncols_y, int nwarps>
@@ -363,36 +366,36 @@ static void mul_mat_vec_q_cuda_T(const mmvq_args & args, cudaStream_t stream) {
} else {
switch (args.ncols_y) {
case 1:
mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 2:
mul_mat_vec_q<type, 2, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
mul_mat_vec_q<type, 2, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 3:
mul_mat_vec_q<type, 3, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
mul_mat_vec_q<type, 3, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 4:
mul_mat_vec_q<type, 4, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
mul_mat_vec_q<type, 4, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 5:
mul_mat_vec_q<type, 5, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
mul_mat_vec_q<type, 5, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 6:
mul_mat_vec_q<type, 6, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
mul_mat_vec_q<type, 6, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 7:
mul_mat_vec_q<type, 7, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
mul_mat_vec_q<type, 7, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 8:
mul_mat_vec_q<type, 8, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
mul_mat_vec_q<type, 8, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
default:
GGML_ABORT("fatal error");
@@ -689,6 +692,7 @@ void ggml_cuda_op_mul_mat_vec_q(
void ggml_cuda_op_mul_mat_vec_q_id(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_tensor * bias,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
@@ -702,10 +706,21 @@ void ggml_cuda_op_mul_mat_vec_q_id(
const int64_t ne0 = dst->ne[0];
if (bias) {
GGML_ASSERT(bias->type == GGML_TYPE_F32);
GGML_ASSERT(bias->ne[0] == ne0);
if (ids) {
//GGML_ASSERT(bias->ne[1] == src0->ne[2]);
GGML_ASSERT(bias->ne[2] == 1 && bias->ne[3] == 1);
} else {
GGML_ASSERT(ggml_nrows(bias) == 1);
}
}
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
ne00, ne0, dst->ne[2],
src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], 0,
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, (const char *)ids->data, nullptr, nullptr,
src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], bias ? bias->nb[1] : 0,
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, (const char *)ids->data, bias ? bias->data : nullptr, nullptr,
row_low, row_high, src1_ncols,
src1_padded_row_size, GGML_UNARY_OP_COUNT, stream);

View File

@@ -22,6 +22,7 @@ void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx,
void ggml_cuda_op_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_tensor * bias,
const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);