mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 19:10:03 +00:00
FlashMLA-3 for DeepSeek models on CUDA (#386)
* CUDA WIP: support for FlashMLA-3 * Much better The issue was that I did not change the number of warps used for 3D matrix multiplications (wk_b * kv_cache, MoE), so we ended up using 4 warps for TG. By going to 1 warp in these cases, we get a significant boost in TG performance (tested with DeepSeek-Lite) * Sadly, the previous commit was wrong * Finalizing * Also add these * Minor * Minor tweak --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -3587,6 +3587,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) ||
|
||||
(op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0);
|
||||
}
|
||||
if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) {
|
||||
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
|
||||
int gqa = op->src[0]->ne[2]/op->src[1]->ne[2];
|
||||
return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0);
|
||||
}
|
||||
if (op->src[1]->ne[0] > 256) {
|
||||
return false;
|
||||
}
|
||||
|
||||
1705
ggml/src/ggml-cuda/fattn-new-mma.cu
Normal file
1705
ggml/src/ggml-cuda/fattn-new-mma.cu
Normal file
File diff suppressed because it is too large
Load Diff
3
ggml/src/ggml-cuda/fattn-new-mma.cuh
Normal file
3
ggml/src/ggml-cuda/fattn-new-mma.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "fattn-vec-f32.cuh"
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn-mma-f16.cuh"
|
||||
#include "fattn-new-mma.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
@@ -517,12 +518,28 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
return;
|
||||
}
|
||||
|
||||
// We need this because I haven't adapted the MMA kernels to work for different
|
||||
//
|
||||
// It turns out the new new MMA implementation is slower than the
|
||||
// previous MMA implementation.
|
||||
// Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512,
|
||||
// so no other implementation works.
|
||||
//
|
||||
if (new_mma_available(cc) && Q->ne[0] == 576) {
|
||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
// We need this because I haven't adapted new MMA kernels to work for different
|
||||
// K and V head sizes.
|
||||
if (K->ne[0] != V->ne[0]) {
|
||||
// We also need it if the new MMA is not available
|
||||
//
|
||||
if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) {
|
||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
// As mentioned above, the new new MMA is slower than then the new MMA.
|
||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
//ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_y>
|
||||
template <ggml_type type, int ncols_y, int nwarps>
|
||||
static __device__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
@@ -73,10 +73,8 @@ static __device__ void mul_mat_vec_q(
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||
constexpr int nwarps = 1;
|
||||
constexpr int rows_per_cuda_block = 1;
|
||||
#else
|
||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||
|
||||
@@ -139,10 +137,10 @@ static __device__ void mul_mat_vec_q(
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_y>
|
||||
template <ggml_type type, int ncols_y, int nwarps>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
|
||||
@@ -153,11 +151,11 @@ static __global__ void mul_mat_vec_q(
|
||||
const char * cx = (const char *)vx + i02*nb02;
|
||||
const char * cy = (const char *)vy + i2*nb12;
|
||||
char * cdst = (char *)dst + i2*nb2;
|
||||
mul_mat_vec_q<type, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_cuda(
|
||||
template <ggml_type type, int nwarps>
|
||||
static void mul_mat_vec_q_cuda_T(
|
||||
const void * vx, const void * vy, float * dst, const char * ids_data,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) {
|
||||
@@ -167,61 +165,61 @@ static void mul_mat_vec_q_cuda(
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
|
||||
int64_t nwarps = 1;
|
||||
int64_t rows_per_cuda_block = 1;
|
||||
int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ?
|
||||
ncols_y < 4 ? 1 : 2 : 1;
|
||||
|
||||
if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
switch(ncols_y) {
|
||||
case 1:
|
||||
nwarps = 4;
|
||||
rows_per_cuda_block = 1;
|
||||
break;
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
nwarps = 4;
|
||||
rows_per_cuda_block = 2;
|
||||
break;
|
||||
case 5:
|
||||
case 6:
|
||||
case 7:
|
||||
case 8:
|
||||
nwarps = 2;
|
||||
rows_per_cuda_block = 2;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
//if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
// switch(ncols_y) {
|
||||
// case 1:
|
||||
// nwarps = 4;
|
||||
// rows_per_cuda_block = 1;
|
||||
// break;
|
||||
// case 2:
|
||||
// case 3:
|
||||
// case 4:
|
||||
// nwarps = 4;
|
||||
// rows_per_cuda_block = 2;
|
||||
// break;
|
||||
// case 5:
|
||||
// case 6:
|
||||
// case 7:
|
||||
// case 8:
|
||||
// nwarps = 2;
|
||||
// rows_per_cuda_block = 2;
|
||||
// break;
|
||||
// default:
|
||||
// GGML_ABORT("fatal error");
|
||||
// break;
|
||||
// }
|
||||
//}
|
||||
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
|
||||
const dim3 block_nums(nblocks, ne2, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
switch (ncols_y) {
|
||||
case 1:
|
||||
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 2:
|
||||
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 2, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 3, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 4:
|
||||
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 4, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 5:
|
||||
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 5, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 6, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 7, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 8, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -229,6 +227,31 @@ static void mul_mat_vec_q_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_cuda(
|
||||
const void * vx, const void * vy, float * dst, const char * ids_data,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) {
|
||||
int nwarps = 1;
|
||||
int id = ggml_cuda_get_device();
|
||||
if (ne2 < 2 && ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
}
|
||||
switch (nwarps) {
|
||||
case 1:
|
||||
mul_mat_vec_q_cuda_T<type, 1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
|
||||
ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
break;
|
||||
case 2:
|
||||
mul_mat_vec_q_cuda_T<type, 2>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
|
||||
ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
break;
|
||||
default:
|
||||
mul_mat_vec_q_cuda_T<type, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
|
||||
ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst, const char * ids_data,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
|
||||
Reference in New Issue
Block a user