diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8e7fadba..592d3998 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -866,337 +866,3 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); } -template // D == head size -__launch_bounds__(D, 1) -static __global__ void flash_attn_mma_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { - constexpr int ncols = ncols1*ncols2; - - const int bidx0 = blockIdx.x; - const int j = blockIdx.y; - const int c = blockIdx.z; - const int jc = j*ncols2 + c; - const int tid = threadIdx.x; - - const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - - const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - - const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; - const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; - if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { - return; - } - - const int channel = kbc0 / (iter_k*iter_j); - const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; - - if (jt*ncols1 + j >= ne01) { - return; - } - - dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; - - // Load the partial result that needs a fixup: - float dst_val = 0.0f; - float max_val = 0.0f; - float rowsum = 0.0f; - { - dst_val = *dst; - - const float2 tmp = dst_fixup[bidx0*ncols + jc]; - max_val = tmp.x; - rowsum = tmp.y; - } - - - // Iterate over previous blocks and compute the combined results. - // All CUDA blocks that get here must have a previous block that needs a fixup. - int bidx = bidx0 - 1; - int kbc_stop = kbc0; - while(true) { - const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - if (kbc == kbc_stop) { // Did not have any data. - bidx--; - kbc_stop = kbc; - continue; - } - - const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; - - const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; - - // Scale the current and new value accumulators depending on the max. values. - const float max_val_new = fmaxf(max_val, tmp.x); - - const float diff_val = max_val - max_val_new; - const float diff_add = tmp.x - max_val_new; - - const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; - const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; - - dst_val = scale_val*dst_val + scale_add*dst_add; - rowsum = scale_val*rowsum + scale_add*tmp.y; - - max_val = max_val_new; - - // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { - break; - } - bidx--; - kbc_stop = kbc; - } - - // Write back final result: - *dst = dst_val / rowsum; -} - -template // D == head size -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void flash_attn_mma_combine_results( - const float * __restrict__ VKQ_parts, - const float2 * __restrict__ VKQ_meta, - float * __restrict__ dst, - const int parallel_blocks) { - VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; - dst += D * gridDim.z*blockIdx.x; - - const int tid = threadIdx.x; - __builtin_assume(tid < D); - - extern __shared__ float2 meta[]; - if (tid < 2*parallel_blocks) { - ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; - } - - __syncthreads(); - - float kqmax = meta[0].x; - for (int l = 1; l < parallel_blocks; ++l) { - kqmax = max(kqmax, meta[l].x); - } - - float VKQ_numerator = 0.0f; - float VKQ_denominator = 0.0f; - for (int l = 0; l < parallel_blocks; ++l) { - const float diff = meta[l].x - kqmax; - float KQ_max_scale = expf(diff); - const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); - *((uint32_t *) &KQ_max_scale) &= ftz_mask; - - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; - VKQ_denominator += KQ_max_scale * meta[l].y; - } - - dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; -} - -template -void launch_fattn_mma( - ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, - const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE -) { - constexpr int ncols = ncols1 * ncols2; - - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - const ggml_tensor * mask = dst->src[3]; - const ggml_tensor * sinks = dst->src[4]; - - ggml_tensor * KQV = dst; - - GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(KQV->type == GGML_TYPE_F32); - - GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); - - GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - - GGML_ASSERT(Q->ne[3] == 1); - - ggml_cuda_pool & pool = ctx.pool(); - cudaStream_t main_stream = ctx.stream(); - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const int nsm = ggml_cuda_info().devices[id].nsm; - - ggml_cuda_pool_alloc K_f16(pool); - ggml_cuda_pool_alloc V_f16(pool); - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); - - const char * K_data = (const char *) K->data; - size_t nb11 = K->nb[1]; - size_t nb12 = K->nb[2]; - size_t nb13 = K->nb[3]; - - const char * V_data = (const char *) V->data; - size_t nb21 = V->nb[1]; - size_t nb22 = V->nb[2]; - size_t nb23 = V->nb[3]; - - if (need_f16_K && K->type != GGML_TYPE_F16) { - K_f16.alloc(ggml_nelements(K)); - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); - to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); - K_data = (char *) K_f16.ptr; - - const size_t bs = ggml_blck_size(K->type); - const size_t ts = ggml_type_size(K->type); - - nb11 = nb11*bs*sizeof(half)/ts; - nb12 = nb12*bs*sizeof(half)/ts; - nb13 = nb13*bs*sizeof(half)/ts; - } - - if (need_f16_V && V->type != GGML_TYPE_F16) { - V_f16.alloc(ggml_nelements(V)); - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; - - const size_t bs = ggml_blck_size(V->type); - const size_t ts = ggml_type_size(V->type); - - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; - } - - int parallel_blocks = 1; - - const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); - const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; - - const dim3 block_dim(warp_size, nwarps, 1); - dim3 blocks_num; - if (stream_k) { - // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int max_blocks = 2*nsm; - const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; - const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); - - const int nblocks_stream_k = max_blocks; - - const bool use_stream_k = cc >= CC_ADA_LOVELACE || tiles_efficiency_percent < 75; - - blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; - blocks_num.y = 1; - blocks_num.z = 1; - - dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); - } else { - GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); - const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. - - int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. - CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); - - // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: - parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); - - // parallel_blocks must not be larger than what the tensor size allows: - parallel_blocks = std::min(parallel_blocks, ntiles_KQ); - - // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. - // Test whether parallel_blocks can be set to a higher value for better efficiency. - const int blocks_per_wave = nsm * max_blocks_per_sm; - int nwaves_best = 0; - int efficiency_percent_best = 0; - for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { - const int nblocks_total = ntiles_total * parallel_blocks_test; - const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; - const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); - - // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. - if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { - break; - } - - if (efficiency_percent > efficiency_percent_best) { - nwaves_best = nwaves; - efficiency_percent_best = efficiency_percent; - parallel_blocks = parallel_blocks_test; - } - } - - blocks_num.x = ntiles_x; - blocks_num.y = parallel_blocks; - blocks_num.z = Q->ne[2]*Q->ne[3]; - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } - } - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (logit_softcap != 0.0f) { - scale /= logit_softcap; - } - - const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - GGML_ASSERT(block_dim.x % warp_size == 0); - fattn_kernel<<>>( - (const char *) Q->data, - K_data, - V_data, - mask ? ((const char *) mask->data) : nullptr, - sinks ? ((const char *)sinks->data) : nullptr, - !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, logit_softcap, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - nb11, nb12, nb13, - nb21, nb22, nb23, - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - CUDA_CHECK(cudaGetLastError()); - - if (stream_k) { - if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - - flash_attn_mma_stream_k_fixup - <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); - } - } else if (parallel_blocks > 1) { - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); - const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - - flash_attn_mma_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); - } - CUDA_CHECK(cudaGetLastError()); -} - diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 14444832..3603c040 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1010,6 +1010,340 @@ static __global__ void flash_attn_mma_ext_f16( #endif // defined(INT8_MMA_AVAILABLE) } +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_mma_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + constexpr int ncols = ncols1*ncols2; + + const int bidx0 = blockIdx.x; + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + const int channel = kbc0 / (iter_k*iter_j); + const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + + if (jt*ncols1 + j >= ne01) { + return; + } + + dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const float2 tmp = dst_fixup[bidx0*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +template // D == head size +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_mma_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks) { + VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; + dst += D * gridDim.z*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + extern __shared__ float2 meta[]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; +} + +template +void launch_fattn_mma( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE +) { + constexpr int ncols = ncols1 * ncols2; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + GGML_ASSERT(Q->ne[3] == 1); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t main_stream = ctx.stream(); + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + K_f16.alloc(ggml_nelements(K)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); + K_data = (char *) K_f16.ptr; + + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } + + int parallel_blocks = 1; + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + + const dim3 block_dim(warp_size, nwarps, 1); + dim3 blocks_num; + if (stream_k) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int max_blocks = 2*nsm; + const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + + const int nblocks_stream_k = max_blocks; + + const bool use_stream_k = cc >= CC_ADA_LOVELACE || tiles_efficiency_percent < 75; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.y = 1; + blocks_num.z = 1; + + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); + } else { + GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); + const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + + // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: + parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = Q->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(block_dim.x % warp_size == 0); + fattn_kernel<<>>( + (const char *) Q->data, + K_data, + V_data, + mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *)sinks->data) : nullptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + nb11, nb12, nb13, + nb21, nb22, nb23, + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (stream_k) { + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; + + flash_attn_mma_stream_k_fixup + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + } + } else if (parallel_blocks > 1) { + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); + + flash_attn_mma_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + } + CUDA_CHECK(cudaGetLastError()); +} + template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int ncols = ncols1 * ncols2;