mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
Add softcap to flash attention
Just CPU and CUDA for now (but, as we know, flash attention on the CPU is useless in llama.cpp). On CUDA this improves PP performance quite a bit, especially for long contexts. E.g., for PP-16384, I now get 3777 t/s. Without this change, one cannot use FA, and one gets 2300 t/s (after fusing softcap and softmax), or 2000 t/s without the fused softcap+softmax. In comparison, mainline llama.cpp has PP-16384 = 1549 t/s before PR-8542 (where Johannes Gaessler has also added softcap to FA), and PP-16384 = 3097 t/s after this PR.
This commit is contained in:
@@ -1811,7 +1811,8 @@ extern "C" {
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * mask,
|
||||
float scale,
|
||||
float max_bias);
|
||||
float max_bias,
|
||||
float softcap);
|
||||
|
||||
GGML_API void ggml_flash_attn_ext_set_prec(
|
||||
struct ggml_tensor * a,
|
||||
|
||||
@@ -21,6 +21,7 @@ typedef void (* fattn_kernel_t)(
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
@@ -659,9 +660,15 @@ void launch_fattn(
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||
memcpy(&softcap, (float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (softcap != 0.0f) {
|
||||
scale /= softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = Q->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
@@ -675,7 +682,7 @@ void launch_fattn(
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2,
|
||||
scale, max_bias, m0, m1, softcap, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
@@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
@@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
|
||||
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
half sum;
|
||||
if (use_softcap) {
|
||||
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
sum = softcap * tanhf(tmp.x + tmp.y);
|
||||
} else {
|
||||
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
}
|
||||
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
|
||||
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
||||
@@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks>
|
||||
template <int cols_per_block, int parallel_blocks, bool use_softcap>
|
||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
default: {
|
||||
@@ -296,24 +309,39 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||
|
||||
float softcap;
|
||||
memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
constexpr int cols_per_block = 16;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
|
||||
} else {
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
|
||||
} else {
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 1;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
|
||||
} else {
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -19,6 +19,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
@@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
@@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
|
||||
if (use_softcap) {
|
||||
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
}
|
||||
|
||||
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
@@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
}
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks>
|
||||
template <int cols_per_block, int parallel_blocks, bool use_softcap>
|
||||
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
default: {
|
||||
@@ -292,21 +303,36 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
float softcap;
|
||||
memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
constexpr int cols_per_block = 16;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
|
||||
} else {
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
|
||||
} else {
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 1;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
|
||||
} else {
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
@@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
||||
@@ -190,6 +197,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
if (use_softcap) {
|
||||
sum = softcap*tanhf(sum);
|
||||
}
|
||||
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
|
||||
if (ncols == 1) {
|
||||
@@ -286,10 +296,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
@@ -297,48 +307,71 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * KQV = dst;
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
ggml_tensor * K = dst->src[1];
|
||||
ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float softcap;
|
||||
memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
@@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
||||
@@ -180,6 +187,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
if (use_softcap) {
|
||||
sum = softcap*tanhf(sum);
|
||||
}
|
||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||
@@ -267,10 +277,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
@@ -278,44 +288,68 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
ggml_tensor * K = dst->src[1];
|
||||
ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float softcap;
|
||||
memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -21,6 +21,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
@@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||
@@ -84,6 +91,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
const half2 softcap_2 = make_half2(softcap, softcap);
|
||||
|
||||
frag_b Q_b[D/16][ncols/frag_n];
|
||||
|
||||
@@ -194,6 +202,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
||||
if (use_softcap) {
|
||||
KQ_f_tmp[k0/WARP_SIZE] = softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
}
|
||||
|
||||
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||
@@ -237,6 +248,16 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||
if (use_softcap) {
|
||||
// There is no dedicated tangens hyperbolicus function for half2.
|
||||
// Yes, and the code below can produce NaNs on overflow
|
||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
||||
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
||||
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] *= softcap_2;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||
@@ -435,20 +456,29 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
float softcap;
|
||||
memcpy(&softcap, (const float *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
if (4*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 4;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
||||
fattn_kernel_t fattn_kernel = softcap == 0.0f ?
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
return;
|
||||
}
|
||||
if (2*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 2;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
||||
fattn_kernel_t fattn_kernel = softcap == 0.0f ?
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
return;
|
||||
}
|
||||
constexpr int parallel_blocks = 1;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
||||
fattn_kernel_t fattn_kernel = softcap == 0.0f ?
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
|
||||
if (precision != GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||
@@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
|
||||
@@ -901,6 +901,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||
if (op->src[0]->ne[0] == 256) {
|
||||
return false;
|
||||
}
|
||||
float softcap;
|
||||
memcpy(&softcap, ((const float *) op->op_params) + 2, sizeof(softcap));
|
||||
if (softcap != 0.0f) return false;
|
||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
|
||||
@@ -7596,7 +7596,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * mask,
|
||||
float scale,
|
||||
float max_bias) {
|
||||
float max_bias,
|
||||
float softcap) {
|
||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||
// TODO: check if vT can be multiplied by (k*qT)
|
||||
|
||||
@@ -7623,7 +7624,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
float params[] = { scale, max_bias };
|
||||
float params[] = { scale, max_bias, softcap };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||
@@ -7643,7 +7644,7 @@ void ggml_flash_attn_ext_set_prec(
|
||||
|
||||
const int32_t prec_i32 = (int32_t) prec;
|
||||
|
||||
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
||||
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||
}
|
||||
|
||||
// ggml_flash_attn_back
|
||||
@@ -16138,9 +16139,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
memcpy(&softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
if (softcap != 0.0f) {
|
||||
scale /= softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
@@ -16204,7 +16211,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
|
||||
s = s*scale + mv; // scale KQ value and apply mask
|
||||
s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask
|
||||
|
||||
const float Mold = M;
|
||||
|
||||
@@ -16213,7 +16220,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
|
||||
if (v->type== GGML_TYPE_F16) {
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
M = s;
|
||||
@@ -16280,7 +16287,7 @@ static void ggml_compute_forward_flash_attn_ext(
|
||||
const struct ggml_tensor * v,
|
||||
const struct ggml_tensor * mask,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (dst->op_params[2]) {
|
||||
switch (dst->op_params[3]) {
|
||||
case GGML_PREC_DEFAULT:
|
||||
case GGML_PREC_F32:
|
||||
{
|
||||
|
||||
@@ -8290,7 +8290,8 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
|
||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
@@ -13222,47 +13223,31 @@ struct llm_build_context {
|
||||
0);
|
||||
cb(k, "k", il);
|
||||
|
||||
if (cparams.flash_attn) {
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
// split cached v into n_head heads (not transposed)
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||
n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
GGML_ASSERT(kv_self.size == n_ctx);
|
||||
|
||||
cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
|
||||
} else {
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
GGML_ASSERT(kv_self.size == n_ctx);
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
cb(cur_attn, "kqv_merged_cont", il);
|
||||
}
|
||||
cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
cb(cur_attn, "kqv_merged_cont", il);
|
||||
|
||||
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
|
||||
model.layers[il].attn_sub_norm, NULL,
|
||||
@@ -16813,12 +16798,6 @@ struct llama_context * llama_new_context_with_model(
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (params.flash_attn && model->hparams.attn_soft_cap) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
|
||||
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
||||
params.flash_attn = false;
|
||||
|
||||
@@ -1652,19 +1652,20 @@ struct test_flash_attn_ext : public test_case {
|
||||
const bool mask; // use mask
|
||||
|
||||
const float max_bias; // ALiBi
|
||||
const float softcap; // Gemma-2
|
||||
|
||||
const ggml_type type_KV;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV);
|
||||
return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, softcap, type_KV);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 5e-4;
|
||||
}
|
||||
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
|
||||
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
|
||||
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), softcap(softcap), type_KV(type_KV) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
||||
@@ -1673,7 +1674,7 @@ struct test_flash_attn_ext : public test_case {
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
||||
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, softcap);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
@@ -2434,11 +2435,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
for (bool mask : { true, false } ) {
|
||||
for (float max_bias : { 0.0f, 8.0f }) {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (int nh : { 32, }) {
|
||||
for (int kv : { 512, 1024, }) {
|
||||
for (int nb : { 1, 2, 4, 8, }) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV));
|
||||
for (float softcap : {0.0f, 10.0f}) {
|
||||
if (hs != 128 && softcap != 0.0f) continue;
|
||||
for (int nh : { 32, }) {
|
||||
for (int kv : { 512, 1024, }) {
|
||||
for (int nb : { 1, 2, 4, 8, }) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, softcap, type_KV));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user