mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 14:14:32 +00:00
WIP
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn.cuh"
|
||||
#include "fattn-compat.cuh"
|
||||
//#include "fattn-prev-mma-f16-interface.cuh"
|
||||
|
||||
template <int DKQ, int DV, int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -87,19 +88,24 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
|
||||
if constexpr (DKQ == 128 && DV == 128) {
|
||||
if (use_gqa_opt && gqa_ratio == 12) {
|
||||
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 1) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 1, 16>(ctx, dst);
|
||||
if (Q->ne[1] <= 8) {
|
||||
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] == 1) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 1, 16>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 2, 16>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 4, 16>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 2, 16>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 4, 16>(ctx, dst);
|
||||
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 16>(ctx, dst);
|
||||
return;
|
||||
//else if (ggml_cuda_fattn_prev_mma_f16_is_supported(ctx, dst)) {
|
||||
// ggml_cuda_flash_attn_ext_prev_mma_f16(ctx, dst);
|
||||
// return;
|
||||
//}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -485,7 +491,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const int32_t n_swa = KQV->op_params[4];
|
||||
|
||||
ggml_tensor local_dst, Kl, Vl, Ml;
|
||||
|
||||
Reference in New Issue
Block a user