Adding bf16 support to CUDA - matrix multipications

This commit is contained in:
Iwan Kawrakow
2024-09-05 11:21:12 +03:00
parent 76be98fdec
commit 38ae720676
5 changed files with 119 additions and 1 deletions

View File

@@ -1227,7 +1227,38 @@ static void ggml_cuda_op_mul_mat_cublas(
const int compute_capability = ggml_cuda_info().devices[id].cc;
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
if (src1->type != GGML_TYPE_BF16) {
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
GGML_ASSERT(to_bf16_cuda != nullptr);
size_t ne = src1_ncols*ne10;
src1_as_bf16.alloc(ne);
to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream);
}
const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;
ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
src1_ptr, CUDA_R_16BF, ne10,
&beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
CUBLAS_COMPUTE_32F,
//CUBLAS_COMPUTE_16BF,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
else if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
@@ -1936,6 +1967,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
}
//printf("%s: %s(%s) x %s(%s), %d %d %d %d %d %d\n", __func__, src0->name, ggml_type_name(src0->type), src1->name, ggml_type_name(src1->type),
// use_dequantize_mul_mat_vec, use_mul_mat_vec_q, use_mul_mat_q, split, use_mul_mat_q, any_gpus_with_slow_fp16);
// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@@ -1946,21 +1980,28 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// FP32 precision KQ single-batch for batch size 1 without FlashAttention
//printf(" branch 1\n");
ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
} else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
//printf(" branch 2\n");
// FP32 precision KQV single-batch for batch size 1 without FlashAttention
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
//printf(" branch 3\n");
// KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
//printf(" branch 4\n");
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_q) {
//printf(" branch 5\n");
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
//printf(" branch 6\n");
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
} else {
//printf(" branch 7\n");
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
}
}
@@ -2734,6 +2775,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
//printf("%s(%s, %s)\n", __func__, op->name, ggml_op_name(op->op));
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
@@ -2756,6 +2798,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
//printf("%s(%s x %s, %s, %s)\n", __func__, a->name, b->name, ggml_type_name(a->type), ggml_type_name(b->type));
return false;
}
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
@@ -2764,6 +2807,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
switch (a->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:

View File

@@ -926,6 +926,45 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
y[i] = x[i];
}
template <typename dst_t>
static __global__ void convert_from_bf16(const nv_bfloat16 * __restrict__ x, dst_t * __restrict__ y, const int64_t k) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
y[i] = __bfloat162float(x[i]);
//typedef union { uint32_t u; float f; } aux_t;
//const uint16_t * u16 = (const uint16_t *) x;
//aux_t aux;
//aux.u = u16[i] << 16;
//y[i] = aux.f;
}
static __global__ void convert_to_bf16(const float * __restrict__ x, nv_bfloat16 * __restrict__ y, const int64_t k) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
y[i] = __float2bfloat16(x[i]);
}
static __global__ void convert_to_bf16(const half * __restrict__ x, nv_bfloat16 * __restrict__ y, const int64_t k) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
y[i] = __float2bfloat16((float)x[i]);
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
@@ -933,6 +972,32 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
template <typename dst_t>
static void convert_from_bf16_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
convert_from_bf16<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>((const nv_bfloat16 *)vx, y, k);
}
//=> to_bf16_cuda_t = void(*)(const void * __restrict__ x, nv_bfloat16 * y, int64_t k, cudaStream_t stream);
template <typename src_t>
static void convert_to_bf16_cuda(const void * __restrict__ vx, nv_bfloat16 * __restrict__ y, const int64_t k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
convert_to_bf16<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>((const src_t *)vx, y, k);
}
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_to_bf16_cuda<float>;
case GGML_TYPE_F16:
return convert_to_bf16_cuda<half>;
default:
return nullptr;
}
}
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
@@ -996,6 +1061,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
case GGML_TYPE_BF16:
return convert_from_bf16_cuda;
default:
return nullptr;
}
@@ -1061,6 +1128,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
case GGML_TYPE_BF16:
return convert_from_bf16_cuda;
default:
return nullptr;
}

View File

@@ -7,7 +7,10 @@ using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, in
typedef to_t_cuda_t<float> to_fp32_cuda_t;
typedef to_t_cuda_t<half> to_fp16_cuda_t;
typedef to_t_cuda_t<nv_bfloat16> to_bf16_cuda_t;
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);

View File

@@ -4,6 +4,7 @@
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED

View File

@@ -3,6 +3,7 @@
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"