mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-05 19:40:19 +00:00
Fix #261
This commit is contained in:
@@ -1267,40 +1267,41 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
#ifdef GGML_CUDA_IQK_FORCE_BF16
|
||||
if (ggml_is_quantized(src0->type) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
||||
to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src0->type);
|
||||
GGML_ASSERT(to_bf16_cuda != nullptr);
|
||||
size_t ne = row_diff*ne00;
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> src0_as_bf16(ctx.pool(id), ne);
|
||||
to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff, ne00, stream);
|
||||
if (to_bf16_cuda) {
|
||||
size_t ne = row_diff*ne00;
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> src0_as_bf16(ctx.pool(id), ne);
|
||||
to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff, ne00, stream);
|
||||
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
|
||||
if (src1->type != GGML_TYPE_BF16) {
|
||||
size_t ne = src1_ncols*ne10;
|
||||
src1_as_bf16.alloc(ne);
|
||||
to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
|
||||
GGML_ASSERT(to_bf16_cuda != nullptr);
|
||||
to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols, ne10, stream);
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
|
||||
if (src1->type != GGML_TYPE_BF16) {
|
||||
size_t ne = src1_ncols*ne10;
|
||||
src1_as_bf16.alloc(ne);
|
||||
to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
|
||||
GGML_ASSERT(to_bf16_cuda != nullptr);
|
||||
to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols, ne10, stream);
|
||||
}
|
||||
const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
|
||||
const nv_bfloat16 * src0_ptr = src0_as_bf16.get();
|
||||
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
|
||||
|
||||
const float alpha_f32 = 1.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
|
||||
src1_ptr, CUDA_R_16BF, ne10,
|
||||
&beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
|
||||
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff, src1_ncols, stream);
|
||||
return;
|
||||
}
|
||||
const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
|
||||
const nv_bfloat16 * src0_ptr = src0_as_bf16.get();
|
||||
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
|
||||
|
||||
const float alpha_f32 = 1.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
|
||||
src1_ptr, CUDA_R_16BF, ne10,
|
||||
&beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
|
||||
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff, src1_ncols, stream);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user