mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
bitnet(scale in a separate tensor): CUDA
This commit is contained in:
@@ -432,8 +432,7 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
|
|||||||
int64_t i = QK_K/QK_IQ1BN * ii + ib/(QK_IQ1BN/32);
|
int64_t i = QK_K/QK_IQ1BN * ii + ib/(QK_IQ1BN/32);
|
||||||
if (i >= nb64) return;
|
if (i >= nb64) return;
|
||||||
ib = ib%(QK_IQ1BN/32);
|
ib = ib%(QK_IQ1BN/32);
|
||||||
float d = iq1bn_fp8_to_float(x[i].extra & 0xff);
|
const float dl = x[i].extra & (1 << (4*ib + il)) ? -1 : 1;
|
||||||
const float dl = x[i].extra & (1 << (4*ib + il + 8)) ? -d : d;
|
|
||||||
const float ml = -dl;
|
const float ml = -dl;
|
||||||
uint16_t idx = x[i].ql[4*ib + il] | ((x[i].qh[2*ib + il/2] << (8 - 4*(il%2))) & 0x0f00);
|
uint16_t idx = x[i].ql[4*ib + il] | ((x[i].qh[2*ib + il/2] << (8 - 4*(il%2))) & 0x0f00);
|
||||||
const uint16_t gp = iq1bn_grid_u16[idx];
|
const uint16_t gp = iq1bn_grid_u16[idx];
|
||||||
@@ -454,14 +453,13 @@ static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst
|
|||||||
dst_t * y = yy + 256*ii + 64*ib64 + 2*il;
|
dst_t * y = yy + 256*ii + 64*ib64 + 2*il;
|
||||||
int64_t i = 256/QK_IQ1BN * ii + ib64;
|
int64_t i = 256/QK_IQ1BN * ii + ib64;
|
||||||
if (i >= nb64) return;
|
if (i >= nb64) return;
|
||||||
const float d = x[i].d;
|
const float m = -1;
|
||||||
const float m = -d;
|
|
||||||
auto qs = x[i].qs + 2*il;
|
auto qs = x[i].qs + 2*il;
|
||||||
for (int j = 0; j < 2; ++j) {
|
for (int j = 0; j < 2; ++j) {
|
||||||
y[j+ 0] = d * ((qs[j] >> 0) & 3) + m;
|
y[j+ 0] = ((qs[j] >> 0) & 3) + m;
|
||||||
y[j+16] = d * ((qs[j] >> 2) & 3) + m;
|
y[j+16] = ((qs[j] >> 2) & 3) + m;
|
||||||
y[j+32] = d * ((qs[j] >> 4) & 3) + m;
|
y[j+32] = ((qs[j] >> 4) & 3) + m;
|
||||||
y[j+48] = d * ((qs[j] >> 6) & 3) + m;
|
y[j+48] = ((qs[j] >> 6) & 3) + m;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1078,8 +1078,7 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
|
|||||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||||
const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx;
|
const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx;
|
||||||
|
|
||||||
float d = iq1bn_fp8_to_float(bq1->extra & 0xff);
|
uint8_t extra = bq1->extra >> 4*iqs;
|
||||||
uint8_t extra = bq1->extra >> (8 + 4*iqs);
|
|
||||||
int sumi = 0;
|
int sumi = 0;
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
const int * q8 = (const int *)bq8_1[iqs].qs;
|
const int * q8 = (const int *)bq8_1[iqs].qs;
|
||||||
@@ -1107,7 +1106,7 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
|
|||||||
q8 += 8;
|
q8 += 8;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
return d * __low2float(bq8_1[iqs].ds) * sumi;
|
return __low2float(bq8_1[iqs].ds) * sumi;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
@@ -1132,7 +1131,7 @@ static __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
|
|||||||
}
|
}
|
||||||
auto d8l = __half22float2(bq8_1[0].ds);
|
auto d8l = __half22float2(bq8_1[0].ds);
|
||||||
auto d8h = __half22float2(bq8_1[1].ds);
|
auto d8h = __half22float2(bq8_1[1].ds);
|
||||||
return (float)bq2->d * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
|
return d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y;
|
||||||
#else
|
#else
|
||||||
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
|
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
|
||||||
auto q8l = bq8_1[0].qs + 8*iqs;
|
auto q8l = bq8_1[0].qs + 8*iqs;
|
||||||
@@ -1146,7 +1145,7 @@ static __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
|
|||||||
}
|
}
|
||||||
auto d8l = __half22float2(bq8_1[0].ds);
|
auto d8l = __half22float2(bq8_1[0].ds);
|
||||||
auto d8h = __half22float2(bq8_1[1].ds);
|
auto d8h = __half22float2(bq8_1[1].ds);
|
||||||
return (float)bq2->d * (d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
|
return d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user