mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Fix non-contiguous batched cuBLAS
This commit is contained in:
@@ -1836,7 +1836,7 @@ static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml
|
||||
}
|
||||
|
||||
static __global__ void k_compute_batched_ptrs(
|
||||
const half * src0_as_f16, const half * src1_as_f16, char * dst,
|
||||
const void * src0_as_f16, const void * src1_as_f16, char * dst,
|
||||
const void ** ptrs_src, void ** ptrs_dst,
|
||||
int64_t ne12, int64_t ne13,
|
||||
int64_t ne23,
|
||||
@@ -1844,86 +1844,155 @@ static __global__ void k_compute_batched_ptrs(
|
||||
size_t nb12, size_t nb13,
|
||||
size_t nbd2, size_t nbd3,
|
||||
int64_t r2, int64_t r3) {
|
||||
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (i13 >= ne13 || i12 >= ne12) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t i03 = i13 / r3;
|
||||
int64_t i02 = i12 / r2;
|
||||
const int64_t i03 = i13 / r3;
|
||||
const int64_t i02 = i12 / r2;
|
||||
|
||||
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
||||
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
|
||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
// Type traits for mapping ggml types to CUDA/cuBLAS types
|
||||
template<ggml_type T>
|
||||
struct batched_mul_mat_traits;
|
||||
|
||||
template<>
|
||||
struct batched_mul_mat_traits<GGML_TYPE_F32> {
|
||||
using cuda_type = float;
|
||||
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
|
||||
static inline const cudaDataType_t data_type = CUDA_R_32F;
|
||||
static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
|
||||
static inline const float alpha = 1.0f;
|
||||
static inline const float beta = 0.0f;
|
||||
static inline const void* get_alpha() { static const float val = alpha; return &val; }
|
||||
static inline const void* get_beta() { static const float val = beta; return &val; }
|
||||
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
|
||||
};
|
||||
|
||||
template<>
|
||||
struct batched_mul_mat_traits<GGML_TYPE_BF16> {
|
||||
using cuda_type = nv_bfloat16;
|
||||
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
|
||||
static inline const cudaDataType_t data_type = CUDA_R_16BF;
|
||||
static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
|
||||
static inline const float alpha = 1.0f;
|
||||
static inline const float beta = 0.0f;
|
||||
static inline const void* get_alpha() { static const float val = alpha; return &val; }
|
||||
static inline const void* get_beta() { static const float val = beta; return &val; }
|
||||
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
|
||||
};
|
||||
|
||||
template<>
|
||||
struct batched_mul_mat_traits<GGML_TYPE_F16> {
|
||||
using cuda_type = half;
|
||||
static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
|
||||
static inline const cudaDataType_t data_type = CUDA_R_16F;
|
||||
static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
|
||||
static inline const half alpha = 1.0;
|
||||
static inline const half beta = 0.0;
|
||||
static inline const void* get_alpha() { static const half val = alpha; return &val; }
|
||||
static inline const void* get_beta() { static const half val = beta; return &val; }
|
||||
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
|
||||
};
|
||||
|
||||
template<ggml_type src0_type>
|
||||
static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
using traits = batched_mul_mat_traits<src0_type>;
|
||||
using cuda_t = typename traits::cuda_type;
|
||||
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
GGML_ASSERT(src0->type == src0_type);
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
|
||||
// As long as dst is contiguous this does not matter though.
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t ne_dst = ggml_nelements(dst);
|
||||
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
|
||||
|
||||
void * src0_ddq = src0->data;
|
||||
half * src0_f16 = (half *) src0_ddq;
|
||||
float * src1_ddf = (float *) src1->data;
|
||||
float * dst_ddf = (float *) dst->data;
|
||||
float * dst_ddf = (float *) dst->data;
|
||||
const size_t ts_src1 = ggml_type_size(src1->type);
|
||||
GGML_ASSERT(nb10 == ts_src1);
|
||||
int64_t s11 = nb11 / ts_src1;
|
||||
int64_t s12 = nb12 / ts_src1;
|
||||
int64_t s13 = nb13 / ts_src1;
|
||||
|
||||
// convert src1 to fp16
|
||||
ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
|
||||
if (src1->type != GGML_TYPE_F16) {
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||
const cuda_t * src0_ptr = nullptr;
|
||||
const cuda_t * src1_ptr = nullptr;
|
||||
|
||||
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
|
||||
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
|
||||
|
||||
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
|
||||
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
|
||||
|
||||
// Handle src0
|
||||
src0_ptr = (const cuda_t *) src0->data;
|
||||
|
||||
// Handle src1 - convert if necessary
|
||||
if (src1->type == src0_type) {
|
||||
src1_ptr = (const cuda_t *) src1->data;
|
||||
} else {
|
||||
// Convert src1 to target type using traits conversion functions
|
||||
const int64_t ne_src1 = ggml_nelements(src1);
|
||||
src1_f16_alloc.alloc(ne_src1);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ggml_nrows(src1), src1->ne[0], main_stream);
|
||||
src1_alloc.alloc(ne_src1);
|
||||
|
||||
const auto convert_func = traits::get_nc_converter(src1->type);
|
||||
GGML_ASSERT(convert_func != nullptr);
|
||||
convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
|
||||
src1_ptr = src1_alloc.get();
|
||||
s11 = ne10;
|
||||
s12 = ne11*s11;
|
||||
s13 = ne12*s12;
|
||||
|
||||
is_src1_cont_2 = true;
|
||||
}
|
||||
half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
|
||||
|
||||
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
|
||||
// Setup destination buffer
|
||||
ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
|
||||
char * dst_t;
|
||||
|
||||
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||
cudaDataType_t cu_data_type = CUDA_R_16F;
|
||||
|
||||
// dst strides
|
||||
size_t nbd2 = dst->nb[2];
|
||||
size_t nbd3 = dst->nb[3];
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
|
||||
cublasComputeType_t cu_compute_type = traits::compute_type;
|
||||
cudaDataType_t cu_data_type = traits::data_type;
|
||||
cudaDataType_t cu_data_type_a = traits::data_type;
|
||||
cudaDataType_t cu_data_type_b = traits::data_type;
|
||||
const void * alpha = traits::get_alpha();
|
||||
const void * beta = traits::get_beta();
|
||||
const float alpha_f32 = 1.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
const void * alpha = &alpha_f16;
|
||||
const void * beta = &beta_f16;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
dst_t = (char *) dst_f16.alloc(ne_dst);
|
||||
|
||||
nbd2 /= sizeof(float) / sizeof(half);
|
||||
nbd3 /= sizeof(float) / sizeof(half);
|
||||
if constexpr (src0_type == GGML_TYPE_F32) {
|
||||
dst_t = (char *) dst_ddf; // Direct F32 output
|
||||
} else {
|
||||
dst_t = (char *) dst_temp.alloc(ne_dst);
|
||||
nbd2 /= sizeof(float) / sizeof(cuda_t);
|
||||
nbd3 /= sizeof(float) / sizeof(cuda_t);
|
||||
}
|
||||
} else {
|
||||
dst_t = (char *) dst_ddf;
|
||||
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
cu_data_type = CUDA_R_32F;
|
||||
|
||||
cu_data_type = CUDA_R_32F;
|
||||
alpha = &alpha_f32;
|
||||
beta = &beta_f32;
|
||||
beta = &beta_f32;
|
||||
}
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
||||
@@ -1931,77 +2000,85 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
|
||||
#if 0
|
||||
// use cublasGemmEx
|
||||
{
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int i12 = 0; i12 < ne12; ++i12) {
|
||||
int i03 = i13 / r3;
|
||||
int i02 = i12 / r2;
|
||||
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
|
||||
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
|
||||
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
|
||||
const int64_t smb = ne12 == 1 ? s13 : s12;
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
||||
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
||||
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#ifdef GGML_USE_MUSA
|
||||
GGML_ASSERT(false);
|
||||
#else // !GGML_USE_MUSA
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
// use cublasGemmStridedBatchedEx
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
|
||||
(const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB
|
||||
beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC
|
||||
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
|
||||
src1_ptr, cu_data_type_b, s11, smb, // strideB
|
||||
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
|
||||
ne12*ne13,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} else {
|
||||
// use cublasGemmBatchedEx
|
||||
const int ne23 = ne12*ne13;
|
||||
const int64_t ne23 = ne12*ne13;
|
||||
|
||||
ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
||||
ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
||||
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||
src0_f16, src1_f16, dst_t,
|
||||
size_t src1_stride_size = sizeof(cuda_t);
|
||||
|
||||
const int threads_x = 16;
|
||||
const int threads_y = 16;
|
||||
dim3 block_dims(threads_x, threads_y);
|
||||
|
||||
dim3 grid_dims(
|
||||
(ne13 + threads_x - 1) / threads_x,
|
||||
(ne12 + threads_y - 1) / threads_y
|
||||
);
|
||||
k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
|
||||
src0_ptr, src1_ptr, dst_t,
|
||||
ptrs_src.get(), ptrs_dst.get(),
|
||||
ne12, ne13,
|
||||
ne23,
|
||||
nb02, nb03,
|
||||
src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
|
||||
src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
|
||||
(src1->type == src0_type) ? nb12 : s12*src1_stride_size,
|
||||
(src1->type == src0_type) ? nb13 : s13*src1_stride_size,
|
||||
nbd2, nbd3,
|
||||
r2, r3);
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
|
||||
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10,
|
||||
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
|
||||
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
|
||||
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
|
||||
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
|
||||
ne23,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
#endif // GGML_USE_MUSA
|
||||
#endif
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16.get(), dst_ddf, ggml_nrows(dst), dst->ne[0], main_stream);
|
||||
// Convert output back to F32 if needed
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
|
||||
to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, 1, main_stream);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
|
||||
break;
|
||||
case GGML_TYPE_BF16:
|
||||
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2122,3 +2122,69 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// non-contuigous conversions
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static __global__ void convert_unary(
|
||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03) {
|
||||
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i01 = blockIdx.y;
|
||||
const int64_t i02 = blockIdx.z % ne02;
|
||||
const int64_t i03 = blockIdx.z / ne02;
|
||||
|
||||
const src_t * x = (const src_t *) vx;
|
||||
|
||||
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
||||
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
|
||||
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_cuda(const void * vx, dst_t * y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
|
||||
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne02, s01, s02, s03);
|
||||
}
|
||||
|
||||
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_cuda<float>;
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_cuda<nv_bfloat16>;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_cuda<float, nv_bfloat16>;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_cuda<half, nv_bfloat16>;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_cuda<half, float>;
|
||||
case GGML_TYPE_BF16:
|
||||
return convert_unary_cuda<nv_bfloat16, float>;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,19 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
|
||||
|
||||
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
|
||||
|
||||
template<typename T>
|
||||
using to_t_nc_cuda_t = void (*)(const void * x, T * y,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
||||
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
|
||||
|
||||
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
|
||||
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
|
||||
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
|
||||
|
||||
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
|
||||
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
|
||||
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
|
||||
|
||||
template<typename dst_t, typename src_t>
|
||||
__host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
|
||||
if constexpr (std::is_same_v<dst_t, src_t>) {
|
||||
@@ -30,6 +43,15 @@ template<typename dst_t, typename src_t>
|
||||
return __float2bfloat16(float(x));
|
||||
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
|
||||
return __float22half2_rn(x);
|
||||
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
|
||||
// bypass compile error on cuda 12.0.1
|
||||
#ifdef GGML_USE_HIPBLAS
|
||||
return __float22bfloat162_rn(x);
|
||||
#else
|
||||
return {x.x, x.y};
|
||||
#endif // GGML_USE_HIP
|
||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||
return int32_t(x);
|
||||
} else {
|
||||
|
||||
@@ -940,6 +940,5 @@ bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
|
||||
src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
|
||||
src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
|
||||
src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
|
||||
src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT ||
|
||||
src0_type == GGML_TYPE_F16;
|
||||
src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user