mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Fused ffn_up*unary_op(ffn_gate) for MMVQ (no bias)
We see nearly 2% TG speedup for Ling-mini-2.0 and about 1% for DeepSeek-Lite.
This commit is contained in:
@@ -2644,6 +2644,27 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
|
||||
local_src1.nb[1] = src_1_ddq_size;
|
||||
}
|
||||
|
||||
bool fuse_next = next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
|
||||
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
|
||||
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
|
||||
((ggml_backend_cuda_buffer_context *)next->src[0]->buffer->context)->device == device_id &&
|
||||
ggml_backend_buffer_is_cuda(next->buffer) &&
|
||||
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id;
|
||||
|
||||
if (!dst->src[4] && !dst->src[5]) {
|
||||
local_dst.data = dst_gate_contiguous.get();
|
||||
auto the_destination = fuse_next ? &local_dst : dst;
|
||||
auto unary_op = (ggml_unary_op)dst->op_params[0];
|
||||
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, the_destination,
|
||||
(const char *)src0_1->data, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(),
|
||||
(float *)dst_gate_contiguous.get(),
|
||||
0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (!fuse_next) return false;
|
||||
}
|
||||
else {
|
||||
|
||||
local_dst.data = dst_up_contiguous.get();
|
||||
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst,
|
||||
(const char *)src0_1->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_up_contiguous.get(),
|
||||
@@ -2669,24 +2690,16 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te
|
||||
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
|
||||
local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream);
|
||||
}
|
||||
}
|
||||
|
||||
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
|
||||
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
|
||||
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
|
||||
((ggml_backend_cuda_buffer_context *)next->src[0]->buffer->context)->device == device_id &&
|
||||
ggml_backend_buffer_is_cuda(next->buffer) &&
|
||||
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) {
|
||||
if (fuse_next) {
|
||||
|
||||
auto unary_op = (ggml_unary_op)dst->op_params[0];
|
||||
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
|
||||
} else {
|
||||
ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids,
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst_gate_contiguous.get());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
|
||||
GGML_ASSERT(dst->ne[0] % QK8_1 == 0);
|
||||
|
||||
@@ -272,28 +272,28 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
|
||||
switch (args.ncols_y) {
|
||||
case 1:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 2:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 3:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 4:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 5:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 6:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 7:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 8:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
||||
23
ggml/src/ggml-cuda/mmvq-args.h
Normal file
23
ggml/src/ggml-cuda/mmvq-args.h
Normal file
@@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
struct mmvq_args {
|
||||
const void * vx_u;
|
||||
const void * vx_g;
|
||||
const void * vy;
|
||||
float * dst;
|
||||
const char * ids_data;
|
||||
const int ncols_x;
|
||||
const int nrows_x;
|
||||
const int nrows_y;
|
||||
const int ncols_y;
|
||||
const int nrows_dst;
|
||||
const int ne2;
|
||||
const uint64_t nb02;
|
||||
const uint64_t nb12;
|
||||
const uint64_t nb2;
|
||||
const uint64_t ids_nb0;
|
||||
ggml_unary_op unary_op;
|
||||
};
|
||||
|
||||
@@ -262,6 +262,29 @@ static __global__ void mul_mat_vec_q(
|
||||
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_y, int nwarps>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void fused_mul_mat_vec_q(
|
||||
const void * __restrict__ vup, const void * __restrict__ vgate,
|
||||
const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst,
|
||||
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) {
|
||||
|
||||
int i2 = blockIdx.y;
|
||||
char * cdst = (char *)dst + i2*nb2;
|
||||
int i02 = *(const int *)(ids_data + i2*ids_nb0);
|
||||
if (i02 < 0) {
|
||||
return;
|
||||
}
|
||||
const char * cx_u = (const char *)vup + i02*nb02;
|
||||
const char * cx_g = (const char *)vgate + i02*nb02;
|
||||
const char * cy = (const char *)vy + i2*nb12;
|
||||
fused_mul_mat_vec_q<type, ncols_y, nwarps>(cx_u, cx_g, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, unary_op);
|
||||
}
|
||||
|
||||
template <ggml_type type, int nwarps>
|
||||
static void mul_mat_vec_q_cuda_T(const mmvq_args & args, cudaStream_t stream) {
|
||||
|
||||
@@ -273,71 +296,95 @@ static void mul_mat_vec_q_cuda_T(const mmvq_args & args, cudaStream_t stream) {
|
||||
int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ?
|
||||
args.ncols_y < 4 ? 1 : 2 : 1;
|
||||
|
||||
//if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
// switch(args.ncols_y) {
|
||||
// case 1:
|
||||
// nwarps = 4;
|
||||
// rows_per_cuda_block = 1;
|
||||
// break;
|
||||
// case 2:
|
||||
// case 3:
|
||||
// case 4:
|
||||
// nwarps = 4;
|
||||
// rows_per_cuda_block = 2;
|
||||
// break;
|
||||
// case 5:
|
||||
// case 6:
|
||||
// case 7:
|
||||
// case 8:
|
||||
// nwarps = 2;
|
||||
// rows_per_cuda_block = 2;
|
||||
// break;
|
||||
// default:
|
||||
// GGML_ABORT("fatal error");
|
||||
// break;
|
||||
// }
|
||||
//}
|
||||
const int64_t nblocks = (args.nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
|
||||
const dim3 block_nums(nblocks, args.ne2, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (args.vx_u && args.vx_g && args.ids_data && args.unary_op != GGML_UNARY_OP_COUNT) {
|
||||
switch (args.ncols_y) {
|
||||
case 1:
|
||||
mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data,
|
||||
fused_mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vx_g, args.vy,
|
||||
args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 2:
|
||||
fused_mul_mat_vec_q<type, 2, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vx_g, args.vy,
|
||||
args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 3:
|
||||
fused_mul_mat_vec_q<type, 3, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vx_g, args.vy,
|
||||
args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 4:
|
||||
fused_mul_mat_vec_q<type, 4, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vx_g, args.vy,
|
||||
args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 5:
|
||||
fused_mul_mat_vec_q<type, 5, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vx_g, args.vy,
|
||||
args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 6:
|
||||
fused_mul_mat_vec_q<type, 6, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vx_g, args.vy,
|
||||
args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 7:
|
||||
fused_mul_mat_vec_q<type, 7, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vx_g, args.vy,
|
||||
args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
case 8:
|
||||
fused_mul_mat_vec_q<type, 8, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vx_g, args.vy,
|
||||
args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (args.ncols_y) {
|
||||
case 1:
|
||||
mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 2:
|
||||
mul_mat_vec_q<type, 2, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data,
|
||||
mul_mat_vec_q<type, 2, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q<type, 3, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data,
|
||||
mul_mat_vec_q<type, 3, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 4:
|
||||
mul_mat_vec_q<type, 4, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data,
|
||||
mul_mat_vec_q<type, 4, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 5:
|
||||
mul_mat_vec_q<type, 5, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data,
|
||||
mul_mat_vec_q<type, 5, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q<type, 6, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data,
|
||||
mul_mat_vec_q<type, 6, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q<type, 7, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data,
|
||||
mul_mat_vec_q<type, 7, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q<type, 8, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx, args.vy, args.dst, args.ids_data,
|
||||
mul_mat_vec_q<type, 8, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data,
|
||||
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
@@ -446,9 +493,9 @@ static void mul_mat_vec_iq3_s_q8_1_cuda(const mmvq_args & args, cudaStream_t str
|
||||
static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type,
|
||||
const int64_t ne00, const int64_t ne0, const int64_t ne2,
|
||||
const int64_t nb02, const int64_t nb12, const int64_t nb2, const int64_t ids_nb0,
|
||||
const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, const char * ids_data,
|
||||
const char * src0_dd_u, const char * src0_dd_g, const char * src1_ddq_i, float * dst_dd_i, const char * ids_data,
|
||||
const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||
const int64_t src1_padded_row_size, ggml_unary_op unary_op, cudaStream_t stream) {
|
||||
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
@@ -475,7 +522,8 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
|
||||
// const uint64_t nb2;
|
||||
// const uint64_t ids_nb0;
|
||||
//};
|
||||
mmvq_args args{/* vx */ src0_dd_i,
|
||||
mmvq_args args{/* vx_u */ src0_dd_u,
|
||||
/* vx_g */ src0_dd_g,
|
||||
/* vy */ src1_ddq_i,
|
||||
/* dst */ dst_dd_i,
|
||||
/* ids_data */ ids_data,
|
||||
@@ -488,7 +536,8 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
|
||||
/* nb02 */ uint64_t(nb02),
|
||||
/* nb12 */ uint64_t(nb12),
|
||||
/* nb2 */ uint64_t(nb2),
|
||||
/* ids_nb0 */ uint64_t(ids_nb0)
|
||||
/* ids_nb0 */ uint64_t(ids_nb0),
|
||||
/* unary_op */ unary_op
|
||||
};
|
||||
|
||||
switch (type) {
|
||||
@@ -608,9 +657,9 @@ void ggml_cuda_op_mul_mat_vec_q_3D(
|
||||
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
|
||||
ne00, ne0, dst->ne[2],
|
||||
src0->nb[2], src1_row_size, dst->nb[2], 0,
|
||||
src0_dd_i, src1_ddq_i, dst_dd_i, nullptr,
|
||||
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr,
|
||||
row_low, row_high, src1_ncols,
|
||||
src1_padded_row_size, stream);
|
||||
src1_padded_row_size, GGML_UNARY_OP_COUNT, stream);
|
||||
|
||||
GGML_UNUSED(src1_ddf_i);
|
||||
}
|
||||
@@ -629,9 +678,9 @@ void ggml_cuda_op_mul_mat_vec_q(
|
||||
|
||||
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
|
||||
ne00, ne0, 1, 0, 0, 0, 0,
|
||||
src0_dd_i, src1_ddq_i, dst_dd_i, nullptr,
|
||||
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr,
|
||||
row_low, row_high, src1_ncols,
|
||||
src1_padded_row_size, stream);
|
||||
src1_padded_row_size, GGML_UNARY_OP_COUNT, stream);
|
||||
|
||||
GGML_UNUSED(src1_ddf_i);
|
||||
}
|
||||
@@ -655,13 +704,44 @@ void ggml_cuda_op_mul_mat_vec_q_id(
|
||||
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
|
||||
ne00, ne0, dst->ne[2],
|
||||
src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0],
|
||||
src0_dd_i, src1_ddq_i, dst_dd_i, (const char *)ids->data,
|
||||
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, (const char *)ids->data,
|
||||
row_low, row_high, src1_ncols,
|
||||
src1_padded_row_size, stream);
|
||||
src1_padded_row_size, GGML_UNARY_OP_COUNT, stream);
|
||||
|
||||
GGML_UNUSED(src1_ddf_i);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_fused_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const char * src0_dd_u, const char * src0_dd_g, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, ggml_unary_op unary_op, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(unary_op == GGML_UNARY_OP_SILU ||
|
||||
unary_op == GGML_UNARY_OP_RELU ||
|
||||
unary_op == GGML_UNARY_OP_GELU);
|
||||
GGML_ASSERT(src0_dd_u && src0_dd_g && ids && ids->data);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1);
|
||||
GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1);
|
||||
GGML_ASSERT(ids->ne[0] == dst->ne[2]);
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
|
||||
ne00, ne0, dst->ne[2],
|
||||
src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0],
|
||||
src0_dd_u, src0_dd_g, src1_ddq_i, dst_dd_i, (const char *)ids->data,
|
||||
row_low, row_high, src1_ncols,
|
||||
src1_padded_row_size, unary_op, stream);
|
||||
|
||||
GGML_UNUSED(src1_ddf_i);
|
||||
}
|
||||
|
||||
|
||||
bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
||||
@@ -25,3 +25,9 @@ void ggml_cuda_op_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx,
|
||||
const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_fused_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const char * src0_dd_u, const char * src0_dd_g, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, ggml_unary_op unary_op, cudaStream_t stream);
|
||||
|
||||
Reference in New Issue
Block a user