Alternative CUDA FA for SWA models (#754)

* Bounds for flash attention

* Add n_swa to FA parameters

* Fix it

* This seems very slightly better

* Using vec kernel when we have SWA

* Need also this

* f32 vec kernel

* This is slightly better

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-09-04 08:42:18 +02:00
committed by GitHub
parent 727f7b7d9f
commit 4a6a6f17ee
7 changed files with 1528 additions and 209 deletions

View File

@@ -13,6 +13,45 @@ typedef tile<16, 16, float> tile_C_KQ_16;
typedef tile<16, 4, half2> tile_C_VKQ;
typedef tile<16, 8, half2> tile_C_VKQ_16;
typedef void (* fattn_kernel_mma_t)(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int2 * __restrict__ bounds,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3);
template<int D, int nwarps, int KQ_per_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
@@ -871,6 +910,7 @@ static __global__ void flash_attn_mma_ext_f16(
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int2 * __restrict__ bounds,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -948,8 +988,13 @@ static __global__ void flash_attn_mma_ext_f16(
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
if (bounds) {
kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter);
kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter);
}
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
@@ -987,8 +1032,12 @@ static __global__ void flash_attn_mma_ext_f16(
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter;
int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
if (bounds) {
kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter);
kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter);
}
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
@@ -1144,9 +1193,109 @@ static __global__ void flash_attn_mma_combine_results(
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_all(int x) {
if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) {
return __all_sync(0xffffffff, x);
} else {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
}
return x;
}
}
template <int ncols1, bool is_swa>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_min_max(
const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) {
const int ne31 = gridDim.x;
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
const int jt = blockIdx.x;
mask += sequence*s33 + jt*ncols1*s31;
__shared__ int buf_iw[WARP_SIZE];
if (tid < WARP_SIZE) {
buf_iw[tid] = 1;
}
__syncthreads();
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
int all_inf = 1;
#pragma unroll
for (int j = 0; j < ncols1; ++j) {
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
}
all_inf = warp_reduce_all(all_inf);
if (tid % WARP_SIZE == 0) {
buf_iw[tid / WARP_SIZE] = all_inf;
}
__syncthreads();
all_inf = buf_iw[tid % WARP_SIZE];
__syncthreads();
all_inf = warp_reduce_all(all_inf);
if (!all_inf) {
break;
}
}
if constexpr (!is_swa) {
if (threadIdx.x == 0) {
KV_min_max[sequence*ne31 + jt] = {0, KV_max_sj + FATTN_KQ_STRIDE};
}
return;
}
if (threadIdx.x == 0) {
KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE;
}
if (tid < WARP_SIZE) {
buf_iw[tid] = 1;
}
__syncthreads();
int KV_min_sj = 0;
for (; KV_min_sj < KV_max_sj; KV_min_sj += FATTN_KQ_STRIDE) {
int all_inf = 1;
#pragma unroll
for (int j = 0; j < ncols1; ++j) {
const float2 tmp = __half22float2(mask[j*s31 + KV_min_sj/2 + tid]);
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
}
all_inf = warp_reduce_all(all_inf);
if (tid % WARP_SIZE == 0) {
buf_iw[tid / WARP_SIZE] = all_inf;
}
__syncthreads();
all_inf = buf_iw[tid % WARP_SIZE];
__syncthreads();
all_inf = warp_reduce_all(all_inf);
if (!all_inf) {
break;
}
}
if (threadIdx.x == 0) {
KV_min_max[sequence*ne31 + jt].x = KV_min_sj;
}
}
template <int D, int ncols1, int ncols2, int KQ_stride>
void launch_fattn_mma(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_mma_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
) {
constexpr int ncols = ncols1 * ncols2;
@@ -1171,6 +1320,9 @@ void launch_fattn_mma(
GGML_ASSERT(Q->ne[3] == 1);
int n_swa;
memcpy(&n_swa, (const int *) KQV->op_params + 4, sizeof(int));
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
@@ -1179,6 +1331,7 @@ void launch_fattn_mma(
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
ggml_cuda_pool_alloc<int2> KV_min_max(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
@@ -1225,11 +1378,29 @@ void launch_fattn_mma(
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
if (mask && (Q->ne[1] >= 1024 || (n_swa > 0 && K->ne[1] >= FATTN_KQ_STRIDE + n_swa))) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
KV_min_max.alloc(ne_KV_max);
if (n_swa > 0) {
flash_attn_mask_to_KV_min_max<ncols1, true><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
} else {
flash_attn_mask_to_KV_min_max<ncols1, false><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
}
CUDA_CHECK(cudaGetLastError());
}
const dim3 block_dim(warp_size, nwarps, 1);
dim3 blocks_num;
if (stream_k) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int max_blocks = 2*nsm;
const int max_blocks = Q->ne[1] > 1 ? 2*nsm : nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
@@ -1313,6 +1484,7 @@ void launch_fattn_mma(
V_data,
mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *)sinks->data) : nullptr,
KV_min_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
@@ -1372,7 +1544,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
fattn_kernel_t fattn_kernel;
fattn_kernel_mma_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_mma_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;

File diff suppressed because it is too large Load Diff

View File

@@ -6,9 +6,9 @@
//
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-vec-common.cuh"
template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
template<int Dk, int Dv, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // Dk, Dv == K, V head sizes
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(Dk, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -18,44 +18,37 @@ static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int2 * __restrict__ KV_min_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation:
if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) {
if (use_softcap && !(Dk == 128 || Dk == 256)) {
if (use_logit_softcap && !(Dk == 128 || Dk == 256)) {
NO_DEVICE_CODE;
return;
}
}
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
if (ncols > 1) {
NO_DEVICE_CODE;
return;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
@@ -63,18 +56,19 @@ static __global__ void flash_attn_vec_ext_f16(
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.y + nb01*ic0;
K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio);
Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) mask + ne11*ic0;
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
@@ -87,11 +81,12 @@ static __global__ void flash_attn_vec_ext_f16(
half2 * KQ2 = (half2 *) KQ;
half kqmax[ncols];
half kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -HALF_MAX_HALF;
kqsum[j] = 0.0f;
}
half kqsum[ncols] = {0.0f};
__shared__ half kqmax_shared[ncols][WARP_SIZE];
__shared__ half kqsum_shared[ncols][WARP_SIZE];
@@ -102,8 +97,21 @@ static __global__ void flash_attn_vec_ext_f16(
kqsum_shared[j][threadIdx.x] = 0.0f;
}
}
__shared__ half maskh_shared[ncols*Dk];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskh_shared[j*Dk + tid] = 0.0f;
}
__syncthreads();
half2 VKQ[ncols] = {{0.0f, 0.0f}};
const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11;
const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0;
if (first_y + blockIdx.y*Dk < k_VKQ_max) {
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
half2 Q_h2[ncols][Dk/(2*WARP_SIZE)];
int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk/(sizeof(int)*QK8_1)];
@@ -179,12 +187,19 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
KQ[j*Dk + tid] = -HALF_MAX_HALF;
}
__syncthreads();
half2 VKQ[ncols] = {{0.0f, 0.0f}};
const int k_start = parallel_blocks == 1 ? 0 : ip*Dk;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) {
// Calculate KQ tile and keep track of new maximum KQ values:
const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11;
const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0;
//if (KV_min_max && threadIdx.x == 0 && head == 0) {
// printf("gridDims = %u, %u, %u, ncols = %d, head = %d, blockIdx.x = %d, blockIdx.y = %d, bounds = %d, %d, ne11 = %d, nb11 = %d, blockIdx.y*Dk = %d\n", gridDim.x, gridDim.y, gridDim.z, ncols, head, blockIdx.x, blockIdx.y, KV_min_max[sequence*gridDim.x + blockIdx.x].x, KV_min_max[sequence*gridDim.x + blockIdx.x].y, ne11, nb11, blockIdx.y*Dk);
//}
K += (first_y + blockIdx.y*Dk) * nb11;
V += (first_y + blockIdx.y*Dv) * nb21;
maskh += (first_y + blockIdx.y*Dk);
for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk,
// Increment pointers after each loop:
K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) {
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
@@ -196,6 +211,14 @@ static __global__ void flash_attn_vec_ext_f16(
kqmax_new_arr[j] = kqmax[j];
}
if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskh_shared[j*Dk + tid] = slopeh*maskh[j*ne11 + tid];
}
__syncthreads();
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
@@ -206,12 +229,14 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
if (use_softcap) {
sum = softcap*tanhf(sum);
half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum((float)sum);
if (use_logit_softcap) {
sum = logit_softcap*tanhf(sum);
}
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
sum += maskh_shared[j*Dk + i_KQ];
if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
@@ -229,7 +254,6 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
}
@@ -261,8 +285,8 @@ static __global__ void flash_attn_vec_ext_f16(
}
half2 V_k;
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2];
@@ -271,9 +295,10 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads();
}
}
if (sinksf) {
const half sink = __float2half(sinksf[blockIdx.y]);
if (sinksf && blockIdx.y == 0) {
const half sink = __float2half(sinksf[head]);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -286,7 +311,7 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
half kqmax_new_j = kqmax_shared[j][threadIdx.y];
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
@@ -307,7 +332,7 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]);
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
if (threadIdx.x == 0) {
kqsum_shared[j][threadIdx.y] = kqsum[j];
}
@@ -322,32 +347,31 @@ static __global__ void flash_attn_vec_ext_f16(
}
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
if (parallel_blocks == 1) {
if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val;
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*Dv + tid] = dst_val;
}
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
template <int Dk, int Dv, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = Dk/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = Dk != 128 && Dk != 192;
constexpr bool need_f16_V = Dv != 128 && Dv != 64;
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr size_t nbytes_shared = 0;
launch_fattn<Dv, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, Dv, need_f16_K, need_f16_V);
}
template <int Dk, int Dv, ggml_type type_K, ggml_type type_V>
@@ -363,59 +387,54 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
float softcap;
memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
constexpr int cols_per_block = 2;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
constexpr int cols_per_block = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
constexpr int cols_per_block = 8;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}

