Bitnet(1.75 bpw): higher precision fp8 scale

Use 3 bits for the exponent and 5 bits for the mantissa.
This makes PPL to be the same as fp16 (but the previous
version with 4 bits for the exponent and mantissa was
good enough for any practical purposes).
This commit is contained in:
Kawrakow
2024-06-18 20:08:28 +03:00
parent 9d38a61be7
commit 1f9541172f
7 changed files with 81 additions and 84 deletions

View File

@@ -501,6 +501,12 @@ static __device__ __forceinline__ float get_alibi_slope(
return powf(base, exph);
}
static __device__ __forceinline__ float iq1bn_fp8_to_float(uint8_t fp8) {
typedef union { float f; uint32_t i; } scale_t;
scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18);
return s.f;
}
template <ggml_type type>
struct ggml_cuda_type_traits;

View File

@@ -432,11 +432,8 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
int64_t i = QK_K/QK_IQ1BN * ii + ib/(QK_IQ1BN/32);
if (i >= nb64) return;
ib = ib%(QK_IQ1BN/32);
typedef union { float f; uint32_t i; } scale_t;
scale_t s;
uint8_t u = x[i].extra & 0xff;
s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
const float dl = x[i].extra & (1 << (4*ib + il + 8)) ? -s.f : s.f;
float d = iq1bn_fp8_to_float(x[i].extra & 0xff);
const float dl = x[i].extra & (1 << (4*ib + il + 8)) ? -d : d;
const float ml = -dl;
uint16_t idx = x[i].ql[4*ib + il] | ((x[i].qh[2*ib + il/2] << (8 - 4*(il%2))) & 0x0f00);
const uint16_t gp = iq1bn_grid_u16[idx];

View File

@@ -1078,10 +1078,7 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx;
typedef union { float f; uint32_t i; } scale_t;
scale_t s;
uint8_t u = bq1->extra & 0xff;
s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
float d = iq1bn_fp8_to_float(bq1->extra & 0xff);
uint8_t extra = bq1->extra >> (8 + 4*iqs);
int sumi = 0;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
@@ -1110,7 +1107,7 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
q8 += 8;
}
#endif
return s.f * __low2float(bq8_1[iqs].ds) * sumi;
return d * __low2float(bq8_1[iqs].ds) * sumi;
}
// TODO

View File

@@ -4992,6 +4992,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
}
}
static inline float iq1bn_fp8_to_float(uint8_t fp8) {
typedef union { float f; uint32_t i; } scale_t;
scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18);
return s.f;
}
void kernel_mul_mv_iq1_bn_f32_impl(
device const void * src0,
device const float * src1,
@@ -5036,13 +5042,8 @@ void kernel_mul_mv_iq1_bn_f32_impl(
device const float * y4 = y + 32 * ix + 8 * ir;
typedef union { float f; uint32_t i; } scale_t;
scale_t scale;
for (int row = 0; row < N_DST; ++row) {
uint8_t u = x[nb*row].extra & 0xff;
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
d1bn[row] = scale.f;
d1bn[row] = iq1bn_fp8_to_float(x[nb*row].extra & 0xff);
}
uint32_t aux32[2];
@@ -5138,9 +5139,6 @@ void kernel_mul_mv_iq2_bn_f32_impl(
device const float * y4 = y + 64 * ix + 4 * ir;
typedef union { float f; uint32_t i; } scale_t;
scale_t scale;
for (int row = 0; row < N_DST; ++row) {
d1bn[row] = x[nb*row].d;
}
@@ -5945,15 +5943,10 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
template <typename type4x4>
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
typedef union { float f; uint32_t i; } scale_t;
scale_t scale;
uint8_t u = xb->extra & 0xff;
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
//uint32_t u = xb->extra & 0xff;
//scale.i = (u << 19) + 905969664;
const float d = iq1bn_fp8_to_float(xb->extra & 0xff);
uint8_t gs = xb->extra >> (8 + 2*il);
const float d1 = gs & 1 ? -scale.f : scale.f;
const float d2 = gs & 2 ? -scale.f : scale.f;
const float d1 = gs & 1 ? -d : d;
const float d2 = gs & 2 ? -d : d;
uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
@@ -5969,19 +5962,6 @@ void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4
reg[2][i] = d2*aux8[2] - d2;
reg[3][i] = d2*aux8[3] - d2;
}
//Basically same performance as above. I guess, the compiler makes the transformation automatically
//uint16_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
//uint16_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
//for (int i = 0; i < 4; ++i) {
// reg[0][i] = d1*((v1 >> 2*i) & 3) - d1;
// reg[2][i] = d2*((v2 >> 2*i) & 3) - d2;
//}
//v1 >>= 8; v2 >>= 8;
//for (int i = 0; i < 4; ++i) {
// reg[1][i] = d1*((v1 >> 2*i) & 3) - d1;
// reg[3][i] = d2*((v2 >> 2*i) & 3) - d2;
//}
}
template <typename type4x4>

View File

@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iqk-quantize.h"
#include "ggml-quants.h"
#include "ggml-impl.h"
#define GGML_COMMON_IMPL_C
@@ -81,10 +82,6 @@ IQ1BNData::IQ1BNData() {
}
struct IQ1BNQuantizer {
typedef union {
float f;
uint32_t i;
} scale_t;
constexpr static int block_size = QK_IQ1BN;
int8_t L[QK_IQ1BN];
void quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix);
@@ -128,22 +125,11 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i
auto max_in_row = row_max(n_per_row, src);
max_in_row *= 1.03125f; // i.e., round to nearest in our fp8 representation
scale_t s;
uint8_t u = 0;
if (max_in_row > 1.9074e-06f && max_in_row < 0.12109f) {
s.f = max_in_row;
u = ((((s.i >> 23) + 132) & 0xf) << 4) | ((s.i >> 19) & 0xf);
s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
} else {
// outside the allowed range. Small values we can habdle via quants set to zero, so we only warn about too large values
if (max_in_row >= 0.12109f) {
u = 255;
fprintf(stderr, "%s: found scale %g, which is outside the range of out fp8 representation\n", __func__, max_in_row);
} else{
u = 0;
}
max_in_row *= 1.015625f; // i.e., round to nearest in our fp8 representation
if (max_in_row > iq1bn_max_value()) {
fprintf(stderr, "%s: found scale %g, which is outside the range of out fp8 representation\n", __func__, max_in_row);
}
auto u = iq1bn_float_to_fp8(max_in_row);
for (int ib = 0; ib < nblock; ++ib) {
std::memset(&y[ib], 0, sizeof(block_iq1_bn));
@@ -205,12 +191,8 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
assert(k%QK_IQ1BN == 0);
int nblock = k / QK_IQ1BN;
IQ1BNQuantizer::scale_t s;
for (int i = 0; i < nblock; ++i) {
uint16_t u = x[i].extra & 0xff;
s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
float d = s.f;
float d = iq1bn_fp8_to_float(x[i].extra & 0xff);
uint8_t extra = x[i].extra >> 8;
auto qh = x[i].qh;
auto ql = x[i].ql;
@@ -276,11 +258,9 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz
int nblock = n / QK_IQ1BN;
float sumf = 0;
IQ1BNQuantizer::scale_t scale;
for (int i = 0; i < nblock; ++i) {
uint16_t u = x[i].extra & 0xff;
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
float d = iq1bn_fp8_to_float(x[i].extra & 0xff);
uint8_t extra = x[i].extra >> 8;
auto qh = x[i].qh;
auto ql = x[i].ql;
@@ -304,7 +284,7 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz
sumi2 += extra & (1 << k) ? -sl : sl;
q8 += 8;
}
sumf += scale.f * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2);
sumf += d * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2);
}
*s = sumf;
@@ -325,10 +305,8 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
int nblock = n / QK_IQ1BN;
float sumf = 0;
IQ1BNQuantizer::scale_t scale;
uint16_t u = x[0].extra & 0xff;
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
float d = iq1bn_fp8_to_float(x[0].extra & 0xff);
for (int i = 0; i < nblock; ++i) {
uint8_t extra = x[i].extra >> 8;
auto qh = x[i].qh;
@@ -351,7 +329,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
sumi += extra & (1 << k) ? -sl : sl;
q8 += 8;
}
sumf += scale.f * (y[i].d) * sumi;
sumf += d * (y[i].d) * sumi;
}
*s = sumf;

