Be able to set FA offset via command line argument

This commit is contained in:
Kawrakow
2026-01-29 06:28:00 +00:00
parent 02ae22388f
commit 629f546db1
4 changed files with 65 additions and 88 deletions

View File

@@ -4526,6 +4526,7 @@ struct cuda_params {
int fusion = GGML_CUDA_FUSION;
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
int mmq_id_thresh = 32;
float fa_offset = 0;
#ifdef USE_CUDA_GRAPH
bool use_cuda_graph = true;
#else
@@ -4581,6 +4582,17 @@ static cuda_params ggml_cuda_parse_params(const char * params_string) {
else if (parsed[0] == "enable-p2p") {
is_good = read_value(parsed[1], params.enable_p2p);
}
else if (parsed[0] == "fa-offset") {
float tmp;
is_good = read_value(parsed[1], tmp);
if (is_good) {
if (tmp < 0.0f || tmp > 3.0f) {
GGML_CUDA_LOG_WARN("%s: bad value for %s. It is %g, but must be in [0...3]\n", __func__, parsed[0].c_str(), tmp);
} else {
params.fa_offset = tmp;
}
}
}
#ifdef USE_CUDA_GRAPH
else if (parsed[0] == "graphs") {
is_good = read_value(parsed[1], params.use_cuda_graph);
@@ -4627,6 +4639,10 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] con
GGML_CUDA_LOG_INFO(" =========================== %s: setting mmq_id_thresh to %d\n", __func__, params.mmq_id_thresh);
ctx->mmq_id_thresh = params.mmq_id_thresh;
}
if (params.fa_offset != ctx->fa_offset) {
GGML_CUDA_LOG_INFO(" =========================== %s: setting fa_offset to %g\n", __func__, params.fa_offset);
ctx->fa_offset = params.fa_offset;
}
enable_p2p = params.enable_p2p;
#ifdef USE_CUDA_GRAPH
if (params.use_cuda_graph != ctx->use_cuda_graph) {

View File

@@ -850,9 +850,10 @@ struct ggml_backend_cuda_context {
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
int fusion = GGML_CUDA_FUSION;
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
int mmq_id_thresh = 32;
int fusion = GGML_CUDA_FUSION;
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
int mmq_id_thresh = 32;
float fa_offset = 0.0f;
#ifdef USE_CUDA_GRAPH
bool use_cuda_graph = true;

View File

@@ -27,30 +27,15 @@ typedef void (* fattn_kernel_mma_t)(
const float m0,
const float m1,
const float softcap,
const float fa_offset,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3);
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne10, const int ne11, const int ne12, const int ne13,
const int ne31, const int nb31,
const int nb01, const int nb02, const int nb03,
const int nb11, const int nb12, const int nb13,
const int nb21, const int nb22, const int nb23,
const int ne0, const int ne1, const int ne2, const int ne3);
template<int D, int nwarps, int KQ_per_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
@@ -160,6 +145,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float scale,
const float slope,
const float logit_softcap,
const float fa_offset,
const int ne01,
const int ne02,
const int stride_KV,
@@ -264,7 +250,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET);
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + fa_offset);
}
}
@@ -319,7 +305,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#pragma unroll
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
const int KQ_index = 2*t + (l/2) % 2;
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET);
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + fa_offset);
}
}
}
@@ -470,6 +456,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float scale,
const float slope,
const float logit_softcap,
const float fa_offset,
const int ne01,
const int ne02,
const int stride_Q1,
@@ -592,13 +579,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset,
ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset,
ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
@@ -918,30 +905,15 @@ static __global__ void flash_attn_mma_ext_f16(
const float m0,
const float m1,
const float logit_softcap,
const float fa_offset,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne10, const int ne11, const int ne12, const int ne13,
const int ne31, const int nb31,
const int nb01, const int nb02, const int nb03,
const int nb11, const int nb12, const int nb13,
const int nb21, const int nb22, const int nb23,
const int ne0, const int ne1, const int ne2, const int ne3) {
#if defined(INT8_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation:
@@ -1000,12 +972,12 @@ static __global__ void flash_attn_mma_ext_f16(
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
@@ -1042,7 +1014,7 @@ static __global__ void flash_attn_mma_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
@@ -1486,7 +1458,7 @@ void launch_fattn_mma(
sinks ? ((const char *)sinks->data) : nullptr,
KV_min_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
scale, max_bias, m0, m1, logit_softcap, ctx.fa_offset, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,

View File

@@ -28,30 +28,15 @@ typedef void (* fattn_new_mma_kernel_t)(
const float m0,
const float m1,
const float softcap,
const float fa_offset,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3);
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne10, const int ne11, const int ne12, const int ne13,
const int ne31, const int nb31,
const int nb01, const int nb02, const int nb03,
const int nb11, const int nb12, const int nb13,
const int nb21, const int nb22, const int nb23,
const int ne0, const int ne1, const int ne2, const int ne3);
typedef tile<16, 8, half2> tile_A;
@@ -542,6 +527,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float scale,
const float slope,
const float logit_softcap,
const float fa_offset,
const int ne01,
const int ne02,
const int stride_K,
@@ -702,7 +688,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET);
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + fa_offset);
}
}
@@ -756,7 +742,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#pragma unroll
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
const int KQ_index = 2*t + (l/2) % 2;
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET);
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + fa_offset);
}
}
}
@@ -928,6 +914,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float scale,
const float slope,
const float logit_softcap,
const float fa_offset,
const int ne01,
const int ne02,
const int gqa_ratio,
@@ -1066,13 +1053,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
@@ -1403,6 +1390,7 @@ static __global__ void flash_attn_ext_f16(
const float m0,
const float m1,
const float logit_softcap,
const float fa_offset,
const uint32_t n_head_log2,
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne10, const int ne11, const int ne12, const int ne13,
@@ -1484,12 +1472,12 @@ static __global__ void flash_attn_ext_f16(
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
}
@@ -1530,7 +1518,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -1961,7 +1949,7 @@ static void launch_fattn_new_mma(
sinks ? ((const char *)sinks->data) : nullptr,
KV_max.get(),
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
scale, max_bias, m0, m1, logit_softcap, ctx.fa_offset, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,