View File

@@ -6,55 +6,54 @@
//
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-vec-common.cuh"
template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// Currenlty llvm with the amdgcn target dose not support unrolling loops
// that contain a break that can not be resolved at compile time.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int Dk, int Dv, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // Dk, Dv == K-, V-head size
#ifndef GGML_USE_HIP
__launch_bounds__(Dk, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
#endif // GGML_USE_HIP
static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int2 * __restrict__ KV_min_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
// Skip unused kernel variants for faster compilation:
if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) {
if (use_softcap && !(Dk == 128 || Dk == 256)) {
if (use_logit_softcap && !(Dk == 128 || Dk == 256)) {
NO_DEVICE_CODE;
return;
}
}
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
if (ncols > 1) {
NO_DEVICE_CODE;
return;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
@@ -62,17 +61,19 @@ static __global__ void flash_attn_vec_ext_f32(
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.y + nb01*ic0;
K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
@@ -87,11 +88,12 @@ static __global__ void flash_attn_vec_ext_f32(
}
float kqmax[ncols];
float kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -FLT_MAX/2.0f;
kqsum[j] = 0.0f;
}
float kqsum[ncols] = {0.0f};
__shared__ float kqmax_shared[ncols][WARP_SIZE];
__shared__ float kqsum_shared[ncols][WARP_SIZE];
@@ -102,6 +104,13 @@ static __global__ void flash_attn_vec_ext_f32(
kqsum_shared[j][threadIdx.x] = 0.0f;
}
}
__shared__ float maskf_shared[ncols*Dk];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*Dk + tid] = 0.0f;
}
__syncthreads();
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
@@ -176,10 +185,26 @@ static __global__ void flash_attn_vec_ext_f32(
float VKQ[ncols] = {0.0f};
const int k_start = parallel_blocks == 1 ? 0 : ip*Dk;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) {
const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11;
const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0;
K += (first_y + blockIdx.y*Dk) * nb11;
V += (first_y + blockIdx.y*Dv) * nb21;
maskh += (first_y + blockIdx.y*Dk);
for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk,
// Increment pointers after each loop:
K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*Dk + tid] = slope*__half2float(maskh[j*ne11 + tid]);
}
__syncthreads();
}
float kqmax_new_arr[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -196,12 +221,14 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
if (use_softcap) {
sum = softcap*tanhf(sum);
if (use_logit_softcap) {
sum = logit_softcap*tanhf(sum);
}
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
sum += maskf_shared[j*Dk + i_KQ];
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
@@ -215,7 +242,6 @@ static __global__ void flash_attn_vec_ext_f32(
for (int j = 0; j < ncols; ++j) {
float kqmax_new_j = kqmax_new_arr[j];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
}
@@ -246,7 +272,7 @@ static __global__ void flash_attn_vec_ext_f32(
break;
}
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
const float V_ki = dequantize_1_v(V + k*nb21, tid);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_ki*KQ[j*Dk + k];
@@ -256,8 +282,8 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads();
}
if (sinksf) {
const float sink = sinksf[blockIdx.y];
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -270,7 +296,7 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float kqmax_new_j = kqmax_shared[j][threadIdx.y];
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
@@ -309,26 +335,28 @@ static __global__ void flash_attn_vec_ext_f32(
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
float dst_val = VKQ[j_VKQ];
if (parallel_blocks == 1) {
if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val;
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*Dv + tid] = dst_val;
}
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
template <int Dk, int Dv, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = Dk/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = Dk != 128 && Dk != 192;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = Dk != 128;
constexpr bool need_f16_V = Dv != 128 && Dv != 64;
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr size_t nbytes_shared = 0;
launch_fattn<Dv, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, Dv, need_f16_K, need_f16_V);
}
template <int Dk, int Dv, ggml_type type_K, ggml_type type_V>
@@ -341,59 +369,54 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
float softcap;
memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
constexpr int cols_per_block = 2;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
constexpr int cols_per_block = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
constexpr int cols_per_block = 8;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}

