Adapting iq2_bn: CUDA dequantize

This commit is contained in:
Iwan Kawrakow
2024-10-23 19:33:05 +03:00
parent 2db9f1e314
commit 0d17e8c3c7

View File

@@ -558,24 +558,28 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) {
const int64_t ii = blockIdx.x;
const block_iq2_bn * x = (const block_iq2_bn *) vx;
static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size, int64_t nrows) {
int64_t ii = 256*blockIdx.x;
const int64_t tid = threadIdx.x;
int64_t ib64 = tid%4; // 0...3
int64_t il = tid/4; // 0...7
dst_t * y = yy + 256*ii + 64*ib64 + 2*il;
int64_t i = 256/QK_IQ1BN * ii + ib64;
if (i >= nb64) return;
const float m = -1;
dst_t * y = yy + ii + 64*ib64 + 2*il;
int64_t row = ii / n_per_row;
if (row >= nrows) return;
const char * cx = (const char *)vx + row * row_size;
float d = *(const float *)cx;
const block_iq2_bn * x = (const block_iq2_bn *)(cx + sizeof(float));
ii -= row*n_per_row;
int64_t i = ii/QK_IQ1BN + ib64;
const float m = -d;
auto qs = x[i].qs + 2*il;
for (int j = 0; j < 2; ++j) {
y[j+ 0] = ((qs[j] >> 0) & 3) + m;
y[j+16] = ((qs[j] >> 2) & 3) + m;
y[j+32] = ((qs[j] >> 4) & 3) + m;
y[j+48] = ((qs[j] >> 6) & 3) + m;
y[j+ 0] = d * ((qs[j] >> 0) & 3) + m;
y[j+16] = d * ((qs[j] >> 2) & 3) + m;
y[j+32] = d * ((qs[j] >> 4) & 3) + m;
y[j+48] = d * ((qs[j] >> 6) & 3) + m;
}
}
@@ -991,9 +995,9 @@ static void dequantize_row_iq1_tn_cuda(const void * vx, dst_t * y, const int64_t
template<typename dst_t>
static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb64 = k / QK_IQ1BN;
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row);
const int nb = (k + 255) / 256;
dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, nb64);
dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size, nrows);
}
template<typename dst_t>