mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
iq2_k: Basics
Quantize/dequantize, CUDA deqantize, AVX512 iqk_mul_mat.
This commit is contained in:
@@ -390,6 +390,7 @@ extern "C" {
|
||||
GGML_TYPE_IQ2_BN = 35,
|
||||
GGML_TYPE_Q8_K64 = 36,
|
||||
GGML_TYPE_IQ4_K = 37,
|
||||
GGML_TYPE_IQ2_K = 38,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
@@ -437,6 +438,7 @@ extern "C" {
|
||||
GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ4_K = 30, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ2_K = 31, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
|
||||
@@ -454,6 +454,14 @@ typedef struct {
|
||||
} block_iq4_k;
|
||||
static_assert(sizeof(block_iq4_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + 3*QK_K/64, "wrong iq4_k block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d;
|
||||
uint16_t extra;
|
||||
uint8_t scales[QK_K/32];
|
||||
uint8_t qs[QK_K/4];
|
||||
} block_iq2_k;
|
||||
static_assert(sizeof(block_iq2_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/32 + QK_K/4, "wrong iq2_k block size/padding");
|
||||
|
||||
#endif // GGML_COMMON_DECL
|
||||
#endif // GGML_COMMON_DECL
|
||||
|
||||
@@ -1890,5 +1898,10 @@ GGML_TABLE_BEGIN(int8_t, iq4k_values, 32)
|
||||
-123, -100, -79, -61, -45, -31, -18, -6, 5, 17, 29, 42, 57, 73, 93, 117
|
||||
GGML_TABLE_END()
|
||||
|
||||
GGML_TABLE_BEGIN(int8_t, iq2nl_values, 8)
|
||||
-31, -13, 1, 17, -26, -8, 6, 22
|
||||
GGML_TABLE_END()
|
||||
|
||||
|
||||
#endif // GGML_COMMON_IMPL
|
||||
#endif // GGML_COMMON_IMPL
|
||||
|
||||
@@ -2754,6 +2754,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
return true;
|
||||
|
||||
@@ -669,6 +669,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
|
||||
static constexpr int qi = QI4_XS;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_K> {
|
||||
static constexpr int qk = QK_K;
|
||||
static constexpr int qr = QR4_XS;
|
||||
static constexpr int qi = QI4_XS;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> {
|
||||
static constexpr int qk = QK_K;
|
||||
|
||||
@@ -543,6 +543,32 @@ static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq2_k * x = (const block_iq2_k *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
int ib128 = tid/16; // 0 or 1
|
||||
int il = tid%16; // 0...15
|
||||
dst_t * y = yy + i*QK_K + 128*ib128 + 2*il;
|
||||
const float d = (float)x[i].d * 1.025f; //1.0325f;
|
||||
const float dl1 = d * (2*((x[i].scales[4*ib128+0] >> 4*(il/8)) & 0xf) - 15);
|
||||
const float dl2 = d * (2*((x[i].scales[4*ib128+1] >> 4*(il/8)) & 0xf) - 15);
|
||||
const float dl3 = d * (2*((x[i].scales[4*ib128+2] >> 4*(il/8)) & 0xf) - 15);
|
||||
const float dl4 = d * (2*((x[i].scales[4*ib128+3] >> 4*(il/8)) & 0xf) - 15);
|
||||
const uint8_t * qs = x[i].qs + 32*ib128 + 2*il;
|
||||
const int16_t extra = x[i].extra >> (8*ib128 + (il/8));
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)];
|
||||
y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)];
|
||||
y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)];
|
||||
y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
@@ -678,6 +704,12 @@ static void dequantize_row_iq4_k_cuda(const void * vx, dst_t * y, const int64_t
|
||||
dequantize_block_iq4_k<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq2_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
||||
const int nb = (k + QK_K - 1) / QK_K;
|
||||
dequantize_block_iq2_k<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
||||
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
@@ -744,6 +776,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
return dequantize_row_iq4_xs_cuda;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
return dequantize_row_iq4_k_cuda;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
return dequantize_row_iq2_k_cuda;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return dequantize_row_iq3_s_cuda;
|
||||
case GGML_TYPE_F32:
|
||||
@@ -797,6 +831,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
return dequantize_row_iq4_xs_cuda;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
return dequantize_row_iq4_k_cuda;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
return dequantize_row_iq2_k_cuda;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return dequantize_row_iq3_s_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
|
||||
@@ -25,6 +25,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
|
||||
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
|
||||
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
|
||||
type == GGML_TYPE_IQ4_K ? vec_dot_iq4_k_q8_1 :
|
||||
type == GGML_TYPE_IQ2_K ? vec_dot_iq2_k_q8_1 :
|
||||
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
|
||||
nullptr;
|
||||
}
|
||||
@@ -48,6 +49,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
|
||||
type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
|
||||
type == GGML_TYPE_IQ4_K ? VDR_IQ4_K_Q8_1_MMVQ :
|
||||
type == GGML_TYPE_IQ2_K ? VDR_IQ2_K_Q8_1_MMVQ :
|
||||
1;
|
||||
}
|
||||
|
||||
@@ -352,6 +354,13 @@ static void mul_mat_vec_iq4_k_q8_1_cuda(
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_k_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
@@ -443,6 +452,9 @@ void ggml_cuda_op_mul_mat_vec_q(
|
||||
case GGML_TYPE_IQ4_K:
|
||||
mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
|
||||
@@ -1274,3 +1274,35 @@ static __device__ __forceinline__ float vec_dot_iq4_k_q8_1(
|
||||
return d * (sumi1 * ls1 + sumi2 * ls2);
|
||||
}
|
||||
|
||||
#define VDR_IQ2_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ2_K_Q8_1_MMQ 4
|
||||
|
||||
// TODO
|
||||
static __device__ __forceinline__ float vec_dot_iq2_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
return 0;
|
||||
//
|
||||
// const block_iq2_k * bq4 = (const block_iq2_k *) vbq + kbx;
|
||||
// const uint8_t * all_values = (const uint8_t *)iq4k_values;
|
||||
//
|
||||
// // iqs is 0...28
|
||||
// const int ib32 = iqs/4;
|
||||
// // Why iqs/4 ?
|
||||
// const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
||||
// const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32;
|
||||
// const uint16_t extra = bq4->extra >> 2*ib32;
|
||||
// int v1, v2;
|
||||
// int sumi1 = 0, sumi2 = 0;
|
||||
// for (int j = 0; j < 4; ++j) {
|
||||
// const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16);
|
||||
// get_int_from_table_16_shift(aux32, extra, all_values, v1, v2);
|
||||
// sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1);
|
||||
// sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2);
|
||||
// }
|
||||
// const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds);
|
||||
// const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2);
|
||||
// const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
// const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
|
||||
// return d * (sumi1 * ls1 + sumi2 * ls2);
|
||||
}
|
||||
|
||||
|
||||
@@ -14947,6 +14947,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
{
|
||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_K: break;
|
||||
case GGML_TYPE_IQ4_K: break;
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
|
||||
@@ -992,6 +992,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ2_K] = {
|
||||
.type_name = "iq2_k",
|
||||
.blck_size = QK_K,
|
||||
.type_size = sizeof(block_iq2_k),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq2_k,
|
||||
.from_float = quantize_row_iq2_k,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_k_ref,
|
||||
.vec_dot = vec_dot_iq2_k_q8_k,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
@@ -3342,6 +3354,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break;
|
||||
@@ -9592,6 +9605,7 @@ static void ggml_compute_forward_add(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
@@ -9973,6 +9987,7 @@ static void ggml_compute_forward_add1(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
@@ -10104,6 +10119,7 @@ static void ggml_compute_forward_acc(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
@@ -13024,6 +13040,7 @@ static void ggml_compute_forward_out_prod(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
@@ -13215,6 +13232,7 @@ static void ggml_compute_forward_set(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
@@ -13480,6 +13498,7 @@ static void ggml_compute_forward_get_rows(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
@@ -14072,6 +14091,7 @@ static void ggml_compute_forward_clamp(
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q8_K:
|
||||
@@ -20808,6 +20828,7 @@ size_t ggml_quantize_chunk(
|
||||
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
|
||||
@@ -742,6 +742,88 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
|
||||
|
||||
};
|
||||
|
||||
struct IQXKScales {
|
||||
IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm_set1_epi8(shift)), min(_mm256_set1_epi8(min_val)) {}
|
||||
template <typename Q8>
|
||||
inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m512i * scales) const {
|
||||
auto extra128 = _mm_set1_epi16(extra);
|
||||
extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
|
||||
extra128 = _mm_and_si128(extra128, e5);
|
||||
extra128 = _mm_shuffle_epi8(extra128, eshuffle);
|
||||
auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)),
|
||||
_mm256_add_epi16(_mm256_set1_epi16(-32), _mm256_cvtepi8_epi16(extra128)));
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
|
||||
accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
|
||||
}
|
||||
scales16 = MM256_SET_M128I(scales8, scales8);
|
||||
scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
|
||||
scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
|
||||
}
|
||||
const __m128i eshift;
|
||||
const __m256i min;
|
||||
const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
|
||||
const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
|
||||
const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
|
||||
const __m128i e5 = _mm_set1_epi8(5);
|
||||
const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000);
|
||||
const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404);
|
||||
};
|
||||
|
||||
struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
|
||||
DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(IQXKScales(5, -32)), values(load_values()) {}
|
||||
template <typename Q8>
|
||||
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
prepare(x[i].qs);
|
||||
iqxk.process(i, d, x[i].extra, make_scales(x[i].scales), q8, accm, scales);
|
||||
//auto scales8 = make_scales(x[i].scales);
|
||||
//auto extra128 = _mm_set1_epi16(x[i].extra);
|
||||
//extra128 = _mm_cmpeq_epi8(_mm_and_si128(extra128, emask), emask);
|
||||
//extra128 = _mm_and_si128(extra128, e5);
|
||||
//extra128 = _mm_shuffle_epi8(extra128, eshuffle);
|
||||
//auto scales16 = _mm256_mullo_epi16(_mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, scale_shuffle)),
|
||||
// _mm256_add_epi16(_mm256_set1_epi16(-32), _mm256_cvtepi8_epi16(extra128)));
|
||||
//for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
// const __m256i prod = _mm256_madd_epi16(scales16, q8.load_bsums(iy, i));
|
||||
// accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
|
||||
//}
|
||||
//scales16 = MM256_SET_M128I(scales8, scales8);
|
||||
//scales[0] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle1));
|
||||
//scales[1] = _mm512_cvtepi8_epi16(_mm256_shuffle_epi8(scales16, shuffle2));
|
||||
}
|
||||
inline void prepare(const uint8_t * q2) {
|
||||
bits.prepare(q2);
|
||||
bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]);
|
||||
bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]);
|
||||
bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]);
|
||||
bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]);
|
||||
}
|
||||
static inline __m512i load_values() {
|
||||
static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0};
|
||||
auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl);
|
||||
auto val256 = MM256_SET_M128I(val128, val128);
|
||||
return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
|
||||
}
|
||||
inline __m128i make_scales(const uint8_t * scales_l) const {
|
||||
uint64_t aux64; std::memcpy(&aux64, scales_l, 8);
|
||||
auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), _mm_set1_epi8(0xf));
|
||||
return _mm_add_epi8(_mm_slli_epi16(scl, 1), m15);
|
||||
}
|
||||
Q2Bits bits;
|
||||
IQXKScales iqxk;
|
||||
|
||||
const __m512i values;
|
||||
const __m128i m15 = _mm_set1_epi8(-15);
|
||||
//const __m128i scale_shuffle = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
|
||||
//const __m128i m15 = _mm_set1_epi8(-15);
|
||||
//const __m128i emask = _mm_set_epi32(0x80804040, 0x20201010, 0x08080404, 0x02020101);
|
||||
//const __m128i eshuffle = _mm_set_epi32(0x0f0d0b09, 0x07050301, 0x0e0c0a08, 0x06040200);
|
||||
//const __m128i e5 = _mm_set1_epi8(5);
|
||||
//const __m256i shuffle1 = _mm256_set_epi64x(0x0b0b0b0b09090909, 0x0303030301010101, 0x0a0a0a0a08080808, 0x0202020200000000);
|
||||
//const __m256i shuffle2 = _mm256_set_epi64x(0x0f0f0f0f0d0d0d0d, 0x0707070705050505, 0x0e0e0e0e0c0c0c0c, 0x0606060604040404);
|
||||
};
|
||||
|
||||
struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
|
||||
DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
|
||||
template <typename Q8>
|
||||
@@ -784,11 +866,6 @@ struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
|
||||
auto sch = _mm_shuffle_epi8(aux, hshuff);
|
||||
return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
|
||||
}
|
||||
//static __m256i load_shuffle(int i) {
|
||||
// static const uint64_t k_shuffles[8] = {0x0202020200000000, 0x0a0a0a0a08080808, 0x0303030301010101, 0x0b0b0b0b09090909,
|
||||
// 0x0606060604040404, 0x0e0e0e0e0c0c0c0c, 0x0707070705050505, 0x0f0f0f0f0d0d0d0d};
|
||||
// return _mm256_loadu_si256((const __m256i *)k_shuffles + i);
|
||||
//}
|
||||
|
||||
Q4Bits bits;
|
||||
const __m512i values;
|
||||
@@ -2897,6 +2974,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ4XS>(mm);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ2K>(mm);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
assert (ne00 % QK_K == 0);
|
||||
MulMat::set_functions<DequantizerIQ4K>(mm);
|
||||
|
||||
@@ -471,8 +471,8 @@ void vec_dot_iq4_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx,
|
||||
const int8_t * q8 = y[ibl].qs;
|
||||
int32_t sum = 0;
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
const int ls1 = (x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30) - 32;
|
||||
const int ls2 = (x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30) - 32;
|
||||
const int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
|
||||
const int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
|
||||
h >>= 4;
|
||||
const int8_t * values1 = iq4k_values + 16*(extra & 1);
|
||||
const int8_t * values2 = iq4k_values + 8*(extra & 2);
|
||||
@@ -698,3 +698,218 @@ size_t quantize_iq4_k(const float * src, void * dst, int64_t nrows, int64_t n_pe
|
||||
}
|
||||
return nrows * nblock * sizeof(block_iq4_k);
|
||||
}
|
||||
|
||||
//
|
||||
// ============================================== iq2_K
|
||||
//
|
||||
|
||||
namespace {
|
||||
|
||||
inline int best_index_iq2nl(const int8_t * values, float x) {
|
||||
int idx = x < values[1] ? 0 : x > values[2] ? 2 : 1;
|
||||
return x - values[idx] < values[idx+1] - x ? idx : idx + 1;
|
||||
}
|
||||
|
||||
void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) {
|
||||
|
||||
constexpr int kBlockSize = 16;
|
||||
|
||||
block_iq2_k * y = (block_iq2_k *)vy;
|
||||
|
||||
float scales[QK_K/kBlockSize];
|
||||
float weight[kBlockSize];
|
||||
float sumx[kBlockSize+1], sumw[kBlockSize+1];
|
||||
|
||||
std::array<std::pair<float,int>, kBlockSize> pairs;
|
||||
|
||||
const int8_t * shifted_values = iq2nl_values + 4;
|
||||
|
||||
for (int ibl = 0; ibl < n_per_row/QK_K; ++ibl) {
|
||||
|
||||
memset(&y[ibl], 0, sizeof(block_iq2_k));
|
||||
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
||||
|
||||
const float * xbl = x + ibl*QK_K;
|
||||
float sumx2 = 0;
|
||||
for (int j = 0; j < QK_K; ++j) sumx2 += xbl[j]*xbl[j];
|
||||
const float sigma2 = 1.5f*sumx2/QK_K;
|
||||
|
||||
uint16_t extra = 0;
|
||||
|
||||
float max_abs_scale = 0;
|
||||
|
||||
for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
|
||||
const float * xb = xbl + kBlockSize*ib;
|
||||
if (quant_weights) {
|
||||
const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize;
|
||||
for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
||||
} else {
|
||||
for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j];
|
||||
}
|
||||
for (int j = 0; j < kBlockSize; ++j) pairs[j] = {xb[j], j};
|
||||
std::sort(pairs.begin(), pairs.end());
|
||||
sumx[0] = sumw[0] = 0;
|
||||
for (int j = 0; j < kBlockSize; ++j) {
|
||||
int jj = pairs[j].second;
|
||||
sumw[j+1] = sumw[j] + weight[jj];
|
||||
sumx[j+1] = sumx[j] + weight[jj]*xb[jj];
|
||||
}
|
||||
float best = 0, d = 0;
|
||||
bool is_shifted = false;
|
||||
float sumqx, sumq2;
|
||||
for (int i1 = 0; i1 < kBlockSize; ++i1) {
|
||||
for (int i2 = i1; i2 < kBlockSize; ++i2) {
|
||||
for (int i3 = i2; i3 < kBlockSize; ++i3) {
|
||||
sumqx = (sumx[i1] - sumx[ 0])*iq2nl_values[0] + (sumx[i2] - sumx[i1])*iq2nl_values[1]
|
||||
+ (sumx[i3] - sumx[i2])*iq2nl_values[2] + (sumx[kBlockSize] - sumx[i3])*iq2nl_values[3];
|
||||
sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[0]*iq2nl_values[0] + (sumw[i2] - sumw[i1])*iq2nl_values[1]*iq2nl_values[1]
|
||||
+ (sumw[i3] - sumw[i2])*iq2nl_values[2]*iq2nl_values[2] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[3]*iq2nl_values[3];
|
||||
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
||||
d = sumqx/sumq2; best = d*sumqx; is_shifted = false;
|
||||
}
|
||||
sumqx = (sumx[i1] - sumx[ 0])*shifted_values[0] + (sumx[i2] - sumx[i1])*shifted_values[1]
|
||||
+ (sumx[i3] - sumx[i2])*shifted_values[2] + (sumx[kBlockSize] - sumx[i3])*shifted_values[3];
|
||||
sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[0]*shifted_values[0] + (sumw[i2] - sumw[i1])*shifted_values[1]*shifted_values[1]
|
||||
+ (sumw[i3] - sumw[i2])*shifted_values[2]*shifted_values[2] + (sumw[kBlockSize] - sumw[i3])*shifted_values[3]*shifted_values[3];
|
||||
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
||||
d = sumqx/sumq2; best = d*sumqx; is_shifted = true;
|
||||
}
|
||||
sumqx = (sumx[i1] - sumx[ 0])*iq2nl_values[3] + (sumx[i2] - sumx[i1])*iq2nl_values[2]
|
||||
+ (sumx[i3] - sumx[i2])*iq2nl_values[1] + (sumx[kBlockSize] - sumx[i3])*iq2nl_values[0];
|
||||
sumq2 = (sumw[i1] - sumw[ 0])*iq2nl_values[3]*iq2nl_values[3] + (sumw[i2] - sumw[i1])*iq2nl_values[2]*iq2nl_values[2]
|
||||
+ (sumw[i3] - sumw[i2])*iq2nl_values[1]*iq2nl_values[1] + (sumw[kBlockSize] - sumw[i3])*iq2nl_values[0]*iq2nl_values[0];
|
||||
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
||||
d = sumqx/sumq2; best = d*sumqx; is_shifted = false;
|
||||
}
|
||||
sumqx = (sumx[i1] - sumx[ 0])*shifted_values[3] + (sumx[i2] - sumx[i1])*shifted_values[2]
|
||||
+ (sumx[i3] - sumx[i2])*shifted_values[1] + (sumx[kBlockSize] - sumx[i3])*shifted_values[0];
|
||||
sumq2 = (sumw[i1] - sumw[ 0])*shifted_values[3]*shifted_values[3] + (sumw[i2] - sumw[i1])*shifted_values[2]*shifted_values[2]
|
||||
+ (sumw[i3] - sumw[i2])*shifted_values[1]*shifted_values[1] + (sumw[kBlockSize] - sumw[i3])*shifted_values[0]*shifted_values[0];
|
||||
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
|
||||
d = sumqx/sumq2; best = d*sumqx; is_shifted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
scales[ib] = d;
|
||||
if (is_shifted) extra |= (1 << ib);
|
||||
|
||||
float abs_scale = fabsf(scales[ib]);
|
||||
max_abs_scale = MAX(max_abs_scale, abs_scale);
|
||||
}
|
||||
|
||||
if (!max_abs_scale) continue;
|
||||
|
||||
float d = max_abs_scale/15;
|
||||
y[ibl].d = GGML_FP32_TO_FP16(d);
|
||||
y[ibl].extra = extra;
|
||||
float id = 1/d;
|
||||
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int ib = 0; ib < QK_K/kBlockSize; ++ib) {
|
||||
int ls = nearest_int(0.5f*(id*scales[ib]+15));
|
||||
ls = MAX(0, MIN(15, ls));
|
||||
y[ibl].scales[ib/2] |= (ls << 4*(ib%2));
|
||||
ls = 2*ls - 15;
|
||||
float dl = d * ls;
|
||||
if (dl) {
|
||||
const int8_t * block_values = y[ibl].extra & (1 << ib) ? shifted_values : iq2nl_values;
|
||||
const float * xb = xbl + kBlockSize*ib;
|
||||
if (quant_weights) {
|
||||
const float * qw = quant_weights + ibl*QK_K + ib*kBlockSize;
|
||||
for (int j = 0; j < kBlockSize; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
||||
} else {
|
||||
for (int j = 0; j < kBlockSize; ++j) weight[j] = 0.25f*sigma2 + xb[j]*xb[j];
|
||||
}
|
||||
float idl = 1/dl;
|
||||
int ib32 = ib/2;
|
||||
int offset = 16*(ib%2);
|
||||
uint8_t * qs = y[ibl].qs + 32*(ib32/4) + offset;
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
const float al = idl*xb[j];
|
||||
int ibest = best_index_iq2nl(block_values, al);
|
||||
qs[j] |= (ibest << 2*(ib32%4));
|
||||
float w = weight[j];
|
||||
float q = block_values[ibest]*ls;
|
||||
sumqx += w*q*xb[j];
|
||||
sumq2 += w*q*q;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(sumqx/sumq2);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
quantize_iq2_k(x, (void *)y, 1, k, nullptr);
|
||||
}
|
||||
|
||||
void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
block_iq2_k * y = (block_iq2_k *)vy;
|
||||
quantize_row_iq2_k_ref(x, y, k);
|
||||
}
|
||||
|
||||
size_t quantize_iq2_k(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
|
||||
GGML_ASSERT(n_per_row%QK_K == 0);
|
||||
int nblock = n_per_row/QK_K;
|
||||
char * qrow = (char *)dst;
|
||||
for (int64_t row = 0; row < nrows; ++row) {
|
||||
quantize_row_iq2_k_impl(src, (void *)qrow, n_per_row, imatrix);
|
||||
src += n_per_row;
|
||||
qrow += nblock*sizeof(block_iq2_k);
|
||||
}
|
||||
return nrows * nblock * sizeof(block_iq2_k);
|
||||
}
|
||||
|
||||
void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d);
|
||||
const uint8_t * qs = x[i].qs;
|
||||
|
||||
uint16_t extra = x[i].extra;
|
||||
|
||||
int shift = 0;
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
||||
float dl1 = d * (2*(x[i].scales[ib32] & 0xf) - 15);
|
||||
float dl2 = d * (2*(x[i].scales[ib32] >> 4) - 15);
|
||||
const int8_t * values1 = extra & 1 ? iq2nl_values + 4 : iq2nl_values;
|
||||
const int8_t * values2 = extra & 2 ? iq2nl_values + 4 : iq2nl_values;
|
||||
extra >>= 2;
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
y[j+ 0] = dl1 * values1[(qs[j+ 0] >> shift) & 3];
|
||||
y[j+16] = dl2 * values2[(qs[j+16] >> shift) & 3];
|
||||
}
|
||||
y += 32;
|
||||
shift += 2;
|
||||
if (shift == 8) { qs += 32; shift = 0; }
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
GGML_UNUSED(nrc);
|
||||
GGML_UNUSED(bx);
|
||||
GGML_UNUSED(by);
|
||||
GGML_UNUSED(bs);
|
||||
|
||||
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_K, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
const block_iq2_k * x = (const block_iq2_k *)vx;
|
||||
const block_q8_K * y = (const block_q8_K *)vy;
|
||||
}
|
||||
|
||||
@@ -19,6 +19,12 @@ size_t quantize_iq4_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
||||
void dequantize_row_iq4_k(const block_iq4_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void vec_dot_iq4_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
size_t quantize_iq2_k(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user