diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1705c1ca..1a4a516c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1811,7 +1811,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias); + float max_bias, + float softcap); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 950fd93d..e4021764 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -21,6 +21,7 @@ typedef void (* fattn_kernel_t)( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -659,9 +660,15 @@ void launch_fattn( float scale = 1.0f; float max_bias = 0.0f; + float softcap = 0.0f; memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + memcpy(&softcap, (float *) KQV->op_params + 2, sizeof(float)); + + if (softcap != 0.0f) { + scale /= softcap; + } const uint32_t n_head = Q->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -675,7 +682,7 @@ void launch_fattn( V_data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, + scale, max_bias, m0, m1, softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 1b2fd500..d1bbf01f 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f16( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; - half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + half sum; + if (use_softcap) { + const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + sum = softcap * tanhf(tmp.x + tmp.y); + } else { + sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); @@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -296,24 +309,39 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); + float softcap; + memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f16_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + launch_fattn_tile_f16_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index f3e68dbf..25908d7a 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f32( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; + if (use_softcap) { + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); @@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32( } } -template +template void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -292,21 +303,36 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; + float softcap; + memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f32_64_128(ctx, dst); + if (softcap == 0.0f) { + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + launch_fattn_tile_f32_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 02a4ad07..cf628dd5 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); @@ -190,6 +197,9 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + if (use_softcap) { + sum = softcap*tanhf(sum); + } sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { @@ -286,10 +296,10 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -297,48 +307,71 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * KQV = dst; - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float softcap; + memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 11a5e355..1aa88272 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); @@ -180,6 +187,9 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + if (use_softcap) { + sum = softcap*tanhf(sum); + } sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -267,10 +277,10 @@ static __global__ void flash_attn_vec_ext_f32( } } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -278,44 +288,68 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float softcap; + memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (softcap == 0.0f) { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index ae232224..efe78a2f 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -6,7 +6,7 @@ #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -21,6 +21,7 @@ static __global__ void flash_attn_ext_f16( const float max_bias, const float m0, const float m1, + const float softcap, const uint32_t n_head_log2, const int ne00, const int ne01, @@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. @@ -84,6 +91,7 @@ static __global__ void flash_attn_ext_f16( const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); + const half2 softcap_2 = make_half2(softcap, softcap); frag_b Q_b[D/16][ncols/frag_n]; @@ -194,6 +202,9 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + if (use_softcap) { + KQ_f_tmp[k0/WARP_SIZE] = softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + } } float KQ_max_new = KQ_max_f[j0/nwarps]; @@ -237,6 +248,16 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + if (use_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + // Yes, and the code below can produce NaNs on overflow + KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/WARP_SIZE] *= softcap_2; + } + } half2 KQ_max_new = KQ_max_h2[j0/nwarps]; @@ -435,20 +456,29 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + float softcap; + memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float)); + if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel = softcap == 0.0f ? + flash_attn_ext_f16 : + flash_attn_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel = softcap == 0.0f ? + flash_attn_ext_f16 : + flash_attn_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel = softcap == 0.0f ? + flash_attn_ext_f16 : + flash_attn_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 29f608b0..f87f33b3 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; if (precision != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { @@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 1e940c5b..bc3e31cb 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -901,6 +901,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx if (op->src[0]->ne[0] == 256) { return false; } + float softcap; + memcpy(&softcap, ((const float *) op->op_params) + 2, sizeof(softcap)); + if (softcap != 0.0f) return false; return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 789c4180..cebac584 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7596,7 +7596,8 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias) { + float max_bias, + float softcap) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -7623,7 +7624,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale, max_bias }; + float params[] = { scale, max_bias, softcap }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -7643,7 +7644,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second + ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_attn_back @@ -16138,9 +16139,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; float max_bias = 0.0f; + float softcap = 0.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (softcap != 0.0f) { + scale /= softcap; + } const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); @@ -16204,7 +16211,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - s = s*scale + mv; // scale KQ value and apply mask + s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask const float Mold = M; @@ -16213,7 +16220,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - if (v->type== GGML_TYPE_F16) { + if (v->type == GGML_TYPE_F16) { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; @@ -16280,7 +16287,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[2]) { + switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { diff --git a/src/llama.cpp b/src/llama.cpp index 5294bf66..8a85144e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8290,7 +8290,8 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); @@ -13222,47 +13223,31 @@ struct llm_build_context { 0); cb(k, "k", il); - if (cparams.flash_attn) { + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); - // split cached v into n_head heads (not transposed) - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), - 0); - cb(v, "v", il); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); - cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + GGML_ASSERT(kv_self.size == n_ctx); - cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); - } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - cb(kq, "kq", il); + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); - GGML_ASSERT(kv_self.size == n_ctx); + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv_self.v_l[il])*n_ctx, - ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); - cb(cur_attn, "kqv_merged_cont", il); - } + cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur_attn, "kqv_merged_cont", il); cur_attn = llm_build_norm(ctx0, cur_attn, hparams, model.layers[il].attn_sub_norm, NULL, @@ -16813,12 +16798,6 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.attn_soft_cap) { - LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__); - params.flash_attn = false; - } - - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); params.flash_attn = false; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a2182c1b..f51ec5b8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1652,19 +1652,20 @@ struct test_flash_attn_ext : public test_case { const bool mask; // use mask const float max_bias; // ALiBi + const float softcap; // Gemma-2 const ggml_type type_KV; std::string vars() override { - return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV); + return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, softcap, type_KV); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16) - : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16) + : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), softcap(softcap), type_KV(type_KV) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); @@ -1673,7 +1674,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr; - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, softcap); return out; } }; @@ -2434,11 +2435,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (bool mask : { true, false } ) { for (float max_bias : { 0.0f, 8.0f }) { if (!mask && max_bias > 0.0f) continue; - for (int nh : { 32, }) { - for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, }) { - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV)); + for (float softcap : {0.0f, 10.0f}) { + if (hs != 128 && softcap != 0.0f) continue; + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, softcap, type_KV)); + } } } }