Be able to re-quantize MS BitNet I2_S models (#169)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-01-10 18:18:04 +02:00
committed by GitHub
parent c411615505
commit 400b774294
6 changed files with 58 additions and 0 deletions

View File

@@ -392,6 +392,10 @@ extern "C" {
GGML_TYPE_Q4_0_4_8 = 32,
GGML_TYPE_Q4_0_8_8 = 33,
//
// So we are able to consume MS BitNet I2_S quants
//
GGML_TYPE_I2_S = 36,
//
GGML_TYPE_Q6_0 = 133,
GGML_TYPE_IQ1_BN = 134,
GGML_TYPE_IQ2_BN = 135,

View File

@@ -15236,6 +15236,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_I2_S:
// nothing to validate
break;
default:

View File

@@ -1605,6 +1605,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.nrows = 1,
.row_meta_size = 0,
},
[GGML_TYPE_I2_S] = {
.type_name = "i2_s",
.blck_size = 1,
.type_size = 1,
.is_quantized = true,
.to_float = dequantize_row_ms_i2s,
.from_float = NULL,
.from_float_ref = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
.row_meta_size = 0,
},
};
// For internal test use
@@ -4130,6 +4143,10 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
// hack for I2_S
if(tensor->type == GGML_TYPE_I2_S) {
nbytes = nbytes / 4 + 32;
}
}
else {
nbytes = tensor->nb[1]; //tensor->ne[0]*tensor->nb[0]/blck_size;
@@ -10825,6 +10842,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -11290,6 +11308,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -11452,6 +11471,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -14660,6 +14680,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -15062,6 +15083,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -15358,6 +15380,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:
@@ -15983,6 +16006,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q4_0_R4:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_I2_S:
case GGML_TYPE_Q8_0_R4:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_KS:

View File

@@ -5874,3 +5874,23 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) {
tensor->type = r.new_type;
}
void dequantize_row_ms_i2s(const void * vx, float * y, int64_t k) {
constexpr int kBlockSize = 128;
constexpr int kGroupSize = kBlockSize/4;
GGML_ASSERT(k % kBlockSize == 0);
const uint8_t * x = (const uint8_t *)vx;
const float * dptr = (const float *)(x + k/4);
const float d = dptr[0];
int nb = k/kBlockSize;
for (int ib = 0; ib < nb; ++ib) {
for (int ig = 0; ig < kBlockSize/kGroupSize; ++ig) {
int shift = 6 - 2*ig;
for (int j = 0; j < kGroupSize; ++j) {
y[j] = d * (((x[j] >> shift) & 3) - 1);
}
y += kGroupSize;
}
x += kGroupSize;
}
}

View File

@@ -217,6 +217,9 @@ void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT d
void iqk_repack_tensor(struct ggml_tensor * tensor);
// So we can re-pack Microsoft's BitNet I2_S quants
void dequantize_row_ms_i2s(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
#ifdef __cplusplus
}
#endif