soft_cap_max: WIP - something is wrong with CUDA

This commit is contained in:
Iwan Kawrakow
2024-08-21 14:58:00 +03:00
parent 6e5d728040
commit 2dbb3d70bf
4 changed files with 102 additions and 20 deletions

View File

@@ -2286,6 +2286,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SOFT_MAX:
ggml_cuda_op_soft_max(ctx, dst);
break;
case GGML_OP_SOFT_CAP_MAX:
ggml_cuda_op_soft_cap_max(ctx, dst);
break;
case GGML_OP_ROPE:
ggml_cuda_op_rope(ctx, dst);
break;
@@ -2876,6 +2879,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_CAP_MAX:
return true;
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);

View File

@@ -12,7 +12,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const float * __restrict__ cap_params) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
@@ -44,7 +44,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
const float val = cap_params ? cap_params[1]*tanhf(cap_params[0]*(x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f)))
: x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
@@ -116,7 +117,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
}
template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, const float * __restrict__ cap_params, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
@@ -134,36 +135,36 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
break;
case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
break;
case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
break;
case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
break;
case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
break;
case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
break;
case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
break;
case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
break;
default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params);
}
}
@@ -197,10 +198,45 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (use_f16) {
const half * src1_dd = (const half *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, NULL, stream);
} else {
const float * src1_dd = (const float *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, NULL, stream);
}
}
void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
float params[4];
memcpy(params, dst->op_params, sizeof(params));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
if (use_f16) {
const half * src1_dd = (const half *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params + 2, stream);
} else {
const float * src1_dd = (const float *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params + 2, stream);
}
}

View File

@@ -3,3 +3,5 @@
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -2889,6 +2889,41 @@ static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s
}
}
static float ggml_vec_softcap_max_f32(const int n, float * x, float s_before, float s_after) {
int i = 0;
float max = -INFINITY;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
__m512 vs_before = _mm512_set1_ps(2.f*s_before);
__m512 vs_after = _mm512_set1_ps(s_after);
__m512 vmax = _mm512_set1_ps(-INFINITY);
for (; i + 15 < n; i += 16) {
__m512 y = ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after);
_mm512_storeu_ps(x + i, y);
vmax = _mm512_max_ps(vmax, y);
}
max = _mm512_reduce_max_ps(vmax);
#elif defined(__AVX2__) && defined(__FMA__)
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after));
}
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
float32x4_t vs_before = vdupq_n_f32(s_before);
float32x4_t vs_after = vdupq_n_f32(s_after);
for (; i + 3 < n; i += 4) {
vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), vs_before, vs_after));
}
#endif
for (; i < n; ++i) {
x[i] = s_after*tanhf(x[i]*s_before);
max = MAX(max, x[i]);
}
return max;
}
inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
@@ -6016,7 +6051,7 @@ struct ggml_tensor * ggml_softcap_max(
float max_bias,
float s_before,
float s_after) {
ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, false);
return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, false);
}
struct ggml_tensor * ggml_softcap_max_inplace(
@@ -6027,7 +6062,7 @@ struct ggml_tensor * ggml_softcap_max_inplace(
float max_bias,
float s_before,
float s_after) {
ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, true);
return ggml_softcap_max_impl(ctx, a, mask, scale, max_bias, s_before, s_after, true);
}
@@ -13759,7 +13794,8 @@ static void ggml_compute_forward_softcap_max_f32(
}
}
ggml_vec_softcap_f32(nc, wp, values[2], values[3]);
//ggml_vec_softcap_f32(nc, wp, values[2], values[3]);
float max = ggml_vec_softcap_max_f32(nc, wp, values[2], values[3]);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
@@ -13768,8 +13804,8 @@ static void ggml_compute_forward_softcap_max_f32(
}
#endif
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, wp);
//float max = -INFINITY;
//ggml_vec_max_f32(nc, &max, wp);
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
assert(sum > 0.0);
@@ -18418,6 +18454,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_SOFT_CAP_MAX:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_SET:
{
const size_t nb1 = ((int32_t *) tensor->op_params)[0];