mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
CUDA: corectly detect if flash attention is supported (#875)
* Don't use vector kernels if K or V are quantized * Correctly determine if FA is supported * Also wmma * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -4386,31 +4386,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
|
||||
#else
|
||||
if (op->src[0]->ne[0] == 128) {
|
||||
return true;
|
||||
}
|
||||
if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 &&
|
||||
(op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) &&
|
||||
(op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) {
|
||||
return true;
|
||||
}
|
||||
if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) {
|
||||
return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) ||
|
||||
(op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0);
|
||||
}
|
||||
if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) {
|
||||
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
|
||||
int gqa = op->src[0]->ne[2]/op->src[1]->ne[2];
|
||||
return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0);
|
||||
}
|
||||
if (op->src[1]->ne[0] > 256) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
return ggml_cuda_fattn_is_supported(*cuda_ctx, op);
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -3,3 +3,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_fattn_mma_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);
|
||||
|
||||
@@ -83,3 +83,10 @@ void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
|
||||
}
|
||||
|
||||
bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
|
||||
auto K = dst->src[1];
|
||||
auto V = dst->src[1];
|
||||
if (K->ne[0] != V->ne[0]) return false;
|
||||
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256;
|
||||
}
|
||||
|
||||
@@ -353,3 +353,10 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_fattn_tile_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
|
||||
auto K = dst->src[1];
|
||||
auto V = dst->src[2];
|
||||
if (K->ne[0] != V->ne[0]) return false;
|
||||
return K->ne[0] == 64 || K->ne[0] == 128;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_fattn_tile_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);
|
||||
|
||||
@@ -352,3 +352,10 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_fattn_tile_f32_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
|
||||
auto K = dst->src[1];
|
||||
auto V = dst->src[2];
|
||||
if (K->ne[0] != V->ne[0]) return false;
|
||||
return K->ne[0] == 64 || K->ne[0] == 128;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_fattn_tile_f32_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);
|
||||
|
||||
@@ -3,3 +3,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_fattn_vec_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);
|
||||
|
||||
@@ -102,4 +102,50 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
on_no_fattn_vec_case(Q->ne[0], V->ne[0]);
|
||||
}
|
||||
|
||||
|
||||
bool ggml_cuda_fattn_vec_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
|
||||
auto K = dst->src[1];
|
||||
auto V = dst->src[2];
|
||||
if (K->ne[0] != V->ne[0]) {
|
||||
if (K->ne[0] != 192 || V->ne[2] != 128) return false;
|
||||
if (K->type != V->type) return false;
|
||||
return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0;
|
||||
}
|
||||
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
if (K->ne[0] == 64) {
|
||||
return K->type == GGML_TYPE_F16 &&
|
||||
(V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 ||
|
||||
V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 || V->type == GGML_TYPE_Q8_0);
|
||||
}
|
||||
if (K->ne[0] == 256) {
|
||||
return K->type == V->type && (K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0);
|
||||
}
|
||||
if (K->ne[0] != 128 || V->ne[0] != 128) return false;
|
||||
if ((K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q4_1 || K->type == GGML_TYPE_Q5_0 || K->type == GGML_TYPE_Q5_1 ||
|
||||
K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16) &&
|
||||
(V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 || V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 ||
|
||||
V->type == GGML_TYPE_Q8_0 || V->type == GGML_TYPE_F16)) return true;
|
||||
return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) ||
|
||||
(K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) ||
|
||||
(K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q6_0) ||
|
||||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) ||
|
||||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL);
|
||||
#else
|
||||
if (K->ne[0] == 128) {
|
||||
if (K->type == V->type) {
|
||||
return K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_IQ4_NL;
|
||||
}
|
||||
return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) ||
|
||||
(K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) ||
|
||||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) ||
|
||||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL);
|
||||
}
|
||||
if (K->type != V->type) return false;
|
||||
if (K->ne[0] == 64) {
|
||||
return K->type == GGML_TYPE_F16;
|
||||
}
|
||||
if (K->ne[0] == 256) {
|
||||
return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0;
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -3,3 +3,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_fattn_vec_f32_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);
|
||||
|
||||
@@ -24,7 +24,7 @@ void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
||||
@@ -71,6 +71,12 @@ void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0)
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
|
||||
|
||||
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
#else
|
||||
@@ -83,10 +89,63 @@ void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0)
|
||||
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
|
||||
|
||||
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
on_no_fattn_vec_case(Q->ne[0], V->ne[0]);
|
||||
}
|
||||
|
||||
bool ggml_cuda_fattn_vec_f32_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
|
||||
auto K = dst->src[1];
|
||||
auto V = dst->src[2];
|
||||
if (K->ne[0] != V->ne[0]) {
|
||||
if (K->ne[0] != 192 || V->ne[2] != 128) return false;
|
||||
if (K->type != V->type) return false;
|
||||
return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0;
|
||||
}
|
||||
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
if (K->ne[0] == 64) {
|
||||
return K->type == GGML_TYPE_F16 &&
|
||||
(V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 ||
|
||||
V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 || V->type == GGML_TYPE_Q8_0);
|
||||
}
|
||||
if (K->ne[0] == 256) {
|
||||
return K->type == V->type && (K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0);
|
||||
}
|
||||
if (K->ne[0] != 128 || V->ne[0] != 128) return false;
|
||||
if ((K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q4_1 || K->type == GGML_TYPE_Q5_0 || K->type == GGML_TYPE_Q5_1 ||
|
||||
K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16) &&
|
||||
(V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 || V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 ||
|
||||
V->type == GGML_TYPE_Q8_0 || V->type == GGML_TYPE_F16)) return true;
|
||||
return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) ||
|
||||
(K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) ||
|
||||
(K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q6_0) ||
|
||||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) ||
|
||||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL);
|
||||
#else
|
||||
if (K->ne[0] == 128) {
|
||||
if (K->type == V->type) {
|
||||
return K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_IQ4_NL;
|
||||
}
|
||||
return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) ||
|
||||
(K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) ||
|
||||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) ||
|
||||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL);
|
||||
}
|
||||
if (K->type != V->type) return false;
|
||||
if (K->ne[0] == 64) {
|
||||
return K->type == GGML_TYPE_F16;
|
||||
}
|
||||
if (K->ne[0] == 256) {
|
||||
return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0;
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -3,3 +3,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_fattn_wmma_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);
|
||||
|
||||
@@ -165,3 +165,9 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_fattn_wmma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
|
||||
auto K = dst->src[1];
|
||||
auto V = dst->src[2];
|
||||
if (K->ne[0] != V->ne[0]) return K->ne[0] == 192 && V->ne[0] == 128;
|
||||
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256;
|
||||
}
|
||||
|
||||
@@ -91,12 +91,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0);
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||
//if (precision == GGML_PREC_DEFAULT) {
|
||||
// ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
//} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
//}
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -125,3 +121,64 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
//ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
}
|
||||
|
||||
bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const int32_t n_swa = KQV->op_params[4];
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
return precision == GGML_PREC_DEFAULT ? ggml_cuda_fattn_vec_f16_is_supported(ctx, dst)
|
||||
: ggml_cuda_fattn_vec_f32_is_supported(ctx, dst);
|
||||
}
|
||||
|
||||
if (!fast_fp16_available(cc)) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst);
|
||||
} else {
|
||||
return ggml_cuda_fattn_tile_f32_is_supported(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
if (!fp16_mma_available(cc)) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
return ggml_cuda_fattn_vec_f16_is_supported(ctx, dst);
|
||||
} else {
|
||||
return ggml_cuda_fattn_tile_f16_is_supported(ctx, dst);
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst);
|
||||
} else {
|
||||
return ggml_cuda_fattn_tile_f32_is_supported(ctx, dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
||||
// So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster.
|
||||
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
|
||||
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0);
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
return ggml_cuda_fattn_tile_f32_is_supported(ctx, dst);
|
||||
}
|
||||
|
||||
if (new_mma_available(cc) && Q->ne[0] == 576) {
|
||||
return V->ne[0] == 512;
|
||||
}
|
||||
|
||||
if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) {
|
||||
return ggml_cuda_fattn_wmma_f16_is_supported(ctx, dst);
|
||||
}
|
||||
|
||||
return ggml_cuda_fattn_mma_f16_is_supported(ctx, dst);
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);
|
||||
|
||||
Reference in New Issue
Block a user