mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Bounds for flash attention
This commit is contained in:
@@ -13,6 +13,45 @@ typedef tile<16, 16, float> tile_C_KQ_16;
|
||||
typedef tile<16, 4, half2> tile_C_VKQ;
|
||||
typedef tile<16, 8, half2> tile_C_VKQ_16;
|
||||
|
||||
typedef void (* fattn_kernel_mma_t)(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int2 * __restrict__ bounds,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3);
|
||||
|
||||
template<int D, int nwarps, int KQ_per_iter>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
|
||||
@@ -871,6 +910,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int2 * __restrict__ bounds,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -948,8 +988,13 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
|
||||
if (bounds) {
|
||||
kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter);
|
||||
kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter);
|
||||
}
|
||||
|
||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
@@ -987,8 +1032,15 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
if (bounds) {
|
||||
if (kb0_start_kernel*KQ_per_iter >= bounds[jt].y || kb0_stop_kernel*KQ_per_iter < bounds[jt].x) {
|
||||
return;
|
||||
}
|
||||
kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter);
|
||||
kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter);
|
||||
}
|
||||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
@@ -1144,9 +1196,102 @@ static __global__ void flash_attn_mma_combine_results(
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_all(int x) {
|
||||
if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) {
|
||||
return __all_sync(0xffffffff, x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
template <int ncols1>
|
||||
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
|
||||
static __global__ void flash_attn_mask_to_KV_min_max(
|
||||
const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) {
|
||||
const int ne31 = gridDim.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int sequence = blockIdx.y;
|
||||
const int jt = blockIdx.x;
|
||||
|
||||
mask += sequence*s33 + jt*ncols1*s31;
|
||||
|
||||
__shared__ int buf_iw[WARP_SIZE];
|
||||
if (tid < WARP_SIZE) {
|
||||
buf_iw[tid] = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
|
||||
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
|
||||
int all_inf = 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols1; ++j) {
|
||||
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
|
||||
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
|
||||
}
|
||||
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
if (tid % WARP_SIZE == 0) {
|
||||
buf_iw[tid / WARP_SIZE] = all_inf;
|
||||
}
|
||||
__syncthreads();
|
||||
all_inf = buf_iw[tid % WARP_SIZE];
|
||||
__syncthreads();
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
|
||||
if (!all_inf) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE;
|
||||
}
|
||||
|
||||
if (tid < WARP_SIZE) {
|
||||
buf_iw[tid] = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int KV_min_sj = 0;
|
||||
for (; KV_min_sj < KV_max_sj; KV_min_sj += FATTN_KQ_STRIDE) {
|
||||
int all_inf = 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols1; ++j) {
|
||||
const float2 tmp = __half22float2(mask[j*s31 + KV_min_sj/2 + tid]);
|
||||
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
|
||||
}
|
||||
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
if (tid % WARP_SIZE == 0) {
|
||||
buf_iw[tid / WARP_SIZE] = all_inf;
|
||||
}
|
||||
__syncthreads();
|
||||
all_inf = buf_iw[tid % WARP_SIZE];
|
||||
__syncthreads();
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
|
||||
if (!all_inf) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KV_min_max[sequence*ne31 + jt].x = KV_min_sj;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int D, int ncols1, int ncols2, int KQ_stride>
|
||||
void launch_fattn_mma(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_mma_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
@@ -1179,6 +1324,7 @@ void launch_fattn_mma(
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
ggml_cuda_pool_alloc<int2> KV_min_max(pool);
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
@@ -1225,6 +1371,19 @@ void launch_fattn_mma(
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
if (mask && (Q->ne[1] >= 1024 || K->ne[1] >= 1024)) {
|
||||
const int s31 = mask->nb[1] / sizeof(half2);
|
||||
const int s33 = mask->nb[3] / sizeof(half2);
|
||||
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
|
||||
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
|
||||
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
|
||||
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
|
||||
KV_min_max.alloc(ne_KV_max);
|
||||
flash_attn_mask_to_KV_min_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
||||
((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
dim3 blocks_num;
|
||||
if (stream_k) {
|
||||
@@ -1313,6 +1472,7 @@ void launch_fattn_mma(
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *)sinks->data) : nullptr,
|
||||
KV_min_max.ptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
@@ -1372,7 +1532,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
fattn_kernel_t fattn_kernel;
|
||||
fattn_kernel_mma_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_mma_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
|
||||
|
||||
Reference in New Issue
Block a user