iq4_kss: CUDA dequantize works

So we can run perplexity. Sadly, the result does not look good
on the bpw vs quantization error plot.
This commit is contained in:
Iwan Kawrakow
2024-10-15 11:48:48 +03:00
parent fd89bf186e
commit b159b2b113
11 changed files with 170 additions and 84 deletions

View File

@@ -44,6 +44,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", },
{ "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", },
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",}, { "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", }, { "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },

View File

@@ -2829,6 +2829,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ3_K:

View File

@@ -543,6 +543,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS> {
static constexpr int qi = QI4_XS; static constexpr int qi = QI4_XS;
}; };
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KSS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};
template<> template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> { struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> {
static constexpr int qk = QK_K; static constexpr int qk = QK_K;

View File

@@ -638,6 +638,36 @@ static __global__ void dequantize_block_iq4_ks(const void * __restrict__ vx, dst
} }
} }
template<typename dst_t>
static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const char * cx = (const char *)vx + row * row_size;
float scale = *(const float *)cx;
const block_iq4_kss * x = (const block_iq4_kss *)(cx + sizeof(float));
const int64_t i = ii - (row*n_per_row)/QK_K;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + ii*QK_K + 32*ib + 4*il;
const uint32_t * q4 = x[i].qs + 4*ib;
uint8_t ls = (q4[0] >> 30) | ((q4[1] >> 28) & 0x0c) | ((q4[2] >> 26) & 0x30) | ((q4[3] >> 24) & 0xc0);
const float d = scale * ((ls & 254) - 127);
const int8_t * values = iq4k_values + ((ls & 1) << 4);
uint32_t aux32[2];
aux32[0] = (q4[il] & 0x00007fff) | ((q4[il] << 1) & 0x7fff0000);
aux32[0] ^= (aux32[0] << 1);
aux32[1] = ((aux32[0] >> 4) & 0x0f0f0f0f);
aux32[0] &= 0x0f0f0f0f;
const uint8_t * aux8 = (const uint8_t *)aux32;
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * values[aux8[j+0]];
y[j+16] = d * values[aux8[j+4]];
}
}
template<typename dst_t> template<typename dst_t>
static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x; const int64_t i = blockIdx.x;
@@ -980,6 +1010,14 @@ static void dequantize_row_iq4_ks_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_ks<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size); dequantize_block_iq4_ks<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
} }
template<typename dst_t>
static void dequantize_row_iq4_kss_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KSS, n_per_row);
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_kss<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
}
template<typename dst_t> template<typename dst_t>
static void dequantize_row_iq2_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { static void dequantize_row_iq2_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row; const int64_t k = nrows * n_per_row;
@@ -1152,6 +1190,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda; return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS:
return dequantize_row_iq4_ks_cuda; return dequantize_row_iq4_ks_cuda;
case GGML_TYPE_IQ4_KSS:
return dequantize_row_iq4_kss_cuda;
case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KS:
return dequantize_row_iq2_ks_cuda; return dequantize_row_iq2_ks_cuda;
case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K:
@@ -1225,6 +1265,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda; return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS:
return dequantize_row_iq4_ks_cuda; return dequantize_row_iq4_ks_cuda;
case GGML_TYPE_IQ4_KSS:
return dequantize_row_iq4_kss_cuda;
case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KS:
return dequantize_row_iq2_ks_cuda; return dequantize_row_iq2_ks_cuda;
case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K:

View File

@@ -239,6 +239,34 @@ __device__ __forceinline__ float vec_dot_iq4_ks_q8_1(
return dl * __low2float(bq8_1[ib32].ds) * sumi; return dl * __low2float(bq8_1[ib32].ds) * sumi;
} }
#define VDR_IQ4_KSS_Q8_1_MMVQ 4
#define VDR_IQ4_KSS_Q8_1_MMQ 4
__device__ __forceinline__ float vec_dot_iq4_kss_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
float scale = *(const float *)vbq;
const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx;
const uint8_t * all_values = (const uint8_t *)iq4k_values;
// TODO
return 0.f;
//// iqs is 0...28
//const int ib32 = iqs/4; // Why iqs/4 ?
//const int32_t * q8 = (const int *)bq8_1[ib32].qs;
//const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
//const float dl = scale * ((bq4->scales[ib32] & 254) - 127);
//int v1, v2;
//int sumi = 0;
//for (int j = 0; j < 4; ++j) {
// get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2);
// sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
// sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
//}
//return dl * __low2float(bq8_1[ib32].ds) * sumi;
}
#define VDR_IQ5_K_Q8_1_MMVQ 4 #define VDR_IQ5_K_Q8_1_MMVQ 4
#define VDR_IQ5_K_Q8_1_MMQ 4 #define VDR_IQ5_K_Q8_1_MMQ 4
@@ -703,6 +731,13 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} }
void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
void mul_mat_vec_iq2_ks_q8_1_cuda( void mul_mat_vec_iq2_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

View File

@@ -32,6 +32,10 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
void mul_mat_vec_iq4_kss_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
void mul_mat_vec_iq2_ks_q8_1_cuda( void mul_mat_vec_iq2_ks_q8_1_cuda(
const void * vx, const void * vy, float * dst, const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

View File

@@ -462,6 +462,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS:
mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ4_KSS:
mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KS:
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;

View File

@@ -15197,6 +15197,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ2_TN: break; case GGML_TYPE_IQ2_TN: break;
case GGML_TYPE_IQ1_TN: break; case GGML_TYPE_IQ1_TN: break;
case GGML_TYPE_IQ4_KS: break; case GGML_TYPE_IQ4_KS: break;
case GGML_TYPE_IQ4_KSS: break;
case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_4_8:
{ {

View File

@@ -2841,12 +2841,16 @@ const uint16_t * scramble_table() {
uint16_t val = i; uint16_t val = i;
int non = popcount(val); int non = popcount(val);
if (non%2) val |= (1 << 15); if (non%2) val |= (1 << 15);
bool found = false;
for (int j = 0; j < int(table.size()); ++j) { for (int j = 0; j < int(table.size()); ++j) {
if ((j ^ (j << 1)) == val) { if ((j ^ (j << 1)) == val) {
//printf("%5d %5u %5d\n", i, val, j); table[i] = j; found = true; break;
table[i] = j; break;
} }
} }
if (!found) {
printf("Oops: did not find for %d %u\n", i, val);
exit(1);
}
} }
} }
return table.data(); return table.data();
@@ -2892,6 +2896,8 @@ uint16_t prune_iq4ks(uint16_t v, const int8_t * values, const float * x, const f
void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy, void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy,
const float * quant_weights, float * weight, float * all_scales) { const float * quant_weights, float * weight, float * all_scales) {
constexpr int kBlockSize = 32; constexpr int kBlockSize = 32;
float xv[4], w[4];
uint16_t vps[kBlockSize/4];
const float * dptr_ks = (const float *)cx; const float * dptr_ks = (const float *)cx;
const float d_ks = *dptr_ks; const float d_ks = *dptr_ks;
const block_iq4_ks * iq4ks = (const block_iq4_ks *)(dptr_ks + 1); const block_iq4_ks * iq4ks = (const block_iq4_ks *)(dptr_ks + 1);
@@ -2906,6 +2912,7 @@ void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * c
float sigma2 = 0; float sigma2 = 0;
for (int j = 0; j < QK_K; ++j) sigma2 += xbl[j]*xbl[j]; for (int j = 0; j < QK_K; ++j) sigma2 += xbl[j]*xbl[j];
sigma2 *= 2.f/QK_K; sigma2 *= 2.f/QK_K;
const uint16_t * q4 = (const uint16_t *)iq4ks[ibl].qs;
for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
const float * xb = xbl + ib*kBlockSize; const float * xb = xbl + ib*kBlockSize;
if (quant_weights) { if (quant_weights) {
@@ -2917,65 +2924,42 @@ void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * c
const int8_t * values = iq4k_values + ((iq4ks[ibl].scales[ib] & 1) << 4); const int8_t * values = iq4k_values + ((iq4ks[ibl].scales[ib] & 1) << 4);
float dl = d_ks * ((iq4ks[ibl].scales[ib] & 254) - 127); float dl = d_ks * ((iq4ks[ibl].scales[ib] & 254) - 127);
float sumqx = 0, sumq2 = 0; float sumqx = 0, sumq2 = 0;
for (int k = 0; k < kBlockSize/4; ++k) {
xv[0] = xb[2*k+0]; xv[1] = xb[2*k+kBlockSize/2]; xv[2] = xb[2*k+1]; xv[3] = xb[2*k+1+kBlockSize/2];
w[0] = weight[2*k+0]; w[1] = weight[2*k+kBlockSize/2]; w[2] = weight[2*k+1]; w[3] = weight[2*k+1+kBlockSize/2];
auto vp = prune_iq4ks(q4[k], values, xv, w, dl);
vps[k] = table[vp & 0x7fff];
for (int j = 0; j < 4; ++j) {
float q = values[(vp >> 4*j) & 0xf];
sumqx += w[j]*q*xv[j];
sumq2 += w[j]*q*q;
}
}
for (int k = 0; k < kBlockSize/8; ++k) { for (int k = 0; k < kBlockSize/8; ++k) {
uint16_t vl = 0, vh = 0; y[ibl].qs[(kBlockSize/8)*ib + k] = vps[2*k+0] | (vps[2*k+1] << 15) | (((iq4ks[ibl].scales[ib] >> 2*k) & 3) << 30);
for (int j = 0; j < 4; ++j) { //y[ibl].qs[(kBlockSize/8)*ib + k] = vps[2*k+0] | (vps[2*k+1] << 15);
uint8_t ql = iq4ks[ibl].qs[(kBlockSize/2)*ib + (kBlockSize/8)*k + j] & 0xf;
uint8_t qh = iq4ks[ibl].qs[(kBlockSize/2)*ib + (kBlockSize/8)*k + j] >> 4;
vl |= ql << 4*j;
vh |= qh << 4*j;
}
auto vlp = prune_iq4ks(vl, values, xb + 4*k, weight + 4*k, dl);
auto vhp = prune_iq4ks(vh, values, xb + 4*k + kBlockSize/2, weight + 4*k + kBlockSize/2, dl);
auto vlp_s = table[vlp & 0x7fff];
auto vhp_s = table[vhp & 0x7fff];
y[ibl].qs[(kBlockSize/8)*ib + k] = vlp_s | (vhp_s << 15) | (((iq4ks[ibl].scales[ib] >> 2*k) & 3) << 30);
if ((vlp_s ^ (vlp_s << 1)) != vlp) {
printf("Oops(l): vl = %u (%d) vlp = %u (%d) vlp_s = %u check = %u\n", vl, popcount(vl), vlp, popcount(vlp), vlp_s, vlp_s ^ (vlp_s << 1));
exit(1);
}
if ((vhp_s ^ (vhp_s << 1)) != vhp) {
printf("Oops(h): vh = %u (%d) vhp = %u (%d) vhp_s = %u check = %u\n", vh, popcount(vh), vhp, popcount(vhp), vhp_s, vhp_s ^ (vhp_s << 1));
exit(1);
}
for (int j = 0; j < 4; ++j) {
float ql = values[(vlp >> 4*j) & 0xf];
float qh = values[(vhp >> 4*j) & 0xf];
sumqx += weight[4*k+j]*ql*xb[4*k+j] + weight[4*k+j+16]*qh*xb[4*k+j+16];
sumq2 += weight[4*k+j]*ql*ql + weight[4*k+j+16]*qh*qh;
}
//for (int j = 0; j < 4; ++j) {
// uint8_t ql = iq4ks[ibl].qs[(kBlockSize/2)*ib + (kBlockSize/8)*k + j] & 0xf;
// uint8_t qh = iq4ks[ibl].qs[(kBlockSize/2)*ib + (kBlockSize/8)*k + j] >> 4;
// uint8_t qlp = (vlp >> 4*j) & 0xf;
// uint8_t qhp = (vhp >> 4*j) & 0xf;
// float diffl = dl*values[ql] - xb[4*k+j];
// float diffh = dl*values[qh] - xb[4*k+j+16];
// float difflp = dl*values[qlp] - xb[4*k+j];
// float diffhp = dl*values[qhp] - xb[4*k+j+16];
// printf("%d %d %d: %u %u %u %u %g %g %g %g\n", ib, k, j, ql, qh, qlp, qhp, diffl, diffh, difflp, diffhp);
//}
} }
scales[ib] = sumq2 > 0 ? sumqx/sumq2 : dl; scales[ib] = sumq2 > 0 ? sumqx/sumq2 : dl;
max_abs_scale = std::max(max_abs_scale, scales[ib]); max_abs_scale = std::max(max_abs_scale, scales[ib]);
q4 += kBlockSize/4;
} }
} }
if (!max_abs_scale) return; //if (!max_abs_scale) return;
float d = max_abs_scale/127; //float d = max_abs_scale/127;
*dptr = d; //*dptr = d;
float id = 1/d; //float id = 1/d;
for (int ibl = 0; ibl < nblock; ++ibl) { //for (int ibl = 0; ibl < nblock; ++ibl) {
auto scales = all_scales + ibl*(QK_K/kBlockSize); // auto scales = all_scales + ibl*(QK_K/kBlockSize);
for (int ib = 0; ib < QK_K/kBlockSize; ++ib) { // for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib]+127.f)); // int l = nearest_int(0.5f*(id*scales[ib]+127.f));
l = std::max(0, std::min(127, l)) << 1; // l = std::max(0, std::min(127, l)) << 1;
l |= (iq4ks[ibl].scales[ib] & 1); // l |= (iq4ks[ibl].scales[ib] & 1);
for (int k = 0; k < 4; ++k) { // for (int k = 0; k < 4; ++k) {
y[ibl].qs[4*ib+k] &= 0x3fffffff; // //y[ibl].qs[4*ib+k] &= 0x3fffffff;
y[ibl].qs[4*ib+k] |= (((l >> 2*k) & 3) << 30); // y[ibl].qs[4*ib+k] |= (((l >> 2*k) & 3) << 30);
} // }
} // }
} //}
} }
} }
@@ -3010,36 +2994,36 @@ void dequantize_row_iq4_kss(const block_iq4_kss * x, float * y, int64_t k) {
const float * dptr = (const float *)x; const float * dptr = (const float *)x;
const float d = *dptr; const float d = *dptr;
x = (const block_iq4_kss *)(dptr + 1); x = (const block_iq4_kss *)(dptr + 1);
//uint32_t aux32[4]; uint32_t aux32[4];
//const uint8_t * aux8 = (const uint8_t *)aux32; const uint8_t * aux8 = (const uint8_t *)aux32;
for (int ibl = 0; ibl < k/QK_K; ++ibl) { for (int ibl = 0; ibl < k/QK_K; ++ibl) {
auto qs = x[ibl].qs; auto qs = x[ibl].qs;
for (int ib = 0; ib < QK_K/32; ++ib) { for (int ib = 0; ib < QK_K/32; ++ib) {
uint8_t ls = ((qs[0] >> 30) | ((qs[1] >> 28) & 0x0c) | ((qs[2] >> 26) & 0x30) | ((qs[3] >> 24) & 0xc0)); //uint8_t ls = ((qs[0] >> 30) | ((qs[1] >> 28) & 0x0c) | ((qs[2] >> 26) & 0x30) | ((qs[3] >> 24) & 0xc0));
const int8_t * values = iq4k_values + ((ls & 1) << 4);
const float dl = d * ((ls & 254) - 127);
for (int k = 0; k < 4; ++k) {
uint16_t vl = qs[k] & 0x7fff;
vl ^= (vl << 1);
uint16_t vh = (qs[k] >> 15) & 0x7fff;
vh ^= (vh << 1);
for (int j = 0; j < 4; ++j) {
y[4*k + j + 0] = dl*values[(vl >> 4*j) & 0xf];
y[4*k + j + 16] = dl*values[(vh >> 4*j) & 0xf];
}
}
//int16_t ls = 0;
//for (int k = 0; k < 4; ++k) {
// aux32[k] = (qs[k] & 0x00007fff) | ((qs[k] << 1) & 0x7fff0000);
// aux32[k] ^= (aux32[k] << 1);
// ls |= (qs[k] >> 30) << 2*k;
//}
//const int8_t * values = iq4k_values + ((ls & 1) << 4); //const int8_t * values = iq4k_values + ((ls & 1) << 4);
//float dl = d * ((ls & 254) - 127); //const float dl = d * ((ls & 254) - 127);
//for (int j = 0; j < 16; ++j) { //for (int k = 0; k < 4; ++k) {
// y[j+ 0] = dl * values[aux8[j] & 0xf]; // uint16_t vl = qs[k] & 0x7fff;
// y[j+16] = dl * values[aux8[j] >> 4]; // vl ^= (vl << 1);
// uint16_t vh = (qs[k] >> 15) & 0x7fff;
// vh ^= (vh << 1);
// for (int j = 0; j < 4; ++j) {
// y[4*k + j + 0] = dl*values[(vl >> 4*j) & 0xf];
// y[4*k + j + 16] = dl*values[(vh >> 4*j) & 0xf];
// }
//} //}
int16_t ls = 0;
for (int k = 0; k < 4; ++k) {
aux32[k] = (qs[k] & 0x00007fff) | ((qs[k] << 1) & 0x7fff0000);
aux32[k] ^= (aux32[k] << 1);
ls |= (qs[k] >> 30) << 2*k;
}
const int8_t * values = iq4k_values + ((ls & 1) << 4);
float dl = d * ((ls & 254) - 127);
for (int j = 0; j < 16; ++j) {
y[j+ 0] = dl * values[aux8[j] & 0xf];
y[j+16] = dl * values[aux8[j] >> 4];
}
y += 32; y += 32;
qs += 4; qs += 4;
} }

View File

@@ -180,6 +180,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_KS = 145, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KS = 145, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
}; };

View File

@@ -3795,6 +3795,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break;
case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break;
case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break;
case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break; case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break;
case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break;
@@ -4498,6 +4499,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_KSS: return "IQ4_KSS - 4.0 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_K: return "IQ3_K - 3.4325 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_K: return "IQ3_K - 3.4325 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_KL: return "IQ3_KL - 4 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_KL: return "IQ3_KL - 4 bpw";
@@ -15651,7 +15653,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS) { ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS) {
new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K; new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K;
} }
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS) && !qs.has_output) { else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS ||
ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_output) {
new_type = GGML_TYPE_IQ5_K; new_type = GGML_TYPE_IQ5_K;
} }
else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_IQ6_K) { else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_IQ6_K) {
@@ -15742,7 +15745,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS) && qs.model.hparams.n_gqa() >= 2) { else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS ||
ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && qs.model.hparams.n_gqa() >= 2) {
new_type = GGML_TYPE_IQ5_K; new_type = GGML_TYPE_IQ5_K;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_K && qs.model.hparams.n_gqa() >= 2) { else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_K && qs.model.hparams.n_gqa() >= 2) {
@@ -15822,7 +15826,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
} }
} }
else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS) && !qs.has_imatrix) { else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS ||
ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_imatrix) {
new_type = GGML_TYPE_Q5_K; new_type = GGML_TYPE_Q5_K;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
@@ -15910,7 +15915,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K ||
new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN || new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_IQ2_TN ||
new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ1_TN || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ1_TN || new_type == GGML_TYPE_IQ4_KS ||
new_type == GGML_TYPE_IQ2_KS) { new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS) {
int nx = tensor->ne[0]; int nx = tensor->ne[0];
int ny = tensor->ne[1]; int ny = tensor->ne[1];
if (nx % QK_K != 0) { if (nx % QK_K != 0) {
@@ -15942,6 +15947,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ4_K:
@@ -16055,6 +16061,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break;
case LLAMA_FTYPE_MOSTLY_IQ4_KSS: default_type = GGML_TYPE_IQ4_KSS; break;
case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break;
case LLAMA_FTYPE_MOSTLY_IQ3_K: default_type = GGML_TYPE_IQ3_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_K: default_type = GGML_TYPE_IQ3_K; break;
case LLAMA_FTYPE_MOSTLY_IQ3_KL: default_type = GGML_TYPE_IQ3_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_KL: default_type = GGML_TYPE_IQ3_K; break;