mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-21 15:09:40 +00:00
iq1_kt: CUDA dequantize
Testing with LlaMA-3.1-8B-Instruct, we get almost the same PPL as iq2_xxs, so about 0.2 bpw fewer bits for the same quality.
This commit is contained in:
@@ -75,6 +75,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
|
||||
{ "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",},
|
||||
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
|
||||
{ "IQ1_KT", LLAMA_FTYPE_MOSTLY_IQ1_KT, " 1.75 bpw trellis quantization", },
|
||||
{ "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", },
|
||||
{ "IQ2_KL", LLAMA_FTYPE_MOSTLY_IQ2_KL, " 2.69 bpw non-linear quantization", },
|
||||
{ "IQ3_KS", LLAMA_FTYPE_MOSTLY_IQ3_KS, " 3.19 bpw non-linear quantization", },
|
||||
|
||||
@@ -3506,6 +3506,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
case GGML_TYPE_IQ1_KT:
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
case GGML_TYPE_IQ4_KT:
|
||||
|
||||
@@ -358,6 +358,26 @@ float __device__ __forceinline__ trellis_next(uint32_t& val) {
|
||||
return (float)(h[0]+h[1]);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq1_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
|
||||
|
||||
int64_t ii = blockIdx.x;
|
||||
int64_t row = (QK_K * ii) / n_per_row;
|
||||
const char * cx = (const char *)vx + row * row_size;
|
||||
float scale = *(const float *)cx;
|
||||
const block_iq1_kt * x = (const block_iq1_kt *)(cx + sizeof(float));
|
||||
const int64_t i = ii - (row*n_per_row)/QK_K;
|
||||
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t ib = tid; // 0...31
|
||||
dst_t * y = yy + ii*QK_K + 8*ib;
|
||||
uint32_t idx = (x[i].ql[ib] | ((x[i].qh[ib%16] << (8 - 4*(ib/16))) & 0xf00) | ((x[i].sh[ib/4] << (8 - (ib%4))) & 0x1000)) + 4096;
|
||||
const float dl = scale * iq4k_values[x[i].sh[ib/4] & 0xf];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
y[j] = dl * trellis_next_int(idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
|
||||
|
||||
@@ -1505,6 +1525,13 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
|
||||
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq1_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
|
||||
const int64_t k = nrows * n_per_row;
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_iq1_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ1_KT, n_per_row));
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
|
||||
const int64_t k = nrows * n_per_row;
|
||||
@@ -1888,6 +1915,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
return dequantize_row_iq2_xxs_cuda;
|
||||
case GGML_TYPE_IQ1_KT:
|
||||
return dequantize_row_iq1_kt_cuda;
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
return dequantize_row_iq2_kt_cuda;
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
@@ -1987,6 +2016,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
return dequantize_row_iq2_xxs_cuda;
|
||||
case GGML_TYPE_IQ1_KT:
|
||||
return dequantize_row_iq1_kt_cuda;
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
return dequantize_row_iq2_kt_cuda;
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
|
||||
@@ -8615,10 +8615,25 @@ void quantize_row_iq1_kt_impl(const float * x, void * vy, int n_per_row, const f
|
||||
auto idx = best_idx;
|
||||
if (score_p > score_m) scales[ib] = dp;
|
||||
else {
|
||||
scales[ib] = dm; idx += Q::kNg;
|
||||
scales[ib] = dm; idx += Q::kNg; score_p = score_m;
|
||||
}
|
||||
for (int ig = 0; ig < Q::kNg; ++ig) all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig] = idx[ig];
|
||||
|
||||
scale_0 -= 8;
|
||||
quantizer.find_best_match( amax/scale_0, xb, weight, best_idx);
|
||||
auto [dp1, score_p1] = quantizer.find_best_scale(xb, weight, best_idx);
|
||||
quantizer.find_best_match(-amax/scale_0, xb, weight, best_idx + Q::kNg);
|
||||
auto [dm1, score_m1] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg);
|
||||
|
||||
if (score_p1 > score_p || score_m1 > score_p) {
|
||||
idx = best_idx;
|
||||
if (score_p1 > score_m1) scales[ib] = dp1;
|
||||
else {
|
||||
scales[ib] = dm1; idx += Q::kNg;
|
||||
}
|
||||
for (int ig = 0; ig < Q::kNg; ++ig) all_idx[(ibl*Q::kSuperBlockSize + ib*Q::kBlockSize)/Q::kGroupSize + ig] = idx[ig];
|
||||
}
|
||||
|
||||
float abs_scale = std::abs(scales[ib]);
|
||||
if (abs_scale > amax_scale) {
|
||||
amax_scale = abs_scale;
|
||||
@@ -8726,7 +8741,7 @@ void quantize_row_iq1_kt_impl(const float * x, void * vy, int n_per_row, const f
|
||||
}
|
||||
if (sumq2 > 0) {
|
||||
d = sumqx/sumq2;
|
||||
*dptr = d;
|
||||
*dptr = d * 1.07f;
|
||||
if (!d) return;
|
||||
} else {
|
||||
break;
|
||||
|
||||
@@ -206,6 +206,7 @@ extern "C" {
|
||||
LLAMA_FTYPE_MOSTLY_IQ4_KT = 153, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ3_KS = 154, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ2_KL = 155, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_IQ1_KT = 156, // except 1d tensors
|
||||
//
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors
|
||||
|
||||
@@ -4411,6 +4411,7 @@ struct llama_model_loader {
|
||||
case GGML_TYPE_IQ2_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_M_R4;break;
|
||||
case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
|
||||
case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break;
|
||||
case GGML_TYPE_IQ1_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ1_KT; break;
|
||||
case GGML_TYPE_IQ2_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KT; break;
|
||||
case GGML_TYPE_IQ3_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KT; break;
|
||||
case GGML_TYPE_IQ4_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KT; break;
|
||||
@@ -5156,6 +5157,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_M_R4: return "IQ2_M_R4 - 2.7 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_KT: return "IQ1_KT - 1.75 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_KT: return "IQ2_KT - 2.125 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_KT: return "IQ3_KT - 3.125 bpw";
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_KT: return "IQ4_KT - 4.0 bpw";
|
||||
@@ -19152,7 +19154,8 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) {
|
||||
new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 ||
|
||||
new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4|| new_type == GGML_TYPE_IQ3_KS ||
|
||||
new_type == GGML_TYPE_IQ2_KT || new_type == GGML_TYPE_IQ3_KT || new_type == GGML_TYPE_IQ4_KT ||
|
||||
new_type == GGML_TYPE_IQ5_KS || new_type == GGML_TYPE_IQ5_KS_R4|| new_type == GGML_TYPE_IQ2_KL) {
|
||||
new_type == GGML_TYPE_IQ5_KS || new_type == GGML_TYPE_IQ5_KS_R4|| new_type == GGML_TYPE_IQ2_KL ||
|
||||
new_type == GGML_TYPE_IQ1_KT) {
|
||||
if (nx % QK_K != 0) {
|
||||
LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
|
||||
convert_incompatible_tensor = true;
|
||||
@@ -19192,6 +19195,7 @@ static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) {
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_KS_R4:
|
||||
case GGML_TYPE_IQ4_XS_R8:
|
||||
case GGML_TYPE_IQ1_KT:
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
case GGML_TYPE_IQ4_KT:
|
||||
@@ -19324,7 +19328,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_KL ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4 ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4 ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) {
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ1_KT) {
|
||||
new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4) {
|
||||
@@ -19918,6 +19922,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XS_R4:default_type = GGML_TYPE_IQ2_XS_R4; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_KS: default_type = GGML_TYPE_IQ2_KS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_KT: default_type = GGML_TYPE_IQ1_KT; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_KT: default_type = GGML_TYPE_IQ2_KT; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;
|
||||
|
||||
Reference in New Issue
Block a user