46
iqk-quantize.h Normal file
View File

@@ -0,0 +1,46 @@
#pragma once
#include <stdint.h>
typedef union {
float f;
uint32_t i;
} iq1bn_scale_t;
#ifdef __cplusplus
extern "C" {
#endif
#ifdef BITNET_IQ1BN_4x4
static inline float iq1bn_min_value(void) { return 1.9074e-06f; }
static inline float iq1bn_max_value(void) { return 0.12109f; }
#else
static inline float iq1bn_min_value(void) { return 0.000488281f; }
static inline float iq1bn_max_value(void) { return 0.123047f; }
#endif
static inline uint8_t iq1bn_float_to_fp8(float f) {
if (f <= iq1bn_min_value()) return 0;
if (f >= iq1bn_max_value()) return 255;
iq1bn_scale_t s;
s.f = f;
#ifdef BITNET_IQ1BN_4x4
return ((((s.i >> 23) + 132) & 0xf) << 4) | ((s.i >> 19) & 0xf);
#else
return ((s.i >> 18) & 0x1f) | (((s.i >> 23) - 116) << 5);
#endif
}
static inline float iq1bn_fp8_to_float(uint8_t fp8) {
iq1bn_scale_t s;
#ifdef BITNET_IQ1BN_4x4
s.i = ((((fp8 >> 4) | 0xf0) - 132) << 23) | ((fp8 & 0x0f) << 19);
#else
s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18);
#endif
return s.f;
}
#ifdef __cplusplus
}
#endif

View File

@@ -31,6 +31,7 @@
#include "ggml-impl.h"
#include "ggml-quants.h"
#include "iqk_mul_mat.h"
#include "iqk-quantize.h"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
@@ -1344,15 +1345,11 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
//auto step = bx / sizeof(block_iq1_bn);
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
typedef union { float f; uint32_t i; } scale_t;
scale_t scale;
for (int ix = 0; ix < nrc_x; ++ix) {
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
uint16_t u = x[0].extra & 0xff;
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
@@ -1401,7 +1398,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, scale.f * hsum_float_8(accd[iy]));
info.store(ix, iy, d1 * hsum_float_8(accd[iy]));
}
}
@@ -4128,15 +4125,11 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
typedef union { float f; uint32_t i; } scale_t;
scale_t scale;
for (int ix = 0; ix < nrc_x; ++ix) {
x = (const block_iq1_bn *)((const char *)vx + ix*bx);
uint16_t u = x[0].extra & 0xff;
scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f);
@@ -4186,7 +4179,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, scale.f * vaddvq_f32(accd[iy]));
info.store(ix, iy, d1 * vaddvq_f32(accd[iy]));
}
}