mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
iq1_bn: improve CUDA TG
On RTX-3080 TG-128(Bitnet-1.58b-3B) goes from 318 t/s to 340 t/s. I see I have on the front page 301 t/s, so pretty nice improvement since then.
This commit is contained in:
@@ -633,31 +633,37 @@ __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
|
||||
float scale = d16;
|
||||
const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(d16)) + kbx;
|
||||
|
||||
static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||
|
||||
// iqs is 0 or 1
|
||||
|
||||
int sumi = 0;
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
uint16_t mult[2];
|
||||
mult[1] = iqs == 0 ? 27 : 3;
|
||||
mult[0] = mult[1] + (mult[1] << 1);
|
||||
const int * q8 = (const int *)bq8_1[iqs].qs;
|
||||
int val[4];
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
int8_t * a = (int8_t *)val;
|
||||
const int i16 = 2*iqs + l;
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
uint8_t q = bq1->ql[3*i16+k];
|
||||
for (int j = 0; j < 5; ++j) {
|
||||
uint8_t v = k_mult[j]*q;
|
||||
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
|
||||
*a++ = vs-1;
|
||||
uint16_t q = bq1->ql[3*i16+k];
|
||||
for (int j = 4; j >= 0; --j) {
|
||||
uint16_t v = q & 0xff;
|
||||
v += v << 1;
|
||||
a[j] = v >> 8;
|
||||
q += q << 1;
|
||||
}
|
||||
a += 5;
|
||||
}
|
||||
uint8_t v = k_mult[i16]*bq1->extra;
|
||||
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
|
||||
*a++ = vs-1;
|
||||
uint16_t v = (mult[l]*bq1->extra) & 0xff;
|
||||
v += v << 1;
|
||||
*a = v >> 8;
|
||||
sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
|
||||
}
|
||||
float2 d8 = __half22float2(bq8_1[iqs].ds);
|
||||
return scale * (d8.x * sumi - d8.y);
|
||||
#else
|
||||
static const uint16_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||
const int8_t * q8 = bq8_1[iqs].qs;
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
const int i16 = 2*iqs + l;
|
||||
@@ -675,8 +681,8 @@ __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
|
||||
sumi += q8[0]*(vs - 1);
|
||||
q8++;
|
||||
}
|
||||
#endif
|
||||
return scale * __low2float(bq8_1[iqs].ds) * sumi;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
|
||||
|
||||
Reference in New Issue
Block a user