diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 29fb5075..42012b9a 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1227,7 +1227,38 @@ static void ggml_cuda_op_mul_mat_cublas( const int compute_capability = ggml_cuda_info().devices[id].cc; - if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { + if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + + ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type); + GGML_ASSERT(to_bf16_cuda != nullptr); + size_t ne = src1_ncols*ne10; + src1_as_bf16.alloc(ne); + to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream); + } + const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get(); + const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i; + ggml_cuda_pool_alloc dst_bf16(ctx.pool(id), row_diff*src1_ncols); + + const float alpha_f32 = 1.0f; + const float beta_f32 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f32, src0_ptr, CUDA_R_16BF, ne00, + src1_ptr, CUDA_R_16BF, ne10, + &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc, + CUBLAS_COMPUTE_32F, + //CUBLAS_COMPUTE_16BF, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } + else if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { @@ -1936,6 +1967,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } + //printf("%s: %s(%s) x %s(%s), %d %d %d %d %d %d\n", __func__, src0->name, ggml_type_name(src0->type), src1->name, ggml_type_name(src1->type), + // use_dequantize_mul_mat_vec, use_mul_mat_vec_q, use_mul_mat_q, split, use_mul_mat_q, any_gpus_with_slow_fp16); + // debug helpers //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); @@ -1946,21 +1980,28 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // FP32 precision KQ single-batch for batch size 1 without FlashAttention + //printf(" branch 1\n"); ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + //printf(" branch 2\n"); // FP32 precision KQV single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + //printf(" branch 3\n"); // KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { + //printf(" branch 4\n"); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { + //printf(" branch 5\n"); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { + //printf(" branch 6\n"); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); } else { + //printf(" branch 7\n"); ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } } @@ -2734,6 +2775,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + //printf("%s(%s, %s)\n", __func__, op->name, ggml_op_name(op->op)); switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -2756,6 +2798,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons struct ggml_tensor * a = op->src[0]; struct ggml_tensor * b = op->src[1]; if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { + //printf("%s(%s x %s, %s, %s)\n", __func__, a->name, b->name, ggml_type_name(a->type), ggml_type_name(b->type)); return false; } if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { @@ -2764,6 +2807,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 03de64ef..cabe64db 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -926,6 +926,45 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res y[i] = x[i]; } +template +static __global__ void convert_from_bf16(const nv_bfloat16 * __restrict__ x, dst_t * __restrict__ y, const int64_t k) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + y[i] = __bfloat162float(x[i]); + + //typedef union { uint32_t u; float f; } aux_t; + + //const uint16_t * u16 = (const uint16_t *) x; + //aux_t aux; + //aux.u = u16[i] << 16; + + //y[i] = aux.f; +} + +static __global__ void convert_to_bf16(const float * __restrict__ x, nv_bfloat16 * __restrict__ y, const int64_t k) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + y[i] = __float2bfloat16(x[i]); +} + +static __global__ void convert_to_bf16(const half * __restrict__ x, nv_bfloat16 * __restrict__ y, const int64_t k) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + y[i] = __float2bfloat16((float)x[i]); +} + template static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -933,6 +972,32 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_ convert_unary<<>>(vx, y, k); } +template +static void convert_from_bf16_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + convert_from_bf16<<>>((const nv_bfloat16 *)vx, y, k); +} + +//=> to_bf16_cuda_t = void(*)(const void * __restrict__ x, nv_bfloat16 * y, int64_t k, cudaStream_t stream); + + +template +static void convert_to_bf16_cuda(const void * __restrict__ vx, nv_bfloat16 * __restrict__ y, const int64_t k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + convert_to_bf16<<>>((const src_t *)vx, y, k); +} + +to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_to_bf16_cuda; + case GGML_TYPE_F16: + return convert_to_bf16_cuda; + default: + return nullptr; + } +} + to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -996,6 +1061,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_from_bf16_cuda; default: return nullptr; } @@ -1061,6 +1128,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_from_bf16_cuda; default: return nullptr; } diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 1fb53900..0efcecde 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -7,7 +7,10 @@ using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, in typedef to_t_cuda_t to_fp32_cuda_t; typedef to_t_cuda_t to_fp16_cuda_t; +typedef to_t_cuda_t to_bf16_cuda_t; to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type); to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); + +to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type); diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index db9f6a16..840809a1 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -4,6 +4,7 @@ #include #include #include +#include #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index d0c37725..d1d16431 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -3,6 +3,7 @@ #include #include #include +#include #ifdef __HIP_PLATFORM_AMD__ // for rocblas_initialize() #include "rocblas/rocblas.h"