diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 16c8c24f..aa85716e 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -13,6 +13,45 @@ typedef tile<16, 16, float> tile_C_KQ_16; typedef tile<16, 4, half2> tile_C_VKQ; typedef tile<16, 8, half2> tile_C_VKQ_16; +typedef void (* fattn_kernel_mma_t)( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int2 * __restrict__ bounds, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const float softcap, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3); + template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { @@ -871,6 +910,7 @@ static __global__ void flash_attn_mma_ext_f16( const char * __restrict__ V, const char * __restrict__ mask, const char * __restrict__ sinks, + const int2 * __restrict__ bounds, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -948,8 +988,13 @@ static __global__ void flash_attn_mma_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_start_kernel = kb0_start * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + + if (bounds) { + kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter); + kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter); + } constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { @@ -987,8 +1032,15 @@ static __global__ void flash_attn_mma_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_start_kernel = kb0_start * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + if (bounds) { + if (kb0_start_kernel*KQ_per_iter >= bounds[jt].y || kb0_stop_kernel*KQ_per_iter < bounds[jt].x) { + return; + } + kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter); + kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter); + } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; @@ -1144,9 +1196,102 @@ static __global__ void flash_attn_mma_combine_results( dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; } +template +static __device__ __forceinline__ int warp_reduce_all(int x) { + if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) { + return __all_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) && x; + } + return x; + } +} + +template +__launch_bounds__(FATTN_KQ_STRIDE/2, 1) +static __global__ void flash_attn_mask_to_KV_min_max( + const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) { + const int ne31 = gridDim.x; + const int tid = threadIdx.x; + const int sequence = blockIdx.y; + const int jt = blockIdx.x; + + mask += sequence*s33 + jt*ncols1*s31; + + __shared__ int buf_iw[WARP_SIZE]; + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + break; + } + } + + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE; + } + + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_min_sj = 0; + for (; KV_min_sj < KV_max_sj; KV_min_sj += FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_min_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + break; + } + } + + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt].x = KV_min_sj; + } +} + + template void launch_fattn_mma( - ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_mma_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -1179,6 +1324,7 @@ void launch_fattn_mma( ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc KV_min_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); @@ -1225,6 +1371,19 @@ void launch_fattn_mma( const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + if (mask && (Q->ne[1] >= 1024 || K->ne[1] >= 1024)) { + const int s31 = mask->nb[1] / sizeof(half2); + const int s33 = mask->nb[3] / sizeof(half2); + const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); + const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + KV_min_max.alloc(ne_KV_max); + flash_attn_mask_to_KV_min_max<<>> + ((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33); + CUDA_CHECK(cudaGetLastError()); + } + const dim3 block_dim(warp_size, nwarps, 1); dim3 blocks_num; if (stream_k) { @@ -1313,6 +1472,7 @@ void launch_fattn_mma( V_data, mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *)sinks->data) : nullptr, + KV_min_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], @@ -1372,7 +1532,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - fattn_kernel_t fattn_kernel; + fattn_kernel_mma_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; fattn_kernel = flash_attn_mma_ext_f16;