View File

@@ -16,6 +16,8 @@
#include <cstdint>
#define FATTN_KQ_STRIDE 256
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
@@ -26,6 +28,25 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int32_t precision = KQV->op_params[3];
const int32_t n_swa = KQV->op_params[4];
ggml_tensor local_dst, Kl, Vl, Ml;
if (n_swa > 0) {
int ntokens = std::max(FATTN_KQ_STRIDE, int(Q->ne[1]));
int nton = FATTN_KQ_STRIDE*((ntokens + n_swa + FATTN_KQ_STRIDE - 1)/FATTN_KQ_STRIDE);
int first = K->ne[1] - nton;
if (first > 0) {
local_dst = *dst;
Kl = *K; Kl.ne[1] = nton; Kl.data = (char *)K->data + K->nb[1]*first;
Vl = *V; Vl.ne[1] = nton; Vl.data = (char *)V->data + V->nb[1]*first;
Ml = *mask; Ml.ne[0] = nton; Ml.data = (char *)mask->data + mask->nb[0]*first;
local_dst.src[1] = &Kl;
local_dst.src[2] = &Vl;
local_dst.src[3] = &Ml;
local_dst.op_params[4] = 0;
dst = &local_dst;
}
}
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) {
@@ -68,14 +89,14 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0);
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (precision == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
//if (precision == GGML_PREC_DEFAULT) {
// ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
//} else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
}
//}
return;
}

