From 2dbb3d70bf92b976087cddba4250f06e4849c917 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 21 Aug 2024 14:58:00 +0300 Subject: [PATCH] soft_cap_max: WIP - something is wrong with CUDA --- ggml/src/ggml-cuda.cu | 4 +++ ggml/src/ggml-cuda/softmax.cu | 66 ++++++++++++++++++++++++++-------- ggml/src/ggml-cuda/softmax.cuh | 2 ++ ggml/src/ggml.c | 50 +++++++++++++++++++++++--- 4 files changed, 102 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 73ab0b73..056ca4a4 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2286,6 +2286,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOFT_MAX: ggml_cuda_op_soft_max(ctx, dst); break; + case GGML_OP_SOFT_CAP_MAX: + ggml_cuda_op_soft_cap_max(ctx, dst); + break; case GGML_OP_ROPE: ggml_cuda_op_rope(ctx, dst); break; @@ -2876,6 +2879,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_CAP_MAX: return true; case GGML_OP_ROPE: return ggml_is_contiguous(op->src[0]); diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index c24abae1..0701d430 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -12,7 +12,7 @@ __device__ float __forceinline__ t2f32(half val) { } template -static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const float * __restrict__ cap_params) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -44,7 +44,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); + const float val = cap_params ? cap_params[1]*tanhf(cap_params[0]*(x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f))) + : x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -116,7 +117,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst } template -static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, const float * __restrict__ cap_params, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -134,36 +135,36 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); break; case 64: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); break; case 128: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); break; case 256: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); break; case 512: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); break; case 1024: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); break; case 2048: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); break; case 4096: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); break; default: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params); } } @@ -197,10 +198,45 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, NULL, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, NULL, stream); + } +} + +void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + const float * src0_d = (const float *)src0->data; + const void * src1_d = src1 ? (const void *)src1->data : nullptr; + + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + float params[4]; + memcpy(params, dst->op_params, sizeof(params)); + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (use_f16) { + const half * src1_dd = (const half *)src1_d; + + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params + 2, stream); + } else { + const float * src1_dd = (const float *)src1_d; + + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params + 2, stream); } } diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh index 4ef4ff86..49a83dfa 100644 --- a/ggml/src/ggml-cuda/softmax.cuh +++ b/ggml/src/ggml-cuda/softmax.cuh @@ -3,3 +3,5 @@ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5db2479b..428bb229 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2889,6 +2889,41 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s } } +static float ggml_vec_softcap_max_f32(const int n, float * x, float s_before, float s_after) { + int i = 0; + float max = -INFINITY; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 vs_before = _mm512_set1_ps(2.f*s_before); + __m512 vs_after = _mm512_set1_ps(s_after); + __m512 vmax = _mm512_set1_ps(-INFINITY); + for (; i + 15 < n; i += 16) { + __m512 y = ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after); + _mm512_storeu_ps(x + i, y); + vmax = _mm512_max_ps(vmax, y); + } + max = _mm512_reduce_max_ps(vmax); +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + float32x4_t vs_before = vdupq_n_f32(s_before); + float32x4_t vs_after = vdupq_n_f32(s_after); + for (; i + 3 < n; i += 4) { + vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after)); + } +#endif + for (; i < n; ++i) { + x[i] = s_after*tanhf(x[i]*s_before); + max = MAX(max, x[i]); + } + return max; +} + inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { const uint16_t * i16 = (const uint16_t *) x; for (int i = 0; i < n; ++i) { @@ -6016,7 +6051,7 @@ struct ggml_tensor * ggml_softcap_max( float max_bias, float s_before, float s_after) { - ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, false); + return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, false); } struct ggml_tensor * ggml_softcap_max_inplace( @@ -6027,7 +6062,7 @@ struct ggml_tensor * ggml_softcap_max_inplace( float max_bias, float s_before, float s_after) { - ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, true); + return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, true); } @@ -13759,7 +13794,8 @@ static void ggml_compute_forward_softcap_max_f32( } } - ggml_vec_softcap_f32(nc, wp, values[2], values[3]); + //ggml_vec_softcap_f32(nc, wp, values[2], values[3]); + float max = ggml_vec_softcap_max_f32(nc, wp, values[2], values[3]); #ifndef NDEBUG for (int i = 0; i < nc; ++i) { @@ -13768,8 +13804,8 @@ static void ggml_compute_forward_softcap_max_f32( } #endif - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); + //float max = -INFINITY; + //ggml_vec_max_f32(nc, &max, wp); ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); assert(sum > 0.0); @@ -18418,6 +18454,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_SOFT_CAP_MAX: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SET: { const size_t nb1 = ((int32_t *) tensor->op_params)[0];