mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 22:24:11 +00:00
This seems much better for GQA = 12 TG
This commit is contained in:
@@ -39,6 +39,8 @@ typedef void (* fattn_new_mma_kernel_t)(
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int ne33,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -211,37 +213,37 @@ struct fattn_mma_f16_config;
|
||||
// }
|
||||
//};
|
||||
//
|
||||
//template <>
|
||||
//struct fattn_mma_f16_config<128, 128> {
|
||||
// static constexpr int nbatch_fa = 64;
|
||||
// static constexpr int nwarps_max = 4;
|
||||
// static constexpr bool Q_in_reg = true;
|
||||
// static constexpr int nstages_target = 2;
|
||||
//
|
||||
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 64;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
// return 64;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 64;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
// return 64;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 64;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
// return 64;
|
||||
// }
|
||||
//};
|
||||
template <>
|
||||
struct fattn_mma_f16_config<128, 128> {
|
||||
static constexpr int nbatch_fa = 64;
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
};
|
||||
//
|
||||
//template <>
|
||||
//struct fattn_mma_f16_config<256, 256> {
|
||||
@@ -930,6 +932,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float logit_softcap,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int gqa_ratio,
|
||||
const int stride_Q1,
|
||||
const int stride_Q2,
|
||||
const int stride_K,
|
||||
@@ -1009,7 +1012,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int j = jc / ncols2;
|
||||
const int c = jc % ncols2;
|
||||
|
||||
if ((ncols1 == 1 || jt*ncols1 + j < ne01) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
|
||||
if ((ncols1 == 1 || jt*ncols1 + j < ne01) && (ncols2 == 1 || zt*ncols2 + c < gqa_ratio)) {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
@@ -1337,7 +1340,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int j_dst = jc_dst / ncols2;
|
||||
const int c_dst = jc_dst % ncols2;
|
||||
|
||||
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= ne01) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
|
||||
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= ne01) || (ncols2 > 1 && zt*ncols2 + c_dst >= gqa_ratio))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1381,6 +1384,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
||||
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
|
||||
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
||||
GGML_UNUSED(gqa_ratio);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
@@ -1404,7 +1408,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const int ne31, const int nb31,
|
||||
const int ne31, const int nb31, const int ne33, const int nb33,
|
||||
const int nb01, const int nb02, const int nb03,
|
||||
const int nb11, const int nb12, const int nb13,
|
||||
const int nb21, const int nb22, const int nb23,
|
||||
@@ -1440,13 +1444,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
|
||||
const int iter_z = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
||||
|
||||
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int kbc = int64_t((blockIdx.x + 0)*iter_k*iter_j*iter_z) / gridDim.x;
|
||||
const int kbc_stop = int64_t((blockIdx.x + 1)*iter_k*iter_j*iter_z) / gridDim.x;
|
||||
int kbc = int64_t((blockIdx.x + 0)*iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x;
|
||||
const int kbc_stop = int64_t((blockIdx.x + 1)*iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x;
|
||||
|
||||
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
||||
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
||||
@@ -1456,19 +1460,22 @@ static __global__ void flash_attn_ext_f16(
|
||||
int kb0_start = kbc % iter_k;
|
||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||
const int sequence = kbc / (iter_k*iter_j*iter_z);
|
||||
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k;
|
||||
const int sequence = kbc /(iter_k*iter_j*iter_z*ne12);
|
||||
const int z_KV = (kbc - iter_k*iter_j*iter_z*ne12 * sequence)/(iter_k*iter_j*iter_z);
|
||||
const int zt = (kbc - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV)/(iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV - iter_k*iter_j * zt) / iter_k;
|
||||
const int zt_Q = z_KV*gqa_ratio + zt*ncols2; // Global Q head start index.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02*zt*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr;
|
||||
//const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + zt_Q) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
@@ -1481,12 +1488,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
|
||||
kbc += iter_k;
|
||||
@@ -1500,20 +1507,23 @@ static __global__ void flash_attn_ext_f16(
|
||||
return;
|
||||
}
|
||||
|
||||
const int sequence = kbc /(iter_k*iter_j*iter_z*ne12);
|
||||
const int z_KV = (kbc - iter_k*iter_j*iter_z*ne12 * sequence)/(iter_k*iter_j*iter_z);
|
||||
const int zt = (kbc - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV)/(iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV - iter_k*iter_j * zt) / iter_k;
|
||||
|
||||
const int sequence = kbc / (iter_k*iter_j*iter_z);
|
||||
const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
|
||||
const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
||||
const int zt_Q = z_KV*gqa_ratio + zt*ncols2; // Global Q head start index.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* zt*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr;
|
||||
//const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + zt_Q) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
@@ -1525,7 +1535,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
@@ -1545,7 +1555,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02,
|
||||
const int ne11, const int ne12) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
constexpr int ne03 = 1;
|
||||
|
||||
@@ -1557,12 +1568,14 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
|
||||
const int gqa_ratio = ne02/ne12;
|
||||
|
||||
const int kbc0 = int64_t((bidx0 + 0)*iter_k*iter_j*iter_z*ne03) / gridDim.x;
|
||||
const int kbc0_stop = int64_t((bidx0 + 1)*iter_k*iter_j*iter_z*ne03) / gridDim.x;
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE; // In our implementation ne11 is always a multiple of FATTN_KQ_STRIDE
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
||||
|
||||
const int kbc0 = int64_t((bidx0 + 0)*iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x;
|
||||
const int kbc0_stop = int64_t((bidx0 + 1)*iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
@@ -1571,15 +1584,18 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
return;
|
||||
}
|
||||
|
||||
const int sequence = kbc0 / (iter_k*iter_j*iter_z);
|
||||
const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc0 /(iter_k*iter_j*iter_z*ne12);
|
||||
const int z_KV = (kbc0 - iter_k*iter_j*iter_z*ne12 * sequence)/(iter_k*iter_j*iter_z);
|
||||
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV)/(iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
||||
|
||||
if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) {
|
||||
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
||||
|
||||
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val = 0.0f;
|
||||
@@ -1598,7 +1614,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
|
||||
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
@@ -1837,9 +1853,10 @@ static void launch_fattn_new_mma(
|
||||
|
||||
int parallel_blocks = 1;
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2);
|
||||
const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3];
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
const int ntiles_z = ((gqa_ratio + ncols2 - 1) / ncols2);
|
||||
const int ntiles_total = ntiles_x * ntiles_z * K->ne[2] * Q->ne[3];
|
||||
|
||||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
@@ -1914,7 +1931,7 @@ static void launch_fattn_new_mma(
|
||||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = ntiles_z*Q->ne[3];
|
||||
blocks_num.z = ntiles_z*K->ne[2]*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
@@ -1951,7 +1968,8 @@ static void launch_fattn_new_mma(
|
||||
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
mask ? mask->ne[3] : 0, mask ? mask->nb[3] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
@@ -1966,7 +1984,7 @@ static void launch_fattn_new_mma(
|
||||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1], K->ne[2]);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
@@ -2135,6 +2153,13 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (K->ne[0] == 128 && gqa_ratio == 12) {
|
||||
GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128);
|
||||
//GGML_ASSERT(Q->ne[1] <= 4);
|
||||
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
//if (K->ne[0] == 64 && V->ne[0] == 64) {
|
||||
// ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
|
||||
// return;
|
||||
|
||||
@@ -89,6 +89,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
return;
|
||||
}
|
||||
|
||||
if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 &&
|
||||
Q->ne[2] / K->ne[2] == 12) {
|
||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
||||
// So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster.
|
||||
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
|
||||
|
||||
Reference in New Issue
Block a user