mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
iq2_kt: CUDA dequantize
so we can run perplexity calcs. As already indicated by rmse, the 2-bit trellis approach is quite a bit worse than iq2_xxs.
This commit is contained in:
@@ -2847,6 +2847,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
case GGML_TYPE_IQ3_K:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ5_K:
|
||||
|
||||
@@ -508,6 +508,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KS> {
|
||||
static constexpr int qi = QI4_XS;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
|
||||
static constexpr int qk = QK_K;
|
||||
static constexpr int qr = QR4_XS;
|
||||
static constexpr int qi = QI4_XS;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
|
||||
static constexpr int qk = QK_K;
|
||||
|
||||
@@ -333,6 +333,39 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
|
||||
inline __device__ int nearest_int(float fval) {
|
||||
assert(fval <= 4194303.f);
|
||||
float val = fval + 12582912.f;
|
||||
int i; memcpy(&i, &val, sizeof(int));
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
constexpr uint32_t ka = 89226354;
|
||||
constexpr uint32_t kb = 64248484;
|
||||
constexpr uint32_t kmask = 0x8fff8fff;
|
||||
constexpr uint32_t km32 = 0x3b603b60;
|
||||
const int64_t i = blockIdx.x;
|
||||
const block_iq2_kt * x = (const block_iq2_kt *) vx;
|
||||
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t ib = tid; // 0...31
|
||||
dst_t * y = yy + i*QK_K + 8*ib;
|
||||
uint32_t idx = (x[i].ql[ib] | (((x[i].qh[ib%16] >> 4*(ib/16)) & 0xf) << 8)) + 4096;
|
||||
const float dl = (float)x[i].d * iq4k_values[((x[i].scales[ib%16] >> 4*(ib/16)) & 0xf)];
|
||||
uint32_t s;
|
||||
const half * h = (const half *)&s;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
idx = ka*idx + kb;
|
||||
s = (idx & kmask) ^ km32;
|
||||
float val = (float)h[0] + (float)h[1];
|
||||
int ival = nearest_int(16.f*val);
|
||||
y[j] = dl * ival;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
@@ -862,6 +895,13 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
|
||||
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
|
||||
const int64_t k = nrows * n_per_row;
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_iq2_kt<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
|
||||
const int64_t k = nrows * n_per_row;
|
||||
@@ -1098,6 +1138,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
return dequantize_row_iq2_xxs_cuda;
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
return dequantize_row_iq2_kt_cuda;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
return dequantize_row_iq2_xs_cuda;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
@@ -1169,6 +1211,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
return dequantize_row_iq2_xxs_cuda;
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
return dequantize_row_iq2_kt_cuda;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
return dequantize_row_iq2_xs_cuda;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
|
||||
@@ -542,6 +542,11 @@ __device__ __forceinline__ float vec_dot_iq2_ks_q8_1(
|
||||
+ __low2float(bq8_1[4*(i4/4)+3].ds) * sumi4);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq2_kt_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
return 0.f;
|
||||
}
|
||||
|
||||
#define VDR_IQ3_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ3_K_Q8_1_MMQ 4
|
||||
|
||||
@@ -770,6 +775,13 @@ void mul_mat_vec_iq2_ks_q8_1_cuda(
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KS, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq2_kt_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KT, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_kt_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq5_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
@@ -32,6 +32,10 @@ void mul_mat_vec_iq2_ks_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq2_kt_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
|
||||
void mul_mat_vec_iq1_bn_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
|
||||
|
||||
@@ -446,6 +446,9 @@ void ggml_cuda_op_mul_mat_vec_q(
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
mul_mat_vec_iq2_kt_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_K:
|
||||
mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
|
||||
@@ -15190,6 +15190,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
case GGML_TYPE_Q6_0: break;
|
||||
case GGML_TYPE_IQ2_K: break;
|
||||
case GGML_TYPE_IQ2_KS: break;
|
||||
case GGML_TYPE_IQ2_KT: break;
|
||||
case GGML_TYPE_IQ3_K: break;
|
||||
case GGML_TYPE_IQ4_K: break;
|
||||
case GGML_TYPE_IQ5_K: break;
|
||||
|
||||
Reference in New Issue
Block a user