mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Compile time option to use bf16 for qunts without MMQ kernels (#261)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -118,6 +118,7 @@ option(GGML_MUSA "ggml: use MUSA"
|
||||
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
|
||||
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
|
||||
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
|
||||
option(GGML_CUDA_IQK_FORCE_BF16 "ggml: use bf16 cuBLAS when no MMQ kernel is available" OFF)
|
||||
set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
|
||||
set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
|
||||
option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF)
|
||||
|
||||
@@ -362,6 +362,10 @@ if (GGML_CUDA)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_IQK_FORCE_BF16)
|
||||
add_compile_definitions(GGML_CUDA_IQK_FORCE_BF16)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_FORCE_CUBLAS)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
|
||||
endif()
|
||||
|
||||
@@ -1264,6 +1264,46 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_IQK_FORCE_BF16
|
||||
if (ggml_is_quantized(src0->type) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
||||
to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src0->type);
|
||||
GGML_ASSERT(to_bf16_cuda != nullptr);
|
||||
size_t ne = row_diff*ne00;
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> src0_as_bf16(ctx.pool(id), ne);
|
||||
to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff, ne00, stream);
|
||||
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
|
||||
if (src1->type != GGML_TYPE_BF16) {
|
||||
size_t ne = src1_ncols*ne10;
|
||||
src1_as_bf16.alloc(ne);
|
||||
to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
|
||||
GGML_ASSERT(to_bf16_cuda != nullptr);
|
||||
to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols, ne10, stream);
|
||||
}
|
||||
const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
|
||||
const nv_bfloat16 * src0_ptr = src0_as_bf16.get();
|
||||
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
|
||||
|
||||
const float alpha_f32 = 1.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
|
||||
src1_ptr, CUDA_R_16BF, ne10,
|
||||
&beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
|
||||
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff, src1_ncols, stream);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
|
||||
|
||||
@@ -579,9 +579,16 @@ static __global__ void dequantize_block_iq4_ks(const void * __restrict__ vx, dst
|
||||
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||
const float d = scale * ((x[i].scales[ib] & 254) - 127);
|
||||
const int8_t * values = iq4k_values + ((x[i].scales[ib] & 1) << 4);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = d * values[q4[j] & 0xf];
|
||||
y[j+16] = d * values[q4[j] >> 4];
|
||||
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = __float2bfloat16(d * values[q4[j] & 0xf]);
|
||||
y[j+16] = __float2bfloat16(d * values[q4[j] >> 4]);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = d * values[q4[j] & 0xf];
|
||||
y[j+16] = d * values[q4[j] >> 4];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -610,9 +617,16 @@ static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, ds
|
||||
aux32[1] = ((aux32[0] >> 4) & 0x0f0f0f0f);
|
||||
aux32[0] &= 0x0f0f0f0f;
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = d * values[aux8[j+0]];
|
||||
y[j+16] = d * values[aux8[j+4]];
|
||||
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = __float2bfloat16(d * values[aux8[j+0]]);
|
||||
y[j+16] = __float2bfloat16(d * values[aux8[j+4]]);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = d * values[aux8[j+0]];
|
||||
y[j+16] = d * values[aux8[j+4]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -632,9 +646,16 @@ static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_
|
||||
const float d2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32);
|
||||
const int8_t * values1 = iq4k_values + 16*((x[i].extra >> (2*ib+0)) & 1);
|
||||
const int8_t * values2 = iq4k_values + 16*((x[i].extra >> (2*ib+1)) & 1);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = d1 * values1[q4[j] & 0xf];
|
||||
y[j+16] = d2 * values2[q4[j] >> 4];
|
||||
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = __float2bfloat16(d1 * values1[q4[j] & 0xf]);
|
||||
y[j+16] = __float2bfloat16(d2 * values2[q4[j] >> 4]);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = d1 * values1[q4[j] & 0xf];
|
||||
y[j+16] = d2 * values2[q4[j] >> 4];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -656,12 +677,22 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_
|
||||
const uint8_t * qs = x[i].qs + 32*ib64 + 2*il;
|
||||
const uint8_t * qh = x[i].qh + 2*il;
|
||||
const uint8_t extra = x[i].extra >> 4*(ib64%4);
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4);
|
||||
y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)];
|
||||
y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)];
|
||||
y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)];
|
||||
y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)];
|
||||
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4);
|
||||
y[j+ 0] = __float2bfloat16(dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]);
|
||||
y[j+16] = __float2bfloat16(dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]);
|
||||
y[j+32] = __float2bfloat16(dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]);
|
||||
y[j+48] = __float2bfloat16(dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4);
|
||||
y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)];
|
||||
y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)];
|
||||
y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)];
|
||||
y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -689,10 +720,17 @@ static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_
|
||||
uint8_t q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4);
|
||||
uint8_t q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2);
|
||||
uint8_t q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2);
|
||||
y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0));
|
||||
y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0));
|
||||
y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0));
|
||||
y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0));
|
||||
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
|
||||
y[j+ 0] = __float2bfloat16(dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0)));
|
||||
y[j+16] = __float2bfloat16(dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0)));
|
||||
y[j+32] = __float2bfloat16(dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0)));
|
||||
y[j+48] = __float2bfloat16(dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0)));
|
||||
} else {
|
||||
y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0));
|
||||
y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0));
|
||||
y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0));
|
||||
y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -713,11 +751,20 @@ static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_
|
||||
const float dl4 = d * (((x[i].scales[4*ib128+3] >> 4*(il/8)) & 0xf) - 8);
|
||||
const uint8_t * qs = x[i].qs + 32*ib128 + 2*il;
|
||||
const int16_t extra = x[i].extra >> (8*ib128 + (il/8));
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)];
|
||||
y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)];
|
||||
y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)];
|
||||
y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)];
|
||||
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
y[j+ 0] = __float2bfloat16(dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]);
|
||||
y[j+32] = __float2bfloat16(dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]);
|
||||
y[j+64] = __float2bfloat16(dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]);
|
||||
y[j+96] = __float2bfloat16(dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)];
|
||||
y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)];
|
||||
y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)];
|
||||
y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -741,11 +788,20 @@ static __global__ void dequantize_block_iq2_ks(const void * __restrict__ vx, dst
|
||||
const float dl3 = d * (((x[i].scales[2*ib128+1] & 0xf) | ((extra >> 6) & 0x10)) - 16);
|
||||
const float dl4 = d * (((x[i].scales[2*ib128+1] >> 4) | ((extra >> 7) & 0x10)) - 16);
|
||||
const uint8_t * qs = x[i].qs + 32*ib128 + 2*il;
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)];
|
||||
y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)];
|
||||
y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)];
|
||||
y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)];
|
||||
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
y[j+ 0] = __float2bfloat16(dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]);
|
||||
y[j+32] = __float2bfloat16(dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]);
|
||||
y[j+64] = __float2bfloat16(dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]);
|
||||
y[j+96] = __float2bfloat16(dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)];
|
||||
y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)];
|
||||
y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)];
|
||||
y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -768,12 +824,22 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_
|
||||
const uint8_t * qs = x[i].qs + 32*ib128 + 2*il;
|
||||
const uint8_t * qh = x[i].qh + 2*il;
|
||||
const int16_t extra = x[i].extra >> (8*ib128 + (il/8));
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
const uint8_t h = qh[j] >> (4*(ib128%2));
|
||||
y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)];
|
||||
y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)];
|
||||
y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)];
|
||||
y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)];
|
||||
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
const uint8_t h = qh[j] >> (4*(ib128%2));
|
||||
y[j+ 0] = __float2bfloat16(dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]);
|
||||
y[j+32] = __float2bfloat16(dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]);
|
||||
y[j+64] = __float2bfloat16(dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]);
|
||||
y[j+96] = __float2bfloat16(dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
const uint8_t h = qh[j] >> (4*(ib128%2));
|
||||
y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)];
|
||||
y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)];
|
||||
y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)];
|
||||
y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1064,6 +1130,22 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
|
||||
return convert_to_bf16_cuda<float>;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_to_bf16_cuda<half>;
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
return dequantize_row_iq2_ks_cuda<nv_bfloat16>;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
return dequantize_row_iq2_k_cuda<nv_bfloat16>;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
return dequantize_row_iq3_k_cuda<nv_bfloat16>;
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
return dequantize_row_iq4_kss_cuda<nv_bfloat16>;
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
return dequantize_row_iq4_ks_cuda<nv_bfloat16>;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
return dequantize_row_iq4_k_cuda<nv_bfloat16>;
|
||||
case GGML_TYPE_IQ5_K:
|
||||
return dequantize_row_iq5_k_cuda<nv_bfloat16>;
|
||||
case GGML_TYPE_IQ6_K:
|
||||
return dequantize_row_iq6_k_cuda<nv_bfloat16>;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user