From 8af4111f97724123d29f991e2b749fc5499e659f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 29 Oct 2024 15:08:32 +0200 Subject: [PATCH] multi_add: CUDA --- ggml/src/ggml-cuda.cu | 12 +++++++++ ggml/src/ggml-cuda/unary.cu | 52 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/unary.cuh | 3 +++ 3 files changed, 67 insertions(+) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 6759e202..e38e9568 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2220,6 +2220,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ADD: ggml_cuda_op_add(ctx, dst); break; + case GGML_OP_MULTI_ADD: + ggml_cuda_op_multi_add(ctx, dst); + break; case GGML_OP_ACC: ggml_cuda_op_acc(ctx, dst); break; @@ -2607,6 +2610,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); #endif } + if (node->op == GGML_OP_MULTI_ADD && node->ne[1] > 1) { + // disable CUDA graphs for batch size > 1 for now. + // Changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } if (node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. @@ -2927,6 +2938,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_MULTI_ADD: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 7bc43d0f..72043c2e 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -52,6 +52,21 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa dst[i] = x[i] * y[i] / (1.0f + expf(-x[i])); } +static __global__ void multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + int64_t k = ne0*ne1; + if (i >= k) { + return; + } + int i1 = i / ne0; + int i0 = i % ne0; + float * result = (float *)(dst + i1*nb1); + const float * s = (const float *)(src0 + i1*nb01) + i0; + float sum = 0; + for (int j = 0; j < nused; ++j) sum += s[j*ne0]; + result[i0] = sum; +} + static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -218,6 +233,43 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_ sqrt_f32<<>>(x, dst, k); } +static void multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst, cudaStream_t stream) { + int64_t k = ne0 * ne1; + const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE; + multi_add_f32<<>>(nused, ne0, ne1, nb1, nb01, src0, dst); +} + +void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + int nused = 0; + for (int i = 0; i < GGML_MAX_SRC; ++i) { + ggml_tensor * src = dst->src[i]; + if (src) { + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_are_same_shape(src, dst)); + GGML_ASSERT(src->ne[2] == 1 && src->ne[3] == 1); + GGML_ASSERT(src->nb[0] == sizeof(float)); + ++nused; + } else { + break; + } + } + GGML_ASSERT(nused >= 2); + const char * src0 = (const char *)dst->src[0]->data; + const int64_t nb01 = dst->src[0]->ne[0]*sizeof(float); + for (int i = 1; i < nused; ++i) { + GGML_ASSERT(dst->src[i]->nb[1] == dst->src[0]->nb[1]); + const char * src = (const char *)dst->src[i]->data; + GGML_ASSERT(src == src0 + i*nb01); + GGML_ASSERT(dst->src[i]->nb[1] == dst->src[0]->nb[1]); + } + //printf("%s: nused = %d\n", __func__, nused); + cudaStream_t stream = ctx.stream(); + multi_add_f32_cuda(nused, dst->ne[0], dst->ne[1], dst->nb[1], dst->src[0]->nb[1], src0, (char *)dst->data, stream); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index d2d478b4..0235a319 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -9,6 +9,7 @@ #define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_SQRT_BLOCK_SIZE 256 +#define CUDA_MULTI_ADD_BLOCK_SIZE 256 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -35,3 +36,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);