From 69689d60e5f547b190c5d84162cca9541f057bb2 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 5 Dec 2024 06:53:41 +0200 Subject: [PATCH] Not working bf16_r4 --- examples/quantize/quantize.cpp | 1 + ggml/include/ggml.h | 2 + ggml/src/ggml-quants.c | 1 + ggml/src/ggml.c | 20 ++++ ggml/src/iqk/iqk_mul_mat.cpp | 170 ++++++++++++++++++++++++++++++++- ggml/src/iqk/iqk_quantize.cpp | 56 +++++++++++ ggml/src/iqk/iqk_quantize.h | 3 + include/llama.h | 1 + src/llama.cpp | 13 +++ 9 files changed, 266 insertions(+), 1 deletion(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 0f906b83..4401eb1d 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -77,6 +77,7 @@ static const std::vector QUANT_OPTIONS = { { "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", }, { "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", }, + { "BF16_R4", LLAMA_FTYPE_MOSTLY_BF16_R4, "14.00G, -0.0050 ppl @ Mistral-7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, // Note: Ensure COPY comes after F32 to avoid ftype 0 from matching. { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f9ff97a7..febb8960 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -420,6 +420,7 @@ extern "C" { GGML_TYPE_Q6_K_R4 = 214, GGML_TYPE_IQ4_NL_R4 = 220, GGML_TYPE_IQ4_XS_R4 = 223, + GGML_TYPE_BF16_R4 = 230, GGML_TYPE_Q6_0_R4 = 233, GGML_TYPE_IQ2_BN_R4 = 335, GGML_TYPE_IQ4_K_R4 = 339, @@ -493,6 +494,7 @@ extern "C" { GGML_FTYPE_MOSTLY_Q6_K_R4 = 214, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_NL_R4 = 219, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_XS_R4 = 222, // except 1d tensors + GGML_FTYPE_MOSTLY_BF16_R4 = 224, // except 1d tensors GGML_FTYPE_MOSTLY_Q6_0_R4 = 227, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_BN_R4 = 329, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index f12c9fe8..d76d41d9 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15209,6 +15209,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_Q6_K_R4: break; case GGML_TYPE_IQ4_K_R4: break; case GGML_TYPE_Q8_K_R8: break; + case GGML_TYPE_BF16_R4: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 772c70c4..1bdc7e92 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1231,6 +1231,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_BF16_R4] = { + .type_name = "bf16_r4", + .blck_size = 1, + .type_size = sizeof(ggml_bf16_t), + .is_quantized = false, + //.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, + //.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, + //.from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref, + //.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, + .vec_dot_type = GGML_TYPE_BF16, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q4_0_4_4] = { .type_name = "q4_0_4x4", .blck_size = QK4_0, @@ -4110,6 +4123,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; + case GGML_FTYPE_MOSTLY_BF16_R4: wtype = GGML_TYPE_BF16_R4;break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; @@ -15748,6 +15762,7 @@ static void ggml_compute_forward_clamp( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_BF16_R4: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -22651,6 +22666,11 @@ size_t ggml_quantize_chunk( ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n); result = n * elemsize; } break; + case GGML_TYPE_BF16_R4: + { + repack_f32_bf16_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row); + result = nrows * row_size; + } break; case GGML_TYPE_F32: { size_t elemsize = sizeof(float); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 75e5c3c1..1698e652 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -5512,7 +5512,8 @@ struct QFBaseBF16 { using Data = __m512bh; using Acc = __m512; static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); } - static inline Acc acc(Acc prev, const Data& y, const Data& x) { + //static inline Acc acc(Acc prev, const Data& y, const Data& x) { + static inline Acc acc(Acc prev, Data y, Data x) { return _mm512_dpbf16_ps(prev, y, x); } static inline Acc acc_first(const Data& y, const Data& x) { @@ -5563,6 +5564,150 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, } for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix])); } + +template +void mul_mat_Qx_Qy_1xN_r4(int n, const char * cx, size_t bx, const DataInfo& info, int nrc_x) { + int nb = n/QFBaseBF16::k_step; + //printf("%s: n = %d, nrc_x = %d, bx = %zu nb = %d\n", __func__, n, nrc_x, bx, nb); + QFTBF16 y(info); + QFBaseBF16::Acc acc[nrc_y] = {}; + QFBaseBF16::Data xv[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + //printf("Working on %d out of %d\n", ix, nrc_x); + QFTBF16<1> x(cx + ix*bx, bx); + for (int i = 0; i < nb; ++i) { + //printf(" block %d out of %d\n", i, nb); + for (int k = 0; k < 4; ++k) xv[k] = x.load1(0, 4*i+k); + for (int iy = 0; iy < nrc_y; ++iy) { + //printf(" iy = %d\n", iy); + auto vy = y.load1(iy, i); + acc[iy] = QFBaseBF16::acc(acc[iy], xv[0], __m512bh(_mm512_shuffle_epi32(__m512i(vy), _MM_PERM_ENUM(0x00)))); + acc[iy] = QFBaseBF16::acc(acc[iy], xv[1], __m512bh(_mm512_shuffle_epi32(__m512i(vy), _MM_PERM_ENUM(0x55)))); + acc[iy] = QFBaseBF16::acc(acc[iy], xv[2], __m512bh(_mm512_shuffle_epi32(__m512i(vy), _MM_PERM_ENUM(0xaa)))); + acc[iy] = QFBaseBF16::acc(acc[iy], xv[3], __m512bh(_mm512_shuffle_epi32(__m512i(vy), _MM_PERM_ENUM(0xff)))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(acc[iy]), _mm512_extractf32x8_ps(acc[iy], 1)); + auto sum128 = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1)); + info.store(ix, iy, sum128); + acc[iy] = _mm512_setzero_ps(); + } + } +} + +template +void mul_mat_Qx_Qy_2xN_r4(int n, const char * cx, size_t bx, const DataInfo& info, int nrc_x) { + int nb = n/QFBaseBF16::k_step; + //printf("%s: n = %d, nrc_x = %d, nrc_y = %d, bx = %zu nb = %d\n", __func__, n, nrc_x, nrc_y, bx, nb); + QFTBF16 y(info); + QFBaseBF16::Acc acc[2*nrc_y]; + QFBaseBF16::Data xv[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + QFTBF16<1> x1(cx + (ix+0)*bx, bx); + QFTBF16<1> x2(cx + (ix+4)*bx, bx); + for (int iy = 0; iy < 2*nrc_y; ++iy) acc[iy] = _mm512_setzero_ps(); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 4; ++k) { xv[2*k+0] = x1.load1(0, 4*i+k); xv[2*k+1] = x2.load1(0, 4*i+k); } + for (int iy = 0; iy < nrc_y; ++iy) { + auto vy = y.load1(iy, i); + __m512bh sy = (__m512bh)_mm512_shuffle_epi32((__m512i)vy, _MM_PERM_ENUM(0x00)); + acc[2*iy+0] = QFBaseBF16::acc(acc[2*iy+0], xv[0], sy); + acc[2*iy+1] = QFBaseBF16::acc(acc[2*iy+1], xv[1], sy); + sy = (__m512bh)_mm512_shuffle_epi32((__m512i)vy, _MM_PERM_ENUM(0x55)); + acc[2*iy+0] = QFBaseBF16::acc(acc[2*iy+0], xv[2], sy); + acc[2*iy+1] = QFBaseBF16::acc(acc[2*iy+1], xv[3], sy); + sy = (__m512bh)_mm512_shuffle_epi32((__m512i)vy, _MM_PERM_ENUM(0xaa)); + acc[2*iy+0] = QFBaseBF16::acc(acc[2*iy+0], xv[4], sy); + acc[2*iy+1] = QFBaseBF16::acc(acc[2*iy+1], xv[5], sy); + sy = (__m512bh)_mm512_shuffle_epi32((__m512i)vy, _MM_PERM_ENUM(0xff)); + acc[2*iy+0] = QFBaseBF16::acc(acc[2*iy+0], xv[6], sy); + acc[2*iy+1] = QFBaseBF16::acc(acc[2*iy+1], xv[7], sy); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(acc[2*iy+0]), _mm512_extractf32x8_ps(acc[2*iy+0], 1)); + auto sum128 = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1)); + info.store(ix, iy, sum128); + sum256 = _mm256_add_ps(_mm512_castps512_ps256(acc[2*iy+1]), _mm512_extractf32x8_ps(acc[2*iy+1], 1)); + sum128 = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1)); + info.store(ix+4, iy, sum128); + } + } +} + +template +void mul_mat_Qx_Qy_4xN_r4(int n, const char * cx, size_t bx, const DataInfo& info, int nrc_x) { + int nb = n/QFBaseBF16::k_step; + //printf("%s: n = %d, nrc_x = %d, nrc_y = %d, bx = %zu nb = %d\n", __func__, n, nrc_x, nrc_y, bx, nb); + QFTBF16 y(info); + QFBaseBF16::Acc acc[4*nrc_y] = {}; + QFBaseBF16::Data xv[16]; + for (int ix = 0; ix < nrc_x; ix += 16) { + QFTBF16<1> x1(cx + (ix+0)*bx, bx); + QFTBF16<1> x2(cx + (ix+4)*bx, bx); + QFTBF16<1> x3(cx + (ix+8)*bx, bx); + QFTBF16<1> x4(cx + (ix+12)*bx, bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 4; ++k) { + xv[4*k+0] = x1.load1(0, 4*i+k); + xv[4*k+1] = x2.load1(0, 4*i+k); + xv[4*k+2] = x3.load1(0, 4*i+k); + xv[4*k+3] = x4.load1(0, 4*i+k); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto vy = y.load1(iy, i); + __m512bh sy = (__m512bh)_mm512_shuffle_epi32((__m512i)vy, _MM_PERM_ENUM(0x00)); + acc[4*iy+0] = QFBaseBF16::acc(acc[4*iy+0], xv[0], sy); + acc[4*iy+1] = QFBaseBF16::acc(acc[4*iy+1], xv[1], sy); + acc[4*iy+2] = QFBaseBF16::acc(acc[4*iy+2], xv[2], sy); + acc[4*iy+3] = QFBaseBF16::acc(acc[4*iy+3], xv[3], sy); + sy = (__m512bh)_mm512_shuffle_epi32((__m512i)vy, _MM_PERM_ENUM(0x55)); + acc[4*iy+0] = QFBaseBF16::acc(acc[4*iy+0], xv[4], sy); + acc[4*iy+1] = QFBaseBF16::acc(acc[4*iy+1], xv[5], sy); + acc[4*iy+2] = QFBaseBF16::acc(acc[4*iy+2], xv[6], sy); + acc[4*iy+3] = QFBaseBF16::acc(acc[4*iy+3], xv[7], sy); + sy = (__m512bh)_mm512_shuffle_epi32((__m512i)vy, _MM_PERM_ENUM(0xaa)); + acc[4*iy+0] = QFBaseBF16::acc(acc[4*iy+0], xv[8], sy); + acc[4*iy+1] = QFBaseBF16::acc(acc[4*iy+1], xv[9], sy); + acc[4*iy+2] = QFBaseBF16::acc(acc[4*iy+2], xv[10], sy); + acc[4*iy+3] = QFBaseBF16::acc(acc[4*iy+3], xv[11], sy); + sy = (__m512bh)_mm512_shuffle_epi32((__m512i)vy, _MM_PERM_ENUM(0xff)); + acc[4*iy+0] = QFBaseBF16::acc(acc[4*iy+0], xv[12], sy); + acc[4*iy+1] = QFBaseBF16::acc(acc[4*iy+1], xv[13], sy); + acc[4*iy+2] = QFBaseBF16::acc(acc[4*iy+2], xv[14], sy); + acc[4*iy+3] = QFBaseBF16::acc(acc[4*iy+3], xv[15], sy); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + for (int k = 0; k < 4; ++k) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(acc[4*iy+k]), _mm512_extractf32x8_ps(acc[4*iy+k], 1)); + auto sum128 = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1)); + info.store(ix+4*k, iy, sum128); + acc[4*iy+k] = _mm512_setzero_ps(); + } + } + } +} + +template +void mul_mat_fX_fY_r4(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%16 == 0); + const char * cx = (const char *)vx; + mul_mat_Qx_Qy_4xN_r4(n, cx, bx, info, 16*(nrc_x/16)); + return; + //mul_mat_Qx_Qy_1xN_r4(n, cx, bx, info, nrc_x); + //return; + if (nrc_x/8 > 0) { + mul_mat_Qx_Qy_2xN_r4(n, cx, bx, info, 8*(nrc_x/8)); + cx += 8*(nrc_x/8)*bx; + nrc_x -= 8*(nrc_x/8); + } + if (nrc_x/4 > 0) { + mul_mat_Qx_Qy_1xN_r4(n, cx, bx, info, nrc_x); + } +} + template void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { constexpr int k_nx = nrc_y <= 2 ? 8 : 5; @@ -5777,6 +5922,17 @@ void set_mul_mat_bf16(MulMat& mm) { mm.funcs[3] = mul_mat_fX_fY_T<4>; mm.funcs[4] = mul_mat_fX_fY_T<5>; } +void set_mul_mat_bf16_r4(MulMat& mm) { + for (auto& f : mm.funcs) f = nullptr; + mm.funcs[0] = mul_mat_fX_fY_r4<1>; + mm.funcs[1] = mul_mat_fX_fY_r4<2>; + mm.funcs[2] = mul_mat_fX_fY_r4<3>; + mm.funcs[3] = mul_mat_fX_fY_r4<4>; + //mm.funcs[4] = mul_mat_fX_fY_r4<5>; + //mm.funcs[5] = mul_mat_fX_fY_r4<6>; + //mm.funcs[6] = mul_mat_fX_fY_r4<7>; + //mm.funcs[7] = mul_mat_fX_fY_r4<8>; +} #endif bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { @@ -5794,6 +5950,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { return true; } + if (typeA == GGML_TYPE_BF16_R4) { + //printf("%s: %s\n", __func__, ggml_type_name((ggml_type)typeB)); + if (ne00 % 32) return false; + switch (typeB) { +#ifdef __AVX512BF16__ + case GGML_TYPE_BF16: set_mul_mat_bf16_r4(mm); break; +#endif + default: return false; + } + return true; + } + if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) { if (ne00 % 4) return false; } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index de8c0d99..ada40703 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -4759,3 +4759,59 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b GGML_UNUSED(by); } +// +// ========================================= bf16_r4 +// +namespace { +template +void repack_bf16(const Src& src, ggml_bf16_t * dst) { + GGML_ASSERT(src.nrows%4 == 0); + for (int row = 0; row < src.nrows; row += 4) { + auto y = dst + 4*row*src.n_per_row; + for (int j = 0; j < src.n_per_row; j += 32) { + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 4; ++k) for (int i = 0; i < 2; ++i) { + y[32*l+2*k+i+ 0] = src.value(row+k, j+2*l+i+ 0); + y[32*l+2*k+i+ 8] = src.value(row+k, j+2*l+i+ 8); + y[32*l+2*k+i+16] = src.value(row+k, j+2*l+i+16); + y[32*l+2*k+i+24] = src.value(row+k, j+2*l+i+24); + } + } + y += 128; + } + } +} +struct F32toBF16 { + F32toBF16(const void * src, int64_t nrows, int64_t n_per_row) : nrows(nrows), n_per_row(n_per_row), x((const float *)src) {} + inline ggml_bf16_t value(int row, int j) const { + union { float f; uint32_t u; } helper_32; + union { ggml_bf16_t f; uint16_t u; } helper_16; + helper_32.f = x[row*n_per_row + j]; + helper_16.u = helper_32.u >> 16; + return helper_16.f; + } + int64_t nrows; + int64_t n_per_row; +private: + const float * x; +}; +struct BF16 { + BF16(const void * src, int64_t nrows, int64_t n_per_row) : nrows(nrows), n_per_row(n_per_row), x((const ggml_bf16_t *)src) {} + inline ggml_bf16_t value(int row, int j) const { return x[row*n_per_row + j]; } + int64_t nrows; + int64_t n_per_row; +private: + const ggml_bf16_t * x; +}; +} + +void repack_f32_bf16_r4 (const void * src, void * dst, int64_t nrows, int64_t n_per_row) { + F32toBF16 helper(src, nrows, n_per_row); + repack_bf16(helper, (ggml_bf16_t *)dst); +} + +void repack_bf16_bf16_r4(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row) { + BF16 helper(src, nrows, n_per_row); + repack_bf16(helper, (ggml_bf16_t *)dst); +} + diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 753bbdb5..10754f21 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -158,6 +158,9 @@ void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void repack_f32_bf16_r4 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); +void repack_bf16_bf16_r4(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); + #ifdef __cplusplus } #endif diff --git a/include/llama.h b/include/llama.h index e4d6ed3d..10ed9cc5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -191,6 +191,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 = 225, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 = 230, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q6_0_R4 = 335, // except 1d tensors + LLAMA_FTYPE_MOSTLY_BF16_R4 = 232, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_BN_R4 = 337, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 035e5b1a..a9ac7b6e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3828,6 +3828,7 @@ struct llama_model_loader { case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break; + case GGML_TYPE_BF16_R4: ftype = LLAMA_FTYPE_MOSTLY_BF16_R4; break; case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; @@ -4540,6 +4541,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_BF16_R4: return "BF16_R4"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; @@ -15833,6 +15835,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_Q8_0_R4) { new_type = GGML_TYPE_Q8_0; } + else if (new_type == GGML_TYPE_BF16_R4) { + new_type = GGML_TYPE_BF16; + } } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || @@ -16228,6 +16233,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; + case LLAMA_FTYPE_MOSTLY_BF16_R4: default_type = GGML_TYPE_BF16_R4; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; // K-quants @@ -16520,6 +16526,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (quantize) { new_type = default_type; + if (new_type == GGML_TYPE_BF16_R4 && strcmp(tensor->name, "token_embd.weight") == 0) { + new_type = GGML_TYPE_BF16; + } // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { @@ -16680,6 +16689,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_K; else chunk_size_multiplier = 4; } + else if (new_type == GGML_TYPE_BF16_R4) { + if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_BF16; + else chunk_size_multiplier = 8; + } LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); fflush(stdout);