CUDA FA WIP - It actually works!

No TG yet, but for PP I can run FA with fp16 cache and it gets
the same answer.
This commit is contained in:
Iwan Kawrakow
2025-03-03 18:52:34 +02:00
parent 3c72a7b47c
commit 0a6542b503

View File

@@ -343,6 +343,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * V = dst->src[2];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;