mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 09:04:10 +00:00
Be able to set FA offset via command line argument
This commit is contained in:
@@ -4526,6 +4526,7 @@ struct cuda_params {
|
||||
int fusion = GGML_CUDA_FUSION;
|
||||
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
||||
int mmq_id_thresh = 32;
|
||||
float fa_offset = 0;
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
bool use_cuda_graph = true;
|
||||
#else
|
||||
@@ -4581,6 +4582,17 @@ static cuda_params ggml_cuda_parse_params(const char * params_string) {
|
||||
else if (parsed[0] == "enable-p2p") {
|
||||
is_good = read_value(parsed[1], params.enable_p2p);
|
||||
}
|
||||
else if (parsed[0] == "fa-offset") {
|
||||
float tmp;
|
||||
is_good = read_value(parsed[1], tmp);
|
||||
if (is_good) {
|
||||
if (tmp < 0.0f || tmp > 3.0f) {
|
||||
GGML_CUDA_LOG_WARN("%s: bad value for %s. It is %g, but must be in [0...3]\n", __func__, parsed[0].c_str(), tmp);
|
||||
} else {
|
||||
params.fa_offset = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
else if (parsed[0] == "graphs") {
|
||||
is_good = read_value(parsed[1], params.use_cuda_graph);
|
||||
@@ -4627,6 +4639,10 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] con
|
||||
GGML_CUDA_LOG_INFO(" =========================== %s: setting mmq_id_thresh to %d\n", __func__, params.mmq_id_thresh);
|
||||
ctx->mmq_id_thresh = params.mmq_id_thresh;
|
||||
}
|
||||
if (params.fa_offset != ctx->fa_offset) {
|
||||
GGML_CUDA_LOG_INFO(" =========================== %s: setting fa_offset to %g\n", __func__, params.fa_offset);
|
||||
ctx->fa_offset = params.fa_offset;
|
||||
}
|
||||
enable_p2p = params.enable_p2p;
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
if (params.use_cuda_graph != ctx->use_cuda_graph) {
|
||||
|
||||
@@ -850,9 +850,10 @@ struct ggml_backend_cuda_context {
|
||||
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
|
||||
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
|
||||
int fusion = GGML_CUDA_FUSION;
|
||||
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
||||
int mmq_id_thresh = 32;
|
||||
int fusion = GGML_CUDA_FUSION;
|
||||
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
||||
int mmq_id_thresh = 32;
|
||||
float fa_offset = 0.0f;
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
bool use_cuda_graph = true;
|
||||
|
||||
|
||||
@@ -27,30 +27,15 @@ typedef void (* fattn_kernel_mma_t)(
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const float fa_offset,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3);
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const int ne31, const int nb31,
|
||||
const int nb01, const int nb02, const int nb03,
|
||||
const int nb11, const int nb12, const int nb13,
|
||||
const int nb21, const int nb22, const int nb23,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3);
|
||||
|
||||
template<int D, int nwarps, int KQ_per_iter>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
@@ -160,6 +145,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const float scale,
|
||||
const float slope,
|
||||
const float logit_softcap,
|
||||
const float fa_offset,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int stride_KV,
|
||||
@@ -264,7 +250,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + fa_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,7 +305,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
|
||||
const int KQ_index = 2*t + (l/2) % 2;
|
||||
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + fa_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -470,6 +456,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float scale,
|
||||
const float slope,
|
||||
const float logit_softcap,
|
||||
const float fa_offset,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int stride_Q1,
|
||||
@@ -592,13 +579,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
||||
constexpr bool last_iter = false;
|
||||
flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
}
|
||||
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||
constexpr bool last_iter = true;
|
||||
flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
||||
}
|
||||
|
||||
@@ -918,30 +905,15 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float logit_softcap,
|
||||
const float fa_offset,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const int ne31, const int nb31,
|
||||
const int nb01, const int nb02, const int nb03,
|
||||
const int nb11, const int nb12, const int nb13,
|
||||
const int nb21, const int nb22, const int nb23,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
||||
#if defined(INT8_MMA_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
@@ -1000,12 +972,12 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
|
||||
@@ -1042,7 +1014,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||
@@ -1486,7 +1458,7 @@ void launch_fattn_mma(
|
||||
sinks ? ((const char *)sinks->data) : nullptr,
|
||||
KV_min_max.ptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
scale, max_bias, m0, m1, logit_softcap, ctx.fa_offset, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
|
||||
@@ -28,30 +28,15 @@ typedef void (* fattn_new_mma_kernel_t)(
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const float fa_offset,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3);
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const int ne31, const int nb31,
|
||||
const int nb01, const int nb02, const int nb03,
|
||||
const int nb11, const int nb12, const int nb13,
|
||||
const int nb21, const int nb22, const int nb23,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3);
|
||||
|
||||
|
||||
typedef tile<16, 8, half2> tile_A;
|
||||
@@ -542,6 +527,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const float scale,
|
||||
const float slope,
|
||||
const float logit_softcap,
|
||||
const float fa_offset,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int stride_K,
|
||||
@@ -702,7 +688,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + fa_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -756,7 +742,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
|
||||
const int KQ_index = 2*t + (l/2) % 2;
|
||||
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + fa_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -928,6 +914,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float scale,
|
||||
const float slope,
|
||||
const float logit_softcap,
|
||||
const float fa_offset,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int gqa_ratio,
|
||||
@@ -1066,13 +1053,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
||||
constexpr bool last_iter = false;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
}
|
||||
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||
constexpr bool last_iter = true;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
||||
}
|
||||
|
||||
@@ -1403,6 +1390,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float logit_softcap,
|
||||
const float fa_offset,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
@@ -1484,12 +1472,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
|
||||
@@ -1530,7 +1518,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, fa_offset,
|
||||
ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
@@ -1961,7 +1949,7 @@ static void launch_fattn_new_mma(
|
||||
sinks ? ((const char *)sinks->data) : nullptr,
|
||||
KV_max.get(),
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
|
||||
scale, max_bias, m0, m1, logit_softcap, ctx.fa_offset, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
|
||||
Reference in New Issue
Block a user