Fix iq2_kt that got broken along the way

This commit is contained in:
Iwan Kawrakow
2025-06-07 18:51:02 +03:00
parent 6ba96c8b33
commit 36fba1fff2

View File

@@ -340,14 +340,25 @@ inline __device__ int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}
float __device__ __forceinline__ trellis_next(uint32_t& val) {
int __device__ __forceinline__ trellis_next_int(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
val = ka*val + kb;
//return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, 0x82828282);
return ggml_cuda_dp4a(val & 0x3f3f3f3f, 0x01010101, -126);
}
float __device__ __forceinline__ trellis_next(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
uint32_t s;
const half * h = (const half *)&s;
val = ka*val + kb;
s = (val & kmask) ^ km32;
return (float)(h[0]+h[1]);
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
@@ -363,7 +374,7 @@ static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 1.05f;
const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
for (int j = 0; j < 8; ++j) {
y[j] = dl * trellis_next(idx);
}
@@ -398,7 +409,6 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int64_t row = (QK_K * ii) / n_per_row;
const float * dptr = (const float *)((const char *)vx + row * row_size);
float scale = dptr[0] * 1.00f;
float row_av = dptr[1];
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
const int64_t i = ii - (row*n_per_row)/QK_K;
@@ -419,8 +429,8 @@ static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst
int ls = ((shb[ib32] & 0xff) >> 1) - 64;
const float dl = scale * ls;
for (int j = 0; j < 4; ++j) {
y[j+0] = dl * trellis_next(idx1) + row_av;
y[j+4] = dl * trellis_next(idx2) + row_av;
y[j+0] = dl * trellis_next_int(idx1);
y[j+4] = dl * trellis_next_int(idx2);
}
}