View File

@@ -9008,6 +9008,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
float params[] = { scale, max_bias, softcap };
ggml_set_op_params(result, params, sizeof(params));
ggml_set_op_params_i32(result, 4, 0);
result->op = GGML_OP_FLASH_ATTN_EXT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = q;

View File

@@ -7985,7 +7985,7 @@ static struct ggml_tensor * llm_build_kqv(
float kq_scale,
const llm_build_cb & cb,
int il,
ggml_tensor * sinks = nullptr) {
ggml_tensor * sinks = nullptr, int n_swa = 0) {
const llama_model & model = lctx.model;
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
@@ -8033,6 +8033,9 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
ggml_flash_attn_ext_add_sinks(cur, sinks);
if (n_swa > 0) {
((int32_t *)cur->op_params)[4] = n_swa;
}
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
@@ -8190,7 +8193,7 @@ static struct ggml_tensor * llm_build_kv(
float kq_scale,
const llm_build_cb & cb,
int il,
ggml_tensor * sinks = nullptr) {
ggml_tensor * sinks = nullptr, int n_swa = 0) {
const llama_hparams & hparams = lctx.model.hparams;
const llama_cparams & cparams = lctx.cparams;
@@ -8205,7 +8208,7 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_tensor * cur;
cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks);
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa);
cb(cur, "kqv_out", il);
return cur;
@@ -8766,7 +8769,8 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr,
this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0);
}
if (il == n_layer - 1) {
@@ -12198,7 +12202,8 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr,
KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0);
}
cur = llm_build_norm(ctx0, cur, hparams,
@@ -12335,7 +12340,8 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il, nullptr,
KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0);
}
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
@@ -14400,7 +14406,8 @@ struct llm_build_context {
}
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr,
is_sliding ? hparams.n_swa : 0);
}
if (il == n_layer - 1) {
@@ -15490,7 +15497,8 @@ struct llm_build_context {
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks);
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks,
is_sliding ? hparams.n_swa : 0);
cb(cur, "attn_out", il);
}