iq1bn(no lookup): CUDA

Not good. We only get ~160 t/s.
This commit is contained in:
Iwan Kawrakow
2024-07-15 19:56:51 +03:00
parent e4dc3babb5
commit ef39ca6a2c
2 changed files with 63 additions and 44 deletions

View File

@@ -425,27 +425,19 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
const int64_t ii = blockIdx.x;
const block_iq1_bn * x = (const block_iq1_bn *) vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
int64_t ib = tid%8; // 0...7
dst_t * y = yy + ii*QK_K + 32*ib + 8*il;
int64_t i = QK_K/QK_IQ1BN * ii + ib/(QK_IQ1BN/32);
static const uint16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
const int tid = threadIdx.x;
const int il = tid/4; // 0...7
const int ib = tid%4; // 0...3
dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
int64_t i = QK_K/QK_IQ1BN * ii + ib;
if (i >= nb64) return;
ib = ib%(QK_IQ1BN/32);
uint16_t idx = x[i].ql[4*ib + il] | ((x[i].qh[2*ib + il/2] << (8 - 4*(il%2))) & 0x0f00);
uint16_t val = x[i].extra & (1 << (4*ib + il)) ? 0xaaaa - iq1bn_grid_u16[idx] : iq1bn_grid_u16[idx];
uint32_t aux32[2];
const int8_t * aux8 = (const int8_t *)aux32;
aux32[0] = val | (val << 14);
//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
// aux32[1] = __vsub4((aux32[0] >> 4) & 0x03030303, 0x01010101);
// aux32[0] = __vsub4(aux32[0] & 0x03030303, 0x01010101);
// for (int j = 0; j < 8; ++j) y[j] = aux8[j];
//#else
aux32[1] = (aux32[0] >> 4) & 0x03030303;
aux32[0] &= 0x03030303;
for (int j = 0; j < 8; ++j) y[j] = aux8[j] - 1;
//#endif
uint16_t val = x[i].ql[il] | ((x[i].qh[il%4] << (8 - 4*(il/4))) & 0x0f00) | ((x[i].extra << (12 - il)) & 4096);
for (int j = 0; j < 8; ++j) {
uint16_t v = (val*k_mult[j] & 0x1fff)*3 >> 13;
y[j] = v - 1;
}
}
template<typename dst_t>

View File

@@ -1078,40 +1078,67 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx;
static const int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
// iqs is 0 or 1
uint8_t extra = bq1->extra >> 4*iqs;
int sumi = 0;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const int * q8 = (const int *)bq8_1[iqs].qs;
int val32, v1, v2;
//int v[2];
//int8_t * a = (int8_t *)v;
//for (int l = 0; l < 2; ++l) {
// int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
// int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
// for (int k = 0; k < 8; ++k) a[k] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
// sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], sumi));
// for (int k = 0; k < 8; ++k) a[k] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
// sumi = __dp4a(v[0], q8[4*l+2], __dp4a(v[1], q8[4*l+3], sumi));
// extra >>= 2;
//}
int v[4];
int8_t * a = (int8_t *)v;
for (int l = 0; l < 2; ++l) {
uint16_t idx1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*iqs + l] << 8) & 0x0f00);
uint16_t idx2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*iqs + l] << 4) & 0x0f00);
uint16_t val1 = extra & 1 ? 0xaaaa - iq1bn_grid_u16[idx1] : iq1bn_grid_u16[idx1];
uint16_t val2 = extra & 2 ? 0xaaaa - iq1bn_grid_u16[idx2] : iq1bn_grid_u16[idx2];
val32 = val1 | (val1 << 14);
v1 = __vsub4(val32 & 0x03030303, 0x01010101);
v2 = __vsub4((val32 >> 4) & 0x03030303, 0x01010101);
sumi = __dp4a(v1, q8[4*l+0], __dp4a(v2, q8[4*l+1], sumi));
val32 = val2 | (val2 << 14);
v1 = __vsub4(val32 & 0x03030303, 0x01010101);
v2 = __vsub4((val32 >> 4) & 0x03030303, 0x01010101);
sumi = __dp4a(v1, q8[4*l+2], __dp4a(v2, q8[4*l+3], sumi));
int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
for (int k = 0; k < 8; ++k) {
a[k+0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
a[k+8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
}
sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], __dp4a(v[2], q8[4*l+2], __dp4a(v[3], q8[4*l+3], sumi))));
extra >>= 2;
}
//int v[8];
//int8_t * a = (int8_t *)v;
//int16_t val1 = bq1->ql[4*iqs + 0] | ((bq1->qh[0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
//int16_t val2 = bq1->ql[4*iqs + 1] | ((bq1->qh[1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
//int16_t val3 = bq1->ql[4*iqs + 2] | ((bq1->qh[2] << (8-4*iqs)) & 0x0f00) | ((extra << 10) & 4096);
//int16_t val4 = bq1->ql[4*iqs + 3] | ((bq1->qh[3] << (8-4*iqs)) & 0x0f00) | ((extra << 9) & 4096);
//for (int k = 0; k < 8; ++k) {
// a[k+ 0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1;
// a[k+ 8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1;
// a[k+16] = ((val3*k_mult[k] & 0x1fff)*3 >> 13) - 1;
// a[k+24] = ((val4*k_mult[k] & 0x1fff)*3 >> 13) - 1;
//}
//sumi = __dp4a(v[0], q8[0], __dp4a(v[1], q8[1], __dp4a(v[2], q8[2], __dp4a(v[3], q8[3], sumi))));
//sumi = __dp4a(v[4], q8[4], __dp4a(v[5], q8[5], __dp4a(v[6], q8[6], __dp4a(v[7], q8[7], sumi))));
#else
uint32_t aux32[2];
const int8_t * aux8 = (const int8_t *)aux32;
const int8_t * q8 = bq8_1[iqs].qs;
for (int l = 0; l < 4; ++l) {
uint16_t idx = bq1->ql[4*iqs + l] | ((bq1->qh[2*iqs + l/2] << (8 - 4*(l%2))) & 0x0f00);
uint16_t val = extra & 1 ? 0xaaaa - iq1bn_grid_u16[idx] : iq1bn_grid_u16[idx];
aux32[0] = val | (val << 14);
aux32[1] = (aux32[0] >> 4) & 0x03030303;
aux32[0] &= 0x03030303;
for (int j = 0; j < 8; ++j) {
sumi += q8[j] * (aux8[j] - 1);
for (int l = 0; l < 2; ++l) {
int val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096);
int val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096);
for (int k = 0; k < 4; ++k) {
int v1 = (val1*k_mult[k+0] & 0x1fff)*3 >> 13;
int v2 = (val1*k_mult[k+4] & 0x1fff)*3 >> 13;
int v3 = (val2*k_mult[k+0] & 0x1fff)*3 >> 13;
int v4 = (val2*k_mult[k+4] & 0x1fff)*3 >> 13;
sumi += (v1 - 1)*q8[k+0] + (v2 - 1)*q8[k+4] + (v3 - 1)*q8[k+8] + (v4 - 1)*q8[k+12];
}
q8 += 8;
q8 += 16;
extra >>= 2;
}
#endif
return __low2float(bq8_1[iqs].ds) * sumi;