This commit is contained in:
Iwan Kawrakow
2024-11-15 17:01:53 +02:00
parent 81cd220f93
commit 3ee5434601

View File

@@ -340,6 +340,18 @@ inline __device__ int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}
float __device__ __forceinline__ trellis_next(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
uint32_t s;
const half * h = (const half *)&s;
val = ka*val + kb;
s = (val & kmask) ^ km32;
return (float)(h[0] +h[1]);
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
@@ -350,24 +362,14 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst
const block_iq2_kt * x = (const block_iq2_kt *)(cx + sizeof(float));
const int64_t i = ii - (row*n_per_row)/QK_K;
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
const int64_t tid = threadIdx.x;
const int64_t ib = tid; // 0...31
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
uint32_t s;
const half * h = (const half *)&s;
for (int j = 0; j < 8; ++j) {
idx = ka*idx + kb;
s = (idx & kmask) ^ km32;
float val = (float)h[0] + (float)h[1];
y[j] = dl * val;
y[j] = dl * trellis_next(idx);
}
}
@@ -381,11 +383,6 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst
const block_iq3_kt * x = (const block_iq3_kt *)(cx + sizeof(float));
const int64_t i = ii - (row*n_per_row)/QK_K;
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
const int8_t * scale_values = iq4k_values + 16;
const int64_t tid = threadIdx.x;
@@ -393,16 +390,10 @@ static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 8*ib;
uint32_t idx1 = x[i].ql[2*ib+0] + ((x[i].qh[(2*ib+0)%32] << (8-4*((2*ib+0)/32))) & 0xf00) + 4096;
uint32_t idx2 = x[i].ql[2*ib+1] + ((x[i].qh[(2*ib+1)%32] << (8-4*((2*ib+1)/32))) & 0xf00) + 4096;
//const int8_t * sv = (const int8_t *)x[i].scales;
//const float dl = scale * sv[ib/8] * 31.75f * 1.015f;
const float dl = scale * scale_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.015f;
uint32_t s[2];
const half * h = (const half *)s;
for (int j = 0; j < 4; ++j) {
idx1 = ka*idx1 + kb; s[0] = (idx1 & kmask) ^ km32;
idx2 = ka*idx2 + kb; s[1] = (idx2 & kmask) ^ km32;
y[j+0] = dl * (float)(h[0] + h[1]);
y[j+4] = dl * (float)(h[2] + h[3]);
y[j+0] = dl * trellis_next(idx1);
y[j+4] = dl * trellis_next(idx2);
}
}
@@ -417,10 +408,6 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
const int64_t i = ii - (row*n_per_row)/QK_K;
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
constexpr int kNumGroups = 64;
const int64_t tid = threadIdx.x;
@@ -437,13 +424,9 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
uint32_t idx2 = ql[jj+1] + ((qh[(jj+1)%(kNumGroups/2)] << (8 - 4*((jj+1)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+3)) & 7) << 12) + offset;
int ls = ((shb[ib32] & 0xff) >> 1) - 64;
const float dl = scale * ls;
uint32_t s[2];
const half * h = (const half *)s;
for (int j = 0; j < 4; ++j) {
idx1 = ka*idx1 + kb; s[0] = (idx1 & kmask) ^ km32;
idx2 = ka*idx2 + kb; s[1] = (idx2 & kmask) ^ km32;
y[j+0] = dl * (float)(h[0] + h[1]) + row_av;
y[j+4] = dl * (float)(h[2] + h[3]) + row_av;
y[j+0] = dl * trellis_next(idx1) + row_av;
y[j+4] = dl * trellis_next(idx2) + row_av;
}
}