FlashMLA-3 for DeepSeek models on CUDA (#386)

* CUDA WIP: support for FlashMLA-3

* Much better

The issue was that I did not change the number of warps
used for 3D matrix multiplications (wk_b * kv_cache, MoE),
so we ended up using 4 warps for TG. By going to 1 warp
in these cases, we get a significant boost in TG performance
(tested with DeepSeek-Lite)

* Sadly, the previous commit was wrong

* Finalizing

* Also add these

* Minor

* Minor tweak

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-05-07 17:38:22 +03:00
committed by GitHub
parent 5436acdb6c
commit 92ceda1d06
5 changed files with 1797 additions and 44 deletions

View File

@@ -3587,6 +3587,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) ||
(op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0);
}
if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) {
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
int gqa = op->src[0]->ne[2]/op->src[1]->ne[2];
return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0);
}
if (op->src[1]->ne[0] > 256) {
return false;
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -13,6 +13,7 @@
#include "fattn-vec-f32.cuh"
#include "fattn-wmma-f16.cuh"
#include "fattn-mma-f16.cuh"
#include "fattn-new-mma.cuh"
#include "fattn.cuh"
#include <cstdint>
@@ -517,12 +518,28 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
// We need this because I haven't adapted the MMA kernels to work for different
//
// It turns out the new new MMA implementation is slower than the
// previous MMA implementation.
// Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512,
// so no other implementation works.
//
if (new_mma_available(cc) && Q->ne[0] == 576) {
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
return;
}
//
// We need this because I haven't adapted new MMA kernels to work for different
// K and V head sizes.
if (K->ne[0] != V->ne[0]) {
// We also need it if the new MMA is not available
//
if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
// As mentioned above, the new new MMA is slower than then the new MMA.
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
//ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
}

View File

@@ -61,7 +61,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
}
}
template <ggml_type type, int ncols_y>
template <ggml_type type, int ncols_y, int nwarps>
static __device__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -73,10 +73,8 @@ static __device__ void mul_mat_vec_q(
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
@@ -139,10 +137,10 @@ static __device__ void mul_mat_vec_q(
}
}
template <ggml_type type, int ncols_y>
template <ggml_type type, int ncols_y, int nwarps>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
@@ -153,11 +151,11 @@ static __global__ void mul_mat_vec_q(
const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
char * cdst = (char *)dst + i2*nb2;
mul_mat_vec_q<type, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
template <ggml_type type>
static void mul_mat_vec_q_cuda(
template <ggml_type type, int nwarps>
static void mul_mat_vec_q_cuda_T(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) {
@@ -167,61 +165,61 @@ static void mul_mat_vec_q_cuda(
int id = ggml_cuda_get_device();
int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1;
int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ?
ncols_y < 4 ? 1 : 2 : 1;
if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
nwarps = 4;
rows_per_cuda_block = 1;
break;
case 2:
case 3:
case 4:
nwarps = 4;
rows_per_cuda_block = 2;
break;
case 5:
case 6:
case 7:
case 8:
nwarps = 2;
rows_per_cuda_block = 2;
break;
default:
GGML_ABORT("fatal error");
break;
}
}
//if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
// switch(ncols_y) {
// case 1:
// nwarps = 4;
// rows_per_cuda_block = 1;
// break;
// case 2:
// case 3:
// case 4:
// nwarps = 4;
// rows_per_cuda_block = 2;
// break;
// case 5:
// case 6:
// case 7:
// case 8:
// nwarps = 2;
// rows_per_cuda_block = 2;
// break;
// default:
// GGML_ABORT("fatal error");
// break;
// }
//}
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
const dim3 block_nums(nblocks, ne2, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
switch (ncols_y) {
case 1:
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 2:
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
mul_mat_vec_q<type, 2, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 3:
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
mul_mat_vec_q<type, 3, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 4:
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
mul_mat_vec_q<type, 4, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 5:
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
mul_mat_vec_q<type, 5, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 6:
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
mul_mat_vec_q<type, 6, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 7:
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
mul_mat_vec_q<type, 7, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 8:
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
mul_mat_vec_q<type, 8, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
default:
GGML_ABORT("fatal error");
@@ -229,6 +227,31 @@ static void mul_mat_vec_q_cuda(
}
}
template <ggml_type type>
static void mul_mat_vec_q_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) {
int nwarps = 1;
int id = ggml_cuda_get_device();
if (ne2 < 2 && ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
nwarps = ncols_y <= 4 ? 4 : 2;
}
switch (nwarps) {
case 1:
mul_mat_vec_q_cuda_T<type, 1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case 2:
mul_mat_vec_q_cuda_T<type, 2>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
default:
mul_mat_vec_q_cuda_T<type, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
ne2, nb02, nb12, nb2, ids_nb0, stream);
}
}
static void mul_mat_vec_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,