WIP KQ binary mask: CUDA

Relatively painless to implement for soft_max and soft_cap_max.
We gain 11.5% for LLaMA-8B and ~14% for Gemma-2-2b at 32k tokens.
The KQ mask is prepared on the CPU and copied to the GPU, so
my guess is that most of it comes from the 32X reduction in the
amount of data being copied to the GPU.

TODO: flash attention
This commit is contained in:
Iwan Kawrakow
2024-08-28 10:03:10 +03:00
parent 511c459232
commit 1216a43719
2 changed files with 31 additions and 17 deletions

View File

@@ -2,13 +2,18 @@
#include "softmax.cuh"
template <typename T>
static __device__ __forceinline__ float t2f32(T val) {
return (float) val;
static __device__ __forceinline__ float mask_value(float slope, const T * mask, int iy) {
return mask ? slope * (float)mask[iy] : 0.0f;
}
template <>
__device__ float __forceinline__ t2f32<half>(half val) {
return __half2float(val);
__device__ __forceinline__ float mask_value(float slope, const half * mask, int iy) {
return mask ? slope * __half2float(mask[iy]) : 0.0f;
}
template <>
__device__ __forceinline__ float mask_value(float, const uint32_t * mask, int iy) {
return mask[iy >> 5] & (1u << (iy & 31)) ? 0.0f : -INFINITY;
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
@@ -44,8 +49,8 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;
const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + (mask ? slope*t2f32(mask[iy]) : 0.0f) :
scale*x[ix] + (mask ? slope*t2f32(mask[iy]) : 0.0f);
const float val = do_softcap ? scale*cap_params1*tanhf(cap_params0*x[ix]) + mask_value(slope, mask, iy) :
scale*x[ix] + mask_value(slope, mask, iy);
vals[col] = val;
max_val = max(max_val, val);
@@ -181,7 +186,7 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
@@ -194,14 +199,17 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32);
if (use_f16) {
if (use_i32) {
const uint32_t * mask = (const uint32_t *)src1_d;
soft_max_f32_cuda(src0_d, mask, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
}
else if (use_f16) {
const half * src1_dd = (const half *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
} else {
const float * src1_dd = (const float *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
}
}
@@ -219,7 +227,7 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_I32); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
@@ -229,15 +237,17 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
memcpy(params, dst->op_params, sizeof(params));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
//printf("%s: %g, %g, %g, %g, %p, %d\n", __func__, params[0], params[1], params[2], params[3], (const void *)src1, use_f16);
const bool use_i32 = (src1 && src1->type == GGML_TYPE_I32);
if (use_f16) {
if (use_i32) {
const uint32_t * mask = (const uint32_t *)src1_d;
soft_max_f32_cuda(src0_d, mask, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
}
else if (use_f16) {
const half * src1_dd = (const half *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
} else {
const float * src1_dd = (const float *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
}
}

View File

@@ -6850,10 +6850,14 @@ static struct ggml_tensor * ggml_soft_max_impl(
GGML_ASSERT(ggml_is_contiguous(a));
if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32 || mask->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_matrix(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]);
if (mask->type == GGML_TYPE_I32) {
GGML_ASSERT(mask->ne[0] == (a->ne[0] + 31)/32);
} else {
GGML_ASSERT(mask->ne[0] == a->ne[0]);
}
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
}