From 629f546db1cfc3da92c90be6cd51d437373ea389 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 29 Jan 2026 06:28:00 +0000 Subject: [PATCH] Be able to set FA offset via command line argument --- ggml/src/ggml-cuda.cu | 16 ++++++ ggml/src/ggml-cuda/common.cuh | 7 +-- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 80 +++++++++------------------- ggml/src/ggml-cuda/fattn-new-mma.cu | 50 +++++++---------- 4 files changed, 65 insertions(+), 88 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index b533dd52..b3ec2c6b 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -4526,6 +4526,7 @@ struct cuda_params { int fusion = GGML_CUDA_FUSION; int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD; int mmq_id_thresh = 32; + float fa_offset = 0; #ifdef USE_CUDA_GRAPH bool use_cuda_graph = true; #else @@ -4581,6 +4582,17 @@ static cuda_params ggml_cuda_parse_params(const char * params_string) { else if (parsed[0] == "enable-p2p") { is_good = read_value(parsed[1], params.enable_p2p); } + else if (parsed[0] == "fa-offset") { + float tmp; + is_good = read_value(parsed[1], tmp); + if (is_good) { + if (tmp < 0.0f || tmp > 3.0f) { + GGML_CUDA_LOG_WARN("%s: bad value for %s. It is %g, but must be in [0...3]\n", __func__, parsed[0].c_str(), tmp); + } else { + params.fa_offset = tmp; + } + } + } #ifdef USE_CUDA_GRAPH else if (parsed[0] == "graphs") { is_good = read_value(parsed[1], params.use_cuda_graph); @@ -4627,6 +4639,10 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] con GGML_CUDA_LOG_INFO(" =========================== %s: setting mmq_id_thresh to %d\n", __func__, params.mmq_id_thresh); ctx->mmq_id_thresh = params.mmq_id_thresh; } + if (params.fa_offset != ctx->fa_offset) { + GGML_CUDA_LOG_INFO(" =========================== %s: setting fa_offset to %g\n", __func__, params.fa_offset); + ctx->fa_offset = params.fa_offset; + } enable_p2p = params.enable_p2p; #ifdef USE_CUDA_GRAPH if (params.use_cuda_graph != ctx->use_cuda_graph) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 45021c3e..f4209197 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -850,9 +850,10 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - int fusion = GGML_CUDA_FUSION; - int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD; - int mmq_id_thresh = 32; + int fusion = GGML_CUDA_FUSION; + int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD; + int mmq_id_thresh = 32; + float fa_offset = 0.0f; #ifdef USE_CUDA_GRAPH bool use_cuda_graph = true; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 050186ce..6765b066 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -27,30 +27,15 @@ typedef void (* fattn_kernel_mma_t)( const float m0, const float m1, const float softcap, + const float fa_offset, const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3); + const int ne00, const int ne01, const int ne02, const int ne03, + const int ne10, const int ne11, const int ne12, const int ne13, + const int ne31, const int nb31, + const int nb01, const int nb02, const int nb03, + const int nb11, const int nb12, const int nb13, + const int nb21, const int nb22, const int nb23, + const int ne0, const int ne1, const int ne2, const int ne3); template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( @@ -160,6 +145,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float scale, const float slope, const float logit_softcap, + const float fa_offset, const int ne01, const int ne02, const int stride_KV, @@ -264,7 +250,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET); + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + fa_offset); } } @@ -319,7 +305,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < tile_C_KQ_16::ne; ++l) { const int KQ_index = 2*t + (l/2) % 2; - KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET); + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + fa_offset); } } } @@ -470,6 +456,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float scale, const float slope, const float logit_softcap, + const float fa_offset, const int ne01, const int ne02, const int stride_Q1, @@ -592,13 +579,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset, ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset, ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } @@ -918,30 +905,15 @@ static __global__ void flash_attn_mma_ext_f16( const float m0, const float m1, const float logit_softcap, + const float fa_offset, const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int ne00, const int ne01, const int ne02, const int ne03, + const int ne10, const int ne11, const int ne12, const int ne13, + const int ne31, const int nb31, + const int nb01, const int nb02, const int nb03, + const int nb11, const int nb12, const int nb13, + const int nb21, const int nb22, const int nb23, + const int ne0, const int ne1, const int ne2, const int ne3) { #if defined(INT8_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: @@ -1000,12 +972,12 @@ static __global__ void flash_attn_mma_ext_f16( if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1042,7 +1014,7 @@ static __global__ void flash_attn_mma_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); @@ -1486,7 +1458,7 @@ void launch_fattn_mma( sinks ? ((const char *)sinks->data) : nullptr, KV_min_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, logit_softcap, + scale, max_bias, m0, m1, logit_softcap, ctx.fa_offset, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 89ee392e..0e7908c2 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -28,30 +28,15 @@ typedef void (* fattn_new_mma_kernel_t)( const float m0, const float m1, const float softcap, + const float fa_offset, const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3); + const int ne00, const int ne01, const int ne02, const int ne03, + const int ne10, const int ne11, const int ne12, const int ne13, + const int ne31, const int nb31, + const int nb01, const int nb02, const int nb03, + const int nb11, const int nb12, const int nb13, + const int nb21, const int nb22, const int nb23, + const int ne0, const int ne1, const int ne2, const int ne3); typedef tile<16, 8, half2> tile_A; @@ -542,6 +527,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float scale, const float slope, const float logit_softcap, + const float fa_offset, const int ne01, const int ne02, const int stride_K, @@ -702,7 +688,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET); + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + fa_offset); } } @@ -756,7 +742,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < tile_C_KQ_16::ne; ++l) { const int KQ_index = 2*t + (l/2) % 2; - KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET); + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + fa_offset); } } } @@ -928,6 +914,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float scale, const float slope, const float logit_softcap, + const float fa_offset, const int ne01, const int ne02, const int gqa_ratio, @@ -1066,13 +1053,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } @@ -1403,6 +1390,7 @@ static __global__ void flash_attn_ext_f16( const float m0, const float m1, const float logit_softcap, + const float fa_offset, const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, const int ne03, const int ne10, const int ne11, const int ne12, const int ne13, @@ -1484,12 +1472,12 @@ static __global__ void flash_attn_ext_f16( if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset, ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset, ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); } @@ -1530,7 +1518,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset, ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -1961,7 +1949,7 @@ static void launch_fattn_new_mma( sinks ? ((const char *)sinks->data) : nullptr, KV_max.get(), !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, logit_softcap, n_head_log2, + scale, max_bias, m0, m1, logit_softcap, ctx.fa_offset, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,