mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
CUDA WIP: support for FlashMLA-3
This commit is contained in:
@@ -3587,6 +3587,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) ||
|
||||
(op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0);
|
||||
}
|
||||
if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) {
|
||||
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
|
||||
int gqa = op->src[0]->ne[2]/op->src[1]->ne[2];
|
||||
return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0);
|
||||
}
|
||||
if (op->src[1]->ne[0] > 256) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "fattn-vec-f32.cuh"
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn-mma-f16.cuh"
|
||||
#include "fattn-new-mma.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
@@ -519,10 +520,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
|
||||
// We need this because I haven't adapted the MMA kernels to work for different
|
||||
// K and V head sizes.
|
||||
if (K->ne[0] != V->ne[0]) {
|
||||
//if (K->ne[0] != V->ne[0]) {
|
||||
if (!new_mma_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
//ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user