From 46862d725bf5ef4abed676b171cd53178ef72cdc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 26 Aug 2024 18:22:29 +0300 Subject: [PATCH] Add softcap to flash attention Just CPU and CUDA for now (but, as we know, flash attention on the CPU is useless in llama.cpp). On CUDA this improves PP performance quite a bit, especially for long contexts. E.g., for PP-16384, I now get 3777 t/s. Without this change, one cannot use FA, and one gets 2300 t/s (after fusing softcap and softmax), or 2000 t/s without the fused softcap+softmax. In comparison, mainline llama.cpp has PP-16384 = 1549 t/s before PR-8542 (where Johannes Gaessler has also added softcap to FA), and PP-16384 = 3097 t/s after this PR. --- ggml/include/ggml.h | 3 +- ggml/src/ggml-cuda/fattn-common.cuh | 9 +++- ggml/src/ggml-cuda/fattn-tile-f16.cu | 46 +++++++++++++++---- ggml/src/ggml-cuda/fattn-tile-f32.cu | 40 ++++++++++++++--- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 59 +++++++++++++++++++------ ggml/src/ggml-cuda/fattn-vec-f32.cuh | 56 +++++++++++++++++++----- ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 38 ++++++++++++++-- ggml/src/ggml-cuda/fattn.cu | 4 +- ggml/src/ggml-metal.m | 3 ++ ggml/src/ggml.c | 19 +++++--- src/llama.cpp | 63 +++++++++------------------ tests/test-backend-ops.cpp | 22 ++++++---- 12 files changed, 257 insertions(+), 105 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1705c1ca..1a4a516c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1811,7 +1811,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias); + float max_bias, + float softcap); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 950fd93d..e4021764 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -21,6 +21,7 @@ typedef void (* fattn_kernel_t)( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -659,9 +660,15 @@ void launch_fattn( float scale = 1.0f; float max_bias = 0.0f; + float softcap = 0.0f; memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + memcpy(&softcap, (float *) KQV->op_params + 2, sizeof(float)); + + if (softcap != 0.0f) { + scale /= softcap; + } const uint32_t n_head = Q->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -675,7 +682,7 @@ void launch_fattn( V_data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, + scale, max_bias, m0, m1, softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 1b2fd500..d1bbf01f 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f16( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; - half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + half sum; + if (use_softcap) { + const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + sum = softcap * tanhf(tmp.x + tmp.y); + } else { + sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); @@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -296,24 +309,39 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); + float softcap; + memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f16_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + launch_fattn_tile_f16_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index f3e68dbf..25908d7a 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f32( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; + if (use_softcap) { + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); @@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32( } } -template +template void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -292,21 +303,36 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; + float softcap; + memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f32_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + launch_fattn_tile_f32_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 02a4ad07..cf628dd5 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); @@ -190,6 +197,9 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + if (use_softcap) { + sum = softcap*tanhf(sum); + } sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { @@ -286,10 +296,10 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -297,48 +307,71 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * KQV = dst; - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float softcap; + memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 11a5e355..1aa88272 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); @@ -180,6 +187,9 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + if (use_softcap) { + sum = softcap*tanhf(sum); + } sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -267,10 +277,10 @@ static __global__ void flash_attn_vec_ext_f32( } } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -278,44 +288,68 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float softcap; + memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index ae232224..efe78a2f 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -6,7 +6,7 @@ #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -21,6 +21,7 @@ static __global__ void flash_attn_ext_f16( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. @@ -84,6 +91,7 @@ static __global__ void flash_attn_ext_f16( const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); + const half2 softcap_2 = make_half2(softcap, softcap); frag_b Q_b[D/16][ncols/frag_n]; @@ -194,6 +202,9 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + if (use_softcap) { + KQ_f_tmp[k0/WARP_SIZE] = softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + } } float KQ_max_new = KQ_max_f[j0/nwarps]; @@ -237,6 +248,16 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + if (use_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + // Yes, and the code below can produce NaNs on overflow + KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/WARP_SIZE] *= softcap_2; + } + } half2 KQ_max_new = KQ_max_h2[j0/nwarps]; @@ -435,20 +456,29 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + float softcap; + memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float)); + if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel = softcap == 0.0f ? + flash_attn_ext_f16 : + flash_attn_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel = softcap == 0.0f ? + flash_attn_ext_f16 : + flash_attn_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel = softcap == 0.0f ? + flash_attn_ext_f16 : + flash_attn_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 29f608b0..f87f33b3 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; if (precision != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { @@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 1e940c5b..bc3e31cb 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -901,6 +901,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx if (op->src[0]->ne[0] == 256) { return false; } + float softcap; + memcpy(&softcap, ((const float *) op->op_params) + 2, sizeof(softcap)); + if (softcap != 0.0f) return false; return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 789c4180..cebac584 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7596,7 +7596,8 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias) { + float max_bias, + float softcap) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -7623,7 +7624,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale, max_bias }; + float params[] = { scale, max_bias, softcap }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -7643,7 +7644,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second + ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_attn_back @@ -16138,9 +16139,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; float max_bias = 0.0f; + float softcap = 0.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (softcap != 0.0f) { + scale /= softcap; + } const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); @@ -16204,7 +16211,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - s = s*scale + mv; // scale KQ value and apply mask + s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask const float Mold = M; @@ -16213,7 +16220,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - if (v->type== GGML_TYPE_F16) { + if (v->type == GGML_TYPE_F16) { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; @@ -16280,7 +16287,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[2]) { + switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { diff --git a/src/llama.cpp b/src/llama.cpp index 5294bf66..8a85144e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8290,7 +8290,8 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); @@ -13222,47 +13223,31 @@ struct llm_build_context { 0); cb(k, "k", il); - if (cparams.flash_attn) { + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); - // split cached v into n_head heads (not transposed) - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), - 0); - cb(v, "v", il); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); - cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + GGML_ASSERT(kv_self.size == n_ctx); - cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - cb(kq, "kq", il); + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); - GGML_ASSERT(kv_self.size == n_ctx); + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv_self.v_l[il])*n_ctx, - ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); - cb(cur_attn, "kqv_merged_cont", il); - } + cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur_attn, "kqv_merged_cont", il); cur_attn = llm_build_norm(ctx0, cur_attn, hparams, model.layers[il].attn_sub_norm, NULL, @@ -16813,12 +16798,6 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.attn_soft_cap) { - LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__); - params.flash_attn = false; - } - - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); params.flash_attn = false; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a2182c1b..f51ec5b8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1652,19 +1652,20 @@ struct test_flash_attn_ext : public test_case { const bool mask; // use mask const float max_bias; // ALiBi + const float softcap; // Gemma-2 const ggml_type type_KV; std::string vars() override { - return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV); + return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, softcap, type_KV); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16) - : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16) + : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), softcap(softcap), type_KV(type_KV) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); @@ -1673,7 +1674,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr; - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, softcap); return out; } }; @@ -2434,11 +2435,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (bool mask : { true, false } ) { for (float max_bias : { 0.0f, 8.0f }) { if (!mask && max_bias > 0.0f) continue; - for (int nh : { 32, }) { - for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, }) { - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV)); + for (float softcap : {0.0f, 10.0f}) { + if (hs != 128 && softcap != 0.0f) continue; + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, softcap, type_KV)); + } } } }