diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index eef83572..a306ee26 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1836,7 +1836,7 @@ static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml } static __global__ void k_compute_batched_ptrs( - const half * src0_as_f16, const half * src1_as_f16, char * dst, + const void * src0_as_f16, const void * src1_as_f16, char * dst, const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, @@ -1844,86 +1844,155 @@ static __global__ void k_compute_batched_ptrs( size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, int64_t r2, int64_t r3) { - int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; - int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; + const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; if (i13 >= ne13 || i12 >= ne12) { return; } - int64_t i03 = i13 / r3; - int64_t i02 = i12 / r2; + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13; ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; } -static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +// Type traits for mapping ggml types to CUDA/cuBLAS types +template +struct batched_mul_mat_traits; + +template<> +struct batched_mul_mat_traits { + using cuda_type = float; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static inline const cudaDataType_t data_type = CUDA_R_32F; + static inline const ggml_type ggml_type_val = GGML_TYPE_F32; + static inline const float alpha = 1.0f; + static inline const float beta = 0.0f; + static inline const void* get_alpha() { static const float val = alpha; return &val; } + static inline const void* get_beta() { static const float val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); } +}; + +template<> +struct batched_mul_mat_traits { + using cuda_type = nv_bfloat16; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static inline const cudaDataType_t data_type = CUDA_R_16BF; + static inline const ggml_type ggml_type_val = GGML_TYPE_BF16; + static inline const float alpha = 1.0f; + static inline const float beta = 0.0f; + static inline const void* get_alpha() { static const float val = alpha; return &val; } + static inline const void* get_beta() { static const float val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); } +}; + +template<> +struct batched_mul_mat_traits { + using cuda_type = half; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + static inline const cudaDataType_t data_type = CUDA_R_16F; + static inline const ggml_type ggml_type_val = GGML_TYPE_F16; + static inline const half alpha = 1.0; + static inline const half beta = 0.0; + static inline const void* get_alpha() { static const half val = alpha; return &val; } + static inline const void* get_beta() { static const half val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); } +}; + +template +static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + using traits = batched_mul_mat_traits; + using cuda_t = typename traits::cuda_type; + GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(src0->type == src0_type); + GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); + // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst. + // As long as dst is contiguous this does not matter though. GGML_TENSOR_BINARY_OP_LOCALS const int64_t ne_dst = ggml_nelements(dst); - cudaStream_t main_stream = ctx.stream(); - CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream)); - void * src0_ddq = src0->data; - half * src0_f16 = (half *) src0_ddq; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; + float * dst_ddf = (float *) dst->data; + const size_t ts_src1 = ggml_type_size(src1->type); + GGML_ASSERT(nb10 == ts_src1); + int64_t s11 = nb11 / ts_src1; + int64_t s12 = nb12 / ts_src1; + int64_t s13 = nb13 / ts_src1; - // convert src1 to fp16 - ggml_cuda_pool_alloc src1_f16_alloc(ctx.pool()); - if (src1->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + const cuda_t * src0_ptr = nullptr; + const cuda_t * src1_ptr = nullptr; + + ggml_cuda_pool_alloc src0_alloc(ctx.pool()); + ggml_cuda_pool_alloc src1_alloc(ctx.pool()); + + bool is_src0_cont_2 = ggml_is_contiguous_2(src0); + bool is_src1_cont_2 = ggml_is_contiguous_2(src1); + + // Handle src0 + src0_ptr = (const cuda_t *) src0->data; + + // Handle src1 - convert if necessary + if (src1->type == src0_type) { + src1_ptr = (const cuda_t *) src1->data; + } else { + // Convert src1 to target type using traits conversion functions const int64_t ne_src1 = ggml_nelements(src1); - src1_f16_alloc.alloc(ne_src1); - GGML_ASSERT(to_fp16_cuda != nullptr); - to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ggml_nrows(src1), src1->ne[0], main_stream); + src1_alloc.alloc(ne_src1); + + const auto convert_func = traits::get_nc_converter(src1->type); + GGML_ASSERT(convert_func != nullptr); + convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); + src1_ptr = src1_alloc.get(); + s11 = ne10; + s12 = ne11*s11; + s13 = ne12*s12; + + is_src1_cont_2 = true; } - half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool()); + // Setup destination buffer + ggml_cuda_pool_alloc dst_temp(ctx.pool()); char * dst_t; - - cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; - cudaDataType_t cu_data_type = CUDA_R_16F; - - // dst strides size_t nbd2 = dst->nb[2]; size_t nbd3 = dst->nb[3]; - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - + cublasComputeType_t cu_compute_type = traits::compute_type; + cudaDataType_t cu_data_type = traits::data_type; + cudaDataType_t cu_data_type_a = traits::data_type; + cudaDataType_t cu_data_type_b = traits::data_type; + const void * alpha = traits::get_alpha(); + const void * beta = traits::get_beta(); const float alpha_f32 = 1.0f; - const float beta_f32 = 0.0f; - - const void * alpha = &alpha_f16; - const void * beta = &beta_f16; + const float beta_f32 = 0.0f; if (dst->op_params[0] == GGML_PREC_DEFAULT) { - dst_t = (char *) dst_f16.alloc(ne_dst); - - nbd2 /= sizeof(float) / sizeof(half); - nbd3 /= sizeof(float) / sizeof(half); + if constexpr (src0_type == GGML_TYPE_F32) { + dst_t = (char *) dst_ddf; // Direct F32 output + } else { + dst_t = (char *) dst_temp.alloc(ne_dst); + nbd2 /= sizeof(float) / sizeof(cuda_t); + nbd3 /= sizeof(float) / sizeof(cuda_t); + } } else { dst_t = (char *) dst_ddf; - cu_compute_type = CUBLAS_COMPUTE_32F; - cu_data_type = CUDA_R_32F; - + cu_data_type = CUDA_R_32F; alpha = &alpha_f32; - beta = &beta_f32; + beta = &beta_f32; } + int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -1931,77 +2000,85 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; -#if 0 - // use cublasGemmEx - { - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - int i03 = i13 / r3; - int i02 = i12 / r2; + if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) { + // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3: + const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00; + const int64_t smb = ne12 == 1 ? s13 : s12; - CUBLAS_CHECK( - cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), - beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } - } - } -#else -#ifdef GGML_USE_MUSA - GGML_ASSERT(false); -#else // !GGML_USE_MUSA - if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA - (const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB - beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC + alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA + src1_ptr, cu_data_type_b, s11, smb, // strideB + beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC ne12*ne13, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else { // use cublasGemmBatchedEx - const int ne23 = ne12*ne13; + const int64_t ne23 = ne12*ne13; ggml_cuda_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); - dim3 block_dims(ne13, ne12); - k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( - src0_f16, src1_f16, dst_t, + size_t src1_stride_size = sizeof(cuda_t); + + const int threads_x = 16; + const int threads_y = 16; + dim3 block_dims(threads_x, threads_y); + + dim3 grid_dims( + (ne13 + threads_x - 1) / threads_x, + (ne12 + threads_y - 1) / threads_y + ); + k_compute_batched_ptrs<<>>( + src0_ptr, src1_ptr, dst_t, ptrs_src.get(), ptrs_dst.get(), ne12, ne13, ne23, nb02, nb03, - src1->type == GGML_TYPE_F16 ? nb12 : nb12/2, - src1->type == GGML_TYPE_F16 ? nb13 : nb13/2, + (src1->type == src0_type) ? nb12 : s12*src1_stride_size, + (src1->type == src0_type) ? nb13 : s13*src1_stride_size, nbd2, nbd3, r2, r3); + CUDA_CHECK(cudaGetLastError()); CUBLAS_CHECK( cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, - (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10, - beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01, + alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00, + (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11, + beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, ne23, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } -#endif // GGML_USE_MUSA -#endif - if (dst->op_params[0] == GGML_PREC_DEFAULT) { - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_ddf, ggml_nrows(dst), dst->ne[0], main_stream); + // Convert output back to F32 if needed + if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val); + to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, 1, main_stream); + } +} + +static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32); + + switch (src0->type) { + case GGML_TYPE_F32: + ggml_cuda_mul_mat_batched_cublas_impl(ctx, src0, src1, dst); + break; + case GGML_TYPE_BF16: + ggml_cuda_mul_mat_batched_cublas_impl(ctx, src0, src1, dst); + break; + case GGML_TYPE_F16: + ggml_cuda_mul_mat_batched_cublas_impl(ctx, src0, src1, dst); + break; + default: + GGML_ABORT("Unsupported type"); } } diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 689613f5..05e0ba19 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -2122,3 +2122,69 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return nullptr; } } + +// non-contuigous conversions + +template +static __global__ void convert_unary( + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (i00 >= ne00) { + return; + } + + const int64_t i01 = blockIdx.y; + const int64_t i02 = blockIdx.z % ne02; + const int64_t i03 = blockIdx.z / ne02; + + const src_t * x = (const src_t *) vx; + + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast(x[ix]); +} + +template +static void convert_unary_cuda(const void * vx, dst_t * y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + convert_unary<<>> + (vx, y, ne00, ne01, ne02, s01, s02, s03); +} + +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda; + default: + return nullptr; + } +} + +to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda; + case GGML_TYPE_F16: + return convert_unary_cuda; + default: + return nullptr; + } +} + +to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F16: + return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda; + default: + return nullptr; + } +} + diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 9e131d2b..a519980c 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -22,6 +22,19 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type); +template +using to_t_nc_cuda_t = void (*)(const void * x, T * y, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream); + +typedef to_t_nc_cuda_t to_fp32_nc_cuda_t; +typedef to_t_nc_cuda_t to_fp16_nc_cuda_t; +typedef to_t_nc_cuda_t to_bf16_nc_cuda_t; + +to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type); +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); +to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); + template __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) { if constexpr (std::is_same_v) { @@ -30,6 +43,15 @@ template return __float2bfloat16(float(x)); } else if constexpr(std::is_same_v) { return __bfloat162float(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + // bypass compile error on cuda 12.0.1 +#ifdef GGML_USE_HIPBLAS + return __float22bfloat162_rn(x); +#else + return {x.x, x.y}; +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 50e6458d..7a5e3841 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -940,6 +940,5 @@ bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) { src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K || src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K || - src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT || - src0_type == GGML_TYPE_F16; + src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT; }