Better FA for GLM-4.7-Flash

This commit is contained in:
Kawrakow
2026-01-26 07:31:13 +00:00
parent 30381fc1fc
commit bd7e75192e
2 changed files with 56 additions and 59 deletions

View File

@@ -566,10 +566,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int nstages = 0;
#endif // CP_ASYNC_AVAILABLE
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int ncols = ncols1 * ncols2;
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
@@ -936,6 +936,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int stride_V,
const int stride_mask,
const int jt,
const int zt,
const int kb0_start,
const int kb0_stop) {
#ifdef INT8_MMA_AVAILABLE
@@ -952,7 +953,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
@@ -1008,7 +1009,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int j = jc / ncols2;
const int c = jc % ncols2;
if (jt*ncols1 + j < ne01) {
if ((ncols1 == 1 || jt*ncols1 + j < ne01) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -1336,7 +1337,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int j_dst = jc_dst / ncols2;
const int c_dst = jc_dst % ncols2;
if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= ne01) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
continue;
}
@@ -1401,29 +1402,13 @@ static __global__ void flash_attn_ext_f16(
const float m1,
const float logit_softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne10, const int ne11, const int ne12, const int ne13,
const int ne31, const int nb31,
const int nb01, const int nb02, const int nb03,
const int nb11, const int nb12, const int nb13,
const int nb21, const int nb22, const int nb23,
const int ne0, const int ne1, const int ne2, const int ne3) {
#if defined(INT8_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation:
@@ -1455,12 +1440,13 @@ static __global__ void flash_attn_ext_f16(
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
// kbc == k block continuous, current index in continuous ijk space.
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
int kbc = int64_t((blockIdx.x + 0)*iter_k*iter_j*iter_z) / gridDim.x;
const int kbc_stop = int64_t((blockIdx.x + 1)*iter_k*iter_j*iter_z) / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1470,18 +1456,19 @@ static __global__ void flash_attn_ext_f16(
int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const int sequence = kbc / (iter_k*iter_j*iter_z);
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k;
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb02*zt*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2);
const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr;
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f;
int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1494,12 +1481,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
}
kbc += iter_k;
@@ -1513,18 +1500,20 @@ static __global__ void flash_attn_ext_f16(
return;
}
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const int sequence = kbc / (iter_k*iter_j*iter_z);
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* zt*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2);
const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr;
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f;
int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1536,7 +1525,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
@@ -1558,6 +1547,7 @@ __launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
constexpr int ncols = ncols1*ncols2;
constexpr int ne03 = 1;
const int bidx0 = blockIdx.x;
const int j = blockIdx.y;
@@ -1569,9 +1559,10 @@ static __global__ void flash_attn_stream_k_fixup(
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0 = int64_t((bidx0 + 0)*iter_k*iter_j*iter_z*ne03) / gridDim.x;
const int kbc0_stop = int64_t((bidx0 + 1)*iter_k*iter_j*iter_z*ne03) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -1580,14 +1571,15 @@ static __global__ void flash_attn_stream_k_fixup(
return;
}
const int channel = kbc0 / (iter_k*iter_j);
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
const int sequence = kbc0 / (iter_k*iter_j*iter_z);
const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
if (jt*ncols1 + j >= ne01) {
if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) {
return;
}
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup:
float dst_val = 0.0f;
@@ -1606,7 +1598,7 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
@@ -1846,7 +1838,8 @@ static void launch_fattn_new_mma(
int parallel_blocks = 1;
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2);
const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3];
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -1921,7 +1914,7 @@ static void launch_fattn_new_mma(
blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
blocks_num.z = Q->ne[2]*Q->ne[3];
blocks_num.z = ntiles_z*Q->ne[3];
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -2161,6 +2154,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
if (gqa_ratio == 20 && Q->ne[1] <= 4 && K->ne[1] >= 2048) {
ggml_cuda_flash_attn_ext_mma_f16_case<576, 512, 1, 32>(ctx, dst);
return;
}
if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} else if (gqa_ratio % 4 == 0) {

View File

@@ -107,7 +107,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512,
// so no other implementation works.
//
if (new_mma_available(cc) && K->ne[0] == 576 && V->ne[0] == 512 && Q->ne[1] == 1 &&
if (false && new_mma_available(cc) && K->ne[0] == 576 && V->ne[0] == 512 && Q->ne[1] == 1 &&
Q->ne[2]/K->ne[2] == 20 && K->ne[1] > 8192) {
// GLM-4.7-Flash TG hack: split 20 heads into 16+4 heads
auto local_Q = *Q;