mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Better FA for GLM-4.7-Flash
This commit is contained in:
@@ -566,10 +566,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int nstages = 0;
|
||||
#endif // CP_ASYNC_AVAILABLE
|
||||
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||
|
||||
@@ -936,6 +936,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int stride_V,
|
||||
const int stride_mask,
|
||||
const int jt,
|
||||
const int zt,
|
||||
const int kb0_start,
|
||||
const int kb0_stop) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
@@ -952,7 +953,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||
|
||||
@@ -1008,7 +1009,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int j = jc / ncols2;
|
||||
const int c = jc % ncols2;
|
||||
|
||||
if (jt*ncols1 + j < ne01) {
|
||||
if ((ncols1 == 1 || jt*ncols1 + j < ne01) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
@@ -1336,7 +1337,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int j_dst = jc_dst / ncols2;
|
||||
const int c_dst = jc_dst % ncols2;
|
||||
|
||||
if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
|
||||
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= ne01) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1401,29 +1402,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
const float m1,
|
||||
const float logit_softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const int ne31, const int nb31,
|
||||
const int nb01, const int nb02, const int nb03,
|
||||
const int nb11, const int nb12, const int nb13,
|
||||
const int nb21, const int nb22, const int nb23,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
||||
#if defined(INT8_MMA_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
@@ -1455,12 +1440,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
|
||||
|
||||
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
int kbc = int64_t((blockIdx.x + 0)*iter_k*iter_j*iter_z) / gridDim.x;
|
||||
const int kbc_stop = int64_t((blockIdx.x + 1)*iter_k*iter_j*iter_z) / gridDim.x;
|
||||
|
||||
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
||||
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
||||
@@ -1470,18 +1456,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
int kb0_start = kbc % iter_k;
|
||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc / (iter_k*iter_j*iter_z);
|
||||
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k;
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02*zt*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr;
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
@@ -1494,12 +1481,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
|
||||
kbc += iter_k;
|
||||
@@ -1513,18 +1500,20 @@ static __global__ void flash_attn_ext_f16(
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const int sequence = kbc / (iter_k*iter_j*iter_z);
|
||||
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
|
||||
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* zt*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr;
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
@@ -1536,7 +1525,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
@@ -1558,6 +1547,7 @@ __launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
constexpr int ne03 = 1;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
const int j = blockIdx.y;
|
||||
@@ -1569,9 +1559,10 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0 = int64_t((bidx0 + 0)*iter_k*iter_j*iter_z*ne03) / gridDim.x;
|
||||
const int kbc0_stop = int64_t((bidx0 + 1)*iter_k*iter_j*iter_z*ne03) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
@@ -1580,14 +1571,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
const int sequence = kbc0 / (iter_k*iter_j*iter_z);
|
||||
const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
||||
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val = 0.0f;
|
||||
@@ -1606,7 +1598,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
@@ -1846,7 +1838,8 @@ static void launch_fattn_new_mma(
|
||||
int parallel_blocks = 1;
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2);
|
||||
const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3];
|
||||
|
||||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
@@ -1921,7 +1914,7 @@ static void launch_fattn_new_mma(
|
||||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = Q->ne[2]*Q->ne[3];
|
||||
blocks_num.z = ntiles_z*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
@@ -2161,6 +2154,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||
if (gqa_ratio == 20 && Q->ne[1] <= 4 && K->ne[1] >= 2048) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<576, 512, 1, 32>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
} else if (gqa_ratio % 4 == 0) {
|
||||
|
||||
@@ -107,7 +107,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
// Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512,
|
||||
// so no other implementation works.
|
||||
//
|
||||
if (new_mma_available(cc) && K->ne[0] == 576 && V->ne[0] == 512 && Q->ne[1] == 1 &&
|
||||
if (false && new_mma_available(cc) && K->ne[0] == 576 && V->ne[0] == 512 && Q->ne[1] == 1 &&
|
||||
Q->ne[2]/K->ne[2] == 20 && K->ne[1] > 8192) {
|
||||
// GLM-4.7-Flash TG hack: split 20 heads into 16+4 heads
|
||||
auto local_Q = *Q;
|
||||
|
||||
Reference in New Issue
Block a user