mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Alternative CUDA FA for SWA models (#754)
* Bounds for flash attention * Add n_swa to FA parameters * Fix it * This seems very slightly better * Using vec kernel when we have SWA * Need also this * f32 vec kernel * This is slightly better --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -13,6 +13,45 @@ typedef tile<16, 16, float> tile_C_KQ_16;
|
||||
typedef tile<16, 4, half2> tile_C_VKQ;
|
||||
typedef tile<16, 8, half2> tile_C_VKQ_16;
|
||||
|
||||
typedef void (* fattn_kernel_mma_t)(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int2 * __restrict__ bounds,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3);
|
||||
|
||||
template<int D, int nwarps, int KQ_per_iter>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
|
||||
@@ -871,6 +910,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int2 * __restrict__ bounds,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -948,8 +988,13 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
|
||||
if (bounds) {
|
||||
kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter);
|
||||
kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter);
|
||||
}
|
||||
|
||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
@@ -987,8 +1032,12 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
if (bounds) {
|
||||
kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter);
|
||||
kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter);
|
||||
}
|
||||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
@@ -1144,9 +1193,109 @@ static __global__ void flash_attn_mma_combine_results(
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_all(int x) {
|
||||
if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) {
|
||||
return __all_sync(0xffffffff, x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
template <int ncols1, bool is_swa>
|
||||
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
|
||||
static __global__ void flash_attn_mask_to_KV_min_max(
|
||||
const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) {
|
||||
const int ne31 = gridDim.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int sequence = blockIdx.y;
|
||||
const int jt = blockIdx.x;
|
||||
|
||||
mask += sequence*s33 + jt*ncols1*s31;
|
||||
|
||||
__shared__ int buf_iw[WARP_SIZE];
|
||||
if (tid < WARP_SIZE) {
|
||||
buf_iw[tid] = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
|
||||
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
|
||||
int all_inf = 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols1; ++j) {
|
||||
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
|
||||
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
|
||||
}
|
||||
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
if (tid % WARP_SIZE == 0) {
|
||||
buf_iw[tid / WARP_SIZE] = all_inf;
|
||||
}
|
||||
__syncthreads();
|
||||
all_inf = buf_iw[tid % WARP_SIZE];
|
||||
__syncthreads();
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
|
||||
if (!all_inf) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!is_swa) {
|
||||
if (threadIdx.x == 0) {
|
||||
KV_min_max[sequence*ne31 + jt] = {0, KV_max_sj + FATTN_KQ_STRIDE};
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE;
|
||||
}
|
||||
|
||||
if (tid < WARP_SIZE) {
|
||||
buf_iw[tid] = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int KV_min_sj = 0;
|
||||
for (; KV_min_sj < KV_max_sj; KV_min_sj += FATTN_KQ_STRIDE) {
|
||||
int all_inf = 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols1; ++j) {
|
||||
const float2 tmp = __half22float2(mask[j*s31 + KV_min_sj/2 + tid]);
|
||||
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
|
||||
}
|
||||
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
if (tid % WARP_SIZE == 0) {
|
||||
buf_iw[tid / WARP_SIZE] = all_inf;
|
||||
}
|
||||
__syncthreads();
|
||||
all_inf = buf_iw[tid % WARP_SIZE];
|
||||
__syncthreads();
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
|
||||
if (!all_inf) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KV_min_max[sequence*ne31 + jt].x = KV_min_sj;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int D, int ncols1, int ncols2, int KQ_stride>
|
||||
void launch_fattn_mma(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_mma_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
@@ -1171,6 +1320,9 @@ void launch_fattn_mma(
|
||||
|
||||
GGML_ASSERT(Q->ne[3] == 1);
|
||||
|
||||
int n_swa;
|
||||
memcpy(&n_swa, (const int *) KQV->op_params + 4, sizeof(int));
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
@@ -1179,6 +1331,7 @@ void launch_fattn_mma(
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
ggml_cuda_pool_alloc<int2> KV_min_max(pool);
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
@@ -1225,11 +1378,29 @@ void launch_fattn_mma(
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
if (mask && (Q->ne[1] >= 1024 || (n_swa > 0 && K->ne[1] >= FATTN_KQ_STRIDE + n_swa))) {
|
||||
const int s31 = mask->nb[1] / sizeof(half2);
|
||||
const int s33 = mask->nb[3] / sizeof(half2);
|
||||
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
|
||||
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
|
||||
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
|
||||
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
|
||||
KV_min_max.alloc(ne_KV_max);
|
||||
if (n_swa > 0) {
|
||||
flash_attn_mask_to_KV_min_max<ncols1, true><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
||||
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
|
||||
} else {
|
||||
flash_attn_mask_to_KV_min_max<ncols1, false><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
||||
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
dim3 blocks_num;
|
||||
if (stream_k) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int max_blocks = 2*nsm;
|
||||
const int max_blocks = Q->ne[1] > 1 ? 2*nsm : nsm;
|
||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
||||
|
||||
@@ -1313,6 +1484,7 @@ void launch_fattn_mma(
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *)sinks->data) : nullptr,
|
||||
KV_min_max.ptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
@@ -1372,7 +1544,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
fattn_kernel_t fattn_kernel;
|
||||
fattn_kernel_mma_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_mma_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
|
||||
|
||||
1074
ggml/src/ggml-cuda/fattn-vec-common.cuh
Normal file
1074
ggml/src/ggml-cuda/fattn-vec-common.cuh
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,9 +6,9 @@
|
||||
//
|
||||
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-vec-common.cuh"
|
||||
|
||||
template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
|
||||
template<int Dk, int Dv, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // Dk, Dv == K, V head sizes
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(Dk, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -18,44 +18,37 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int2 * __restrict__ KV_min_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) {
|
||||
if (use_softcap && !(Dk == 128 || Dk == 256)) {
|
||||
if (use_logit_softcap && !(Dk == 128 || Dk == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
}
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
if (ncols > 1) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
@@ -63,18 +56,19 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.y + nb01*ic0;
|
||||
K += nb12*(blockIdx.y / gqa_ratio);
|
||||
V += nb22*(blockIdx.y / gqa_ratio);
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
|
||||
@@ -87,11 +81,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
half kqmax[ncols];
|
||||
half kqsum[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax[j] = -HALF_MAX_HALF;
|
||||
kqsum[j] = 0.0f;
|
||||
}
|
||||
half kqsum[ncols] = {0.0f};
|
||||
|
||||
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
||||
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
||||
@@ -102,8 +97,21 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
kqsum_shared[j][threadIdx.x] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ half maskh_shared[ncols*Dk];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskh_shared[j*Dk + tid] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||
|
||||
const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11;
|
||||
const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0;
|
||||
if (first_y + blockIdx.y*Dk < k_VKQ_max) {
|
||||
|
||||
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
half2 Q_h2[ncols][Dk/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk/(sizeof(int)*QK8_1)];
|
||||
@@ -179,12 +187,19 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ[j*Dk + tid] = -HALF_MAX_HALF;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*Dk;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11;
|
||||
const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0;
|
||||
//if (KV_min_max && threadIdx.x == 0 && head == 0) {
|
||||
// printf("gridDims = %u, %u, %u, ncols = %d, head = %d, blockIdx.x = %d, blockIdx.y = %d, bounds = %d, %d, ne11 = %d, nb11 = %d, blockIdx.y*Dk = %d\n", gridDim.x, gridDim.y, gridDim.z, ncols, head, blockIdx.x, blockIdx.y, KV_min_max[sequence*gridDim.x + blockIdx.x].x, KV_min_max[sequence*gridDim.x + blockIdx.x].y, ne11, nb11, blockIdx.y*Dk);
|
||||
//}
|
||||
K += (first_y + blockIdx.y*Dk) * nb11;
|
||||
V += (first_y + blockIdx.y*Dv) * nb21;
|
||||
maskh += (first_y + blockIdx.y*Dk);
|
||||
for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) {
|
||||
|
||||
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
||||
@@ -196,6 +211,14 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
kqmax_new_arr[j] = kqmax[j];
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskh_shared[j*Dk + tid] = slopeh*maskh[j*ne11 + tid];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
@@ -206,12 +229,14 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
if (use_softcap) {
|
||||
sum = softcap*tanhf(sum);
|
||||
half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum((float)sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
|
||||
sum += maskh_shared[j*Dk + i_KQ];
|
||||
|
||||
if (ncols == 1) {
|
||||
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
||||
@@ -229,7 +254,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
||||
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||
}
|
||||
@@ -261,8 +285,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
|
||||
half2 V_k;
|
||||
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2];
|
||||
@@ -271,9 +295,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
if (sinksf) {
|
||||
const half sink = __float2half(sinksf[blockIdx.y]);
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const half sink = __float2half(sinksf[head]);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
@@ -286,7 +311,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half kqmax_new_j = kqmax_shared[j][threadIdx.y];
|
||||
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||
@@ -307,7 +332,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum_shared[j][threadIdx.y] = kqsum[j];
|
||||
}
|
||||
@@ -322,32 +347,31 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
|
||||
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
||||
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
|
||||
|
||||
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
||||
if (parallel_blocks == 1) {
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*Dv + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
|
||||
template <int Dk, int Dv, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = Dk/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = Dk != 128 && Dk != 192;
|
||||
constexpr bool need_f16_V = Dv != 128 && Dv != 64;
|
||||
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<Dv, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, Dv, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int Dk, int Dv, ggml_type type_K, ggml_type type_V>
|
||||
@@ -363,59 +387,54 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float softcap;
|
||||
memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
constexpr int cols_per_block = 2;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
constexpr int cols_per_block = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
constexpr int cols_per_block = 8;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,55 +6,54 @@
|
||||
//
|
||||
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-vec-common.cuh"
|
||||
|
||||
template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// Currenlty llvm with the amdgcn target dose not support unrolling loops
|
||||
// that contain a break that can not be resolved at compile time.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
template<int Dk, int Dv, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // Dk, Dv == K-, V-head size
|
||||
#ifndef GGML_USE_HIP
|
||||
__launch_bounds__(Dk, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif // GGML_USE_HIP
|
||||
static __global__ void flash_attn_vec_ext_f32(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int2 * __restrict__ KV_min_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) {
|
||||
if (use_softcap && !(Dk == 128 || Dk == 256)) {
|
||||
if (use_logit_softcap && !(Dk == 128 || Dk == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
}
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
if (ncols > 1) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
@@ -62,17 +61,19 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.y + nb01*ic0;
|
||||
K += nb12*(blockIdx.y / gqa_ratio);
|
||||
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
|
||||
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
|
||||
@@ -87,11 +88,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
}
|
||||
|
||||
float kqmax[ncols];
|
||||
float kqsum[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqmax[j] = -FLT_MAX/2.0f;
|
||||
kqsum[j] = 0.0f;
|
||||
}
|
||||
float kqsum[ncols] = {0.0f};
|
||||
|
||||
__shared__ float kqmax_shared[ncols][WARP_SIZE];
|
||||
__shared__ float kqsum_shared[ncols][WARP_SIZE];
|
||||
@@ -102,6 +104,13 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
kqsum_shared[j][threadIdx.x] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ float maskf_shared[ncols*Dk];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskf_shared[j*Dk + tid] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
@@ -176,10 +185,26 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
float VKQ[ncols] = {0.0f};
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*Dk;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) {
|
||||
const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11;
|
||||
const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0;
|
||||
|
||||
K += (first_y + blockIdx.y*Dk) * nb11;
|
||||
V += (first_y + blockIdx.y*Dv) * nb21;
|
||||
maskh += (first_y + blockIdx.y*Dk);
|
||||
for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) {
|
||||
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
if (mask) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
maskf_shared[j*Dk + tid] = slope*__half2float(maskh[j*ne11 + tid]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float kqmax_new_arr[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
@@ -196,12 +221,14 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
if (use_softcap) {
|
||||
sum = softcap*tanhf(sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
sum += maskf_shared[j*Dk + i_KQ];
|
||||
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||
|
||||
@@ -215,7 +242,6 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float kqmax_new_j = kqmax_new_arr[j];
|
||||
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||
}
|
||||
@@ -246,7 +272,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
break;
|
||||
}
|
||||
|
||||
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
|
||||
const float V_ki = dequantize_1_v(V + k*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_ki*KQ[j*Dk + k];
|
||||
@@ -256,8 +282,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (sinksf) {
|
||||
const float sink = sinksf[blockIdx.y];
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
@@ -270,7 +296,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float kqmax_new_j = kqmax_shared[j][threadIdx.y];
|
||||
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
||||
@@ -309,26 +335,28 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||
|
||||
float dst_val = VKQ[j_VKQ];
|
||||
if (parallel_blocks == 1) {
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*Dv + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
|
||||
template <int Dk, int Dv, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = Dk/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
|
||||
constexpr bool need_f16_K = Dk != 128 && Dk != 192;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = Dk != 128;
|
||||
constexpr bool need_f16_V = Dv != 128 && Dv != 64;
|
||||
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<Dv, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, Dv, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int Dk, int Dv, ggml_type type_K, ggml_type type_V>
|
||||
@@ -341,59 +369,54 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float softcap;
|
||||
memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
constexpr int cols_per_block = 1;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
constexpr int cols_per_block = 2;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
constexpr int cols_per_block = 4;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
constexpr int cols_per_block = 8;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#define FATTN_KQ_STRIDE 256
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
@@ -26,6 +28,25 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const int32_t n_swa = KQV->op_params[4];
|
||||
|
||||
ggml_tensor local_dst, Kl, Vl, Ml;
|
||||
if (n_swa > 0) {
|
||||
int ntokens = std::max(FATTN_KQ_STRIDE, int(Q->ne[1]));
|
||||
int nton = FATTN_KQ_STRIDE*((ntokens + n_swa + FATTN_KQ_STRIDE - 1)/FATTN_KQ_STRIDE);
|
||||
int first = K->ne[1] - nton;
|
||||
if (first > 0) {
|
||||
local_dst = *dst;
|
||||
Kl = *K; Kl.ne[1] = nton; Kl.data = (char *)K->data + K->nb[1]*first;
|
||||
Vl = *V; Vl.ne[1] = nton; Vl.data = (char *)V->data + V->nb[1]*first;
|
||||
Ml = *mask; Ml.ne[0] = nton; Ml.data = (char *)mask->data + mask->nb[0]*first;
|
||||
local_dst.src[1] = &Kl;
|
||||
local_dst.src[2] = &Vl;
|
||||
local_dst.src[3] = &Ml;
|
||||
local_dst.op_params[4] = 0;
|
||||
dst = &local_dst;
|
||||
}
|
||||
}
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
@@ -68,14 +89,14 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
|
||||
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0);
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
//if (precision == GGML_PREC_DEFAULT) {
|
||||
// ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
//} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
}
|
||||
//}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -9008,6 +9008,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
float params[] = { scale, max_bias, softcap };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
ggml_set_op_params_i32(result, 4, 0);
|
||||
|
||||
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = q;
|
||||
|
||||
Reference in New Issue
Block a user