mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 07:20:12 +00:00
Faster MLA on CUDA (#234)
* Slight MLA TG performance improvement on CUDA The low MLA performance on CUDA is dues to the wk_b * q_nope operation. It turns into n_head matrix multiplications with n_head separate quantization and GEMV steps. The associated overhead is just too much for TG where each GEMV is very fast (512 x 128 = 131 KFLOP for DeepSeek-Lite, 4X that for DeepSeekV3/R1). The way it was done there was also a copy of each q_nope row before quantization, which I have now eliminated. This results in a ~2.5% speedup. What needs to happen instead is to launch a single computation that quantizes all heads, and then have a kernel that does the GEMV for all heads instead of n_head sequential GEMVs. * Slightly better * CUDA: Quantize non-contiguous tensors * Much better MLA It is a total hack, but it works. * Cleanup Remove duplicated gemv's. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -1518,6 +1518,8 @@ static void ggml_cuda_op_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
bool quantization_done = false;
|
||||
|
||||
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
|
||||
if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
|
||||
continue;
|
||||
@@ -1561,9 +1563,15 @@ static void ggml_cuda_op_mul_mat(
|
||||
}
|
||||
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
|
||||
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
|
||||
if (src1_on_device && (src1_is_contiguous || (src1->ne[1] == 1 && src1->ne[3] == 1 && src1->nb[0] == sizeof(float)))) {
|
||||
if (src1_is_contiguous) {
|
||||
quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
|
||||
} else {
|
||||
//printf("Calling quantize_tensor_q8_1_cuda for %s\n", src0->name);
|
||||
quantize_tensor_q8_1_cuda(src1, dev[id].src1_ddq, src0->type, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
quantization_done = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1583,6 +1591,20 @@ static void ggml_cuda_op_mul_mat(
|
||||
}
|
||||
|
||||
const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
|
||||
if (!(split && used_devices > 1) && quantization_done && ne11 == 1 && ne12 > 1 && ne13 == 1) {
|
||||
//printf("invoking fast path for %s x %s\n", src0->name, src1->name);
|
||||
int id = ctx.device;
|
||||
char * src0_dd_i = dev[id].src0_dd;
|
||||
float * src1_ddf_i = dev[id].src1_ddf;
|
||||
char * src1_ddq_i = dev[id].src1_ddq;
|
||||
float * dst_dd_i = dev[id].dst_dd;
|
||||
cudaStream_t stream = ctx.stream(id, 0);
|
||||
ggml_cuda_op_mul_mat_vec_q_3D(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
|
||||
dev[id].row_low, dev[id].row_high, ne11, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
|
||||
const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0;
|
||||
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
|
||||
@@ -1649,13 +1671,17 @@ static void ggml_cuda_op_mul_mat(
|
||||
}
|
||||
}
|
||||
} else if (src1_on_device && !src1_is_contiguous) {
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
|
||||
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
|
||||
if (!quantization_done) {
|
||||
//printf("Copying %s\n", src1->name);
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
|
||||
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
if (quantize_src1 && !src1_is_contiguous) {
|
||||
if (quantize_src1 && !src1_is_contiguous && !quantization_done) {
|
||||
//printf("Quantizing %s\n", src1->name);
|
||||
quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
@@ -15,11 +15,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
|
||||
|
||||
namespace {
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__global__ void iqk_mul_mat_vec_q(
|
||||
__device__ void iqk_mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {
|
||||
|
||||
@@ -94,10 +90,27 @@ __global__ void iqk_mul_mat_vec_q(
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__global__ void iqk_mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) {
|
||||
int i2 = blockIdx.y;
|
||||
const char * cx = (const char *)vx + i2*nb02;
|
||||
const char * cy = (const char *)vy + i2*nb12;
|
||||
char * cdst = (char *)dst + i2*nb2;
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
void iqk_mul_mat_vec_q_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
|
||||
//GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
||||
@@ -132,35 +145,35 @@ void iqk_mul_mat_vec_q_cuda(
|
||||
}
|
||||
}
|
||||
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
|
||||
const dim3 block_nums(nblocks, 1, 1);
|
||||
const dim3 block_nums(nblocks, ne2, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
const int64_t row_size = ggml_row_size(type, ncols_x);
|
||||
|
||||
switch (ncols_y) {
|
||||
case 1:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
|
||||
break;
|
||||
case 2:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
|
||||
break;
|
||||
case 3:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
|
||||
break;
|
||||
case 4:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
|
||||
break;
|
||||
case 5:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
|
||||
break;
|
||||
case 6:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
|
||||
break;
|
||||
case 7:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
|
||||
break;
|
||||
case 8:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
@@ -730,68 +743,78 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
|
||||
|
||||
void mul_mat_vec_iq2_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K, VDR_IQ2_K_Q8_1_MMVQ, vec_dot_iq2_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K, VDR_IQ2_K_Q8_1_MMVQ, vec_dot_iq2_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq3_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_K, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_K, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq4_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq4_ks_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq4_kss_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq2_ks_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KS, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KS, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq5_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq6_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq1_bn_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN, 1, vec_dot_iq1_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN, 1, vec_dot_iq1_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq2_bn_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
@@ -2,40 +2,50 @@
|
||||
|
||||
void mul_mat_vec_iq2_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq3_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq4_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq5_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq6_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq4_ks_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq4_kss_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq2_ks_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq1_bn_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq2_bn_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "mmvq.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
#include "iqk_mmvq.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
|
||||
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
|
||||
|
||||
@@ -51,11 +51,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_y>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void mul_mat_vec_q(
|
||||
static __device__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
@@ -132,10 +128,27 @@ static __global__ void mul_mat_vec_q(
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_y>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst,
|
||||
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) {
|
||||
int i2 = blockIdx.y;
|
||||
const char * cx = (const char *)vx + i2*nb02;
|
||||
const char * cy = (const char *)vy + i2*nb12;
|
||||
char * cdst = (char *)dst + i2*nb2;
|
||||
mul_mat_vec_q<type, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
|
||||
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
||||
@@ -145,6 +158,11 @@ static void mul_mat_vec_q_cuda(
|
||||
int64_t nwarps = 1;
|
||||
int64_t rows_per_cuda_block = 1;
|
||||
|
||||
//if (ne2 > 1) {
|
||||
// printf("%s: ncols_x = %d, nrows_x = %d, nrows_y = %d, ncols_y = %d nrows_dst = %d, ne2 = %d nb02 = %zu, nb12 = %zu, nb2 = %zu\n", __func__,
|
||||
// ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2);
|
||||
//}
|
||||
|
||||
if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
switch(ncols_y) {
|
||||
case 1:
|
||||
@@ -170,33 +188,33 @@ static void mul_mat_vec_q_cuda(
|
||||
}
|
||||
}
|
||||
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
|
||||
const dim3 block_nums(nblocks, 1, 1);
|
||||
const dim3 block_nums(nblocks, ne2, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
switch (ncols_y) {
|
||||
case 1:
|
||||
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
|
||||
break;
|
||||
case 2:
|
||||
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
|
||||
break;
|
||||
case 4:
|
||||
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
|
||||
break;
|
||||
case 5:
|
||||
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -206,142 +224,311 @@ static void mul_mat_vec_q_cuda(
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q6_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q6_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q6_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_s_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_s_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_m_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_xs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type,
|
||||
const int64_t ne00, const int64_t ne10, const int64_t ne0, const int64_t ne2,
|
||||
const int64_t nb02, const int64_t nb12, const int64_t nb2,
|
||||
const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i,
|
||||
const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
|
||||
const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size);
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_0:
|
||||
mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_K:
|
||||
mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ6_K:
|
||||
mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q_3D(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1);
|
||||
GGML_ASSERT(src0->ne[2] == src1->ne[2] && src0->ne[2] == dst->ne[2]);
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
|
||||
const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size);
|
||||
|
||||
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
|
||||
ne00, ne10, ne0, dst->ne[2],
|
||||
src0->nb[2], src1_row_size, dst->nb[2],
|
||||
src0_dd_i, src1_ddq_i, dst_dd_i,
|
||||
row_low, row_high, src1_ncols,
|
||||
src1_padded_row_size, stream);
|
||||
|
||||
GGML_UNUSED(src1_ddf_i);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
@@ -364,105 +551,11 @@ void ggml_cuda_op_mul_mat_vec_q(
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_0:
|
||||
mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_K:
|
||||
mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ6_K:
|
||||
mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
|
||||
ne00, ne10, ne0, 1, 0, 0, 0,
|
||||
src0_dd_i, src1_ddq_i, dst_dd_i,
|
||||
row_low, row_high, src1_ncols,
|
||||
src1_padded_row_size, stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_ddf_i);
|
||||
GGML_UNUSED(src1_ncols);
|
||||
GGML_UNUSED(src1_padded_row_size);
|
||||
}
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream);
|
||||
|
||||
@@ -37,6 +37,42 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
||||
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
||||
}
|
||||
|
||||
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded, const uint64_t stride) {
|
||||
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (ix0 >= kx0_padded) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ix1 = blockIdx.y;
|
||||
|
||||
const int64_t i_padded = ix1*kx0_padded + ix0;
|
||||
|
||||
block_q8_1 * y = (block_q8_1 *) vy;
|
||||
|
||||
const int64_t ib = i_padded / QK8_1; // block index
|
||||
const int64_t iqs = i_padded % QK8_1; // quant index
|
||||
|
||||
const float xi = ix0 < kx ? x[ix1*stride + ix0] : 0.0f;
|
||||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
|
||||
amax = warp_reduce_max(amax);
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
const float d = amax / 127;
|
||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||
|
||||
y[ib].qs[iqs] = q;
|
||||
|
||||
if (iqs > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
reinterpret_cast<half&>(y[ib].ds.x) = d;
|
||||
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
||||
}
|
||||
|
||||
template <mmq_q8_1_ds_layout ds_layout>
|
||||
static __global__ void quantize_mmq_q8_1(
|
||||
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
||||
@@ -164,3 +200,19 @@ void quantize_mmq_q8_1_cuda(
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream) {
|
||||
GGML_ASSERT(src->ne[1] == 1 && src->ne[3] == 1);
|
||||
GGML_ASSERT(src->type == GGML_TYPE_F32);
|
||||
const int64_t src_padded_col_size = GGML_PAD(src->ne[0], MATRIX_ROW_PADDING);
|
||||
GGML_ASSERT(src_padded_col_size % QK8_1 == 0);
|
||||
if (src->ne[2] == 1 || ggml_is_contiguous(src)) {
|
||||
quantize_row_q8_1_cuda((const float *)src->data, vy, src->ne[0], 1, 1, src_padded_col_size, type, stream);
|
||||
return;
|
||||
}
|
||||
const int64_t block_num_x = (src_padded_col_size + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const dim3 num_blocks(block_num_x, src->ne[2]*src->ne[3], 1);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
const uint64_t stride = src->nb[2]/sizeof(float);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>((const float *)src->data, vy, src->ne[0], src_padded_col_size, stride);
|
||||
}
|
||||
|
||||
@@ -22,3 +22,6 @@ void quantize_row_q8_1_cuda(
|
||||
void quantize_mmq_q8_1_cuda(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream);
|
||||
|
||||
// For now only applicable for tensors with ne[1] = 1, ne[3] = 1, and useful if ne[2] > 1
|
||||
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream);
|
||||
|
||||
@@ -3212,7 +3212,7 @@ static bool llama_kv_cache_init(
|
||||
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
|
||||
//ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
|
||||
#else
|
||||
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
|
||||
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_v, kv_lora_rank + n_embd_head_qk_rope, kv_size);
|
||||
#endif
|
||||
ggml_format_name(kv, "cache_kv_l%d", i);
|
||||
cache.kv_l.push_back(kv);
|
||||
@@ -13579,6 +13579,7 @@ struct llm_build_context {
|
||||
cb(wk_b, "wk_b", il);
|
||||
|
||||
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
|
||||
//if (q_nope->ne[1] <= 32) q_nope = ggml_cont(ctx0, q_nope);
|
||||
cb(q_nope, "q_nope_perm", il);
|
||||
|
||||
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
|
||||
|
||||
Reference in New Issue
Block a user