mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-03 21:09:39 +00:00
DeepSeek FA optimizations (#929)
* Use new-new-mma also for MLA=3, and use mask bounds This gives us ~25% better PP at 32k tokens compared to main * This seems better --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -14,6 +14,46 @@
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
typedef void (* fattn_new_mma_kernel_t)(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const float softcap,
|
||||
const uint32_t n_head_log2,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3);
|
||||
|
||||
|
||||
typedef tile<16, 8, half2> tile_A;
|
||||
typedef tile< 8, 8, half2> tile_B;
|
||||
typedef tile<16, 8, half2> tile_B_16;
|
||||
@@ -43,37 +83,37 @@ struct fattn_mma_f16_config;
|
||||
// Perhaps the 256 head size needs a closer look
|
||||
// to see if this implementation is better.
|
||||
//
|
||||
template <>
|
||||
struct fattn_mma_f16_config< 64, 64> {
|
||||
static constexpr int nbatch_fa = 64;
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
};
|
||||
//template <>
|
||||
//struct fattn_mma_f16_config< 64, 64> {
|
||||
// static constexpr int nbatch_fa = 64;
|
||||
// static constexpr int nwarps_max = 4;
|
||||
// static constexpr bool Q_in_reg = true;
|
||||
// static constexpr int nstages_target = 2;
|
||||
//
|
||||
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//};
|
||||
//
|
||||
//template <>
|
||||
//struct fattn_mma_f16_config< 80, 80> {
|
||||
@@ -243,6 +283,38 @@ struct fattn_mma_f16_config< 64, 64> {
|
||||
// }
|
||||
//};
|
||||
|
||||
template <>
|
||||
struct fattn_mma_f16_config<192, 128> {
|
||||
static constexpr int nbatch_fa = 64;
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 1;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fattn_mma_f16_config<576, 512> {
|
||||
static constexpr int nbatch_fa = 32;
|
||||
@@ -1287,6 +1359,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -1377,8 +1450,11 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
if (KV_max) {
|
||||
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[jt] / c::nbatch_fa);
|
||||
}
|
||||
|
||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
@@ -1417,8 +1493,11 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
int kb0_start_kernel = kb0_start * kb_niter;
|
||||
int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
if (KV_max) {
|
||||
kb0_stop_kernel = min(kb0_stop_kernel, KV_max[jt] / c::nbatch_fa);
|
||||
}
|
||||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
@@ -1574,9 +1653,68 @@ static __global__ void flash_attn_combine_results_new(
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_all(int x) {
|
||||
if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) {
|
||||
return __all_sync(0xffffffff, x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
template <int ncols1>
|
||||
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
|
||||
static __global__ void flash_attn_mask_to_KV_max(
|
||||
const half2 * __restrict__ mask, int * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) {
|
||||
const int ne31 = gridDim.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int sequence = blockIdx.y;
|
||||
const int jt = blockIdx.x;
|
||||
|
||||
mask += sequence*s33 + jt*ncols1*s31;
|
||||
|
||||
__shared__ int buf_iw[WARP_SIZE];
|
||||
if (tid < WARP_SIZE) {
|
||||
buf_iw[tid] = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
|
||||
for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
|
||||
int all_inf = 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols1; ++j) {
|
||||
const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
|
||||
all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
|
||||
}
|
||||
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
if (tid % WARP_SIZE == 0) {
|
||||
buf_iw[tid / WARP_SIZE] = all_inf;
|
||||
}
|
||||
__syncthreads();
|
||||
all_inf = buf_iw[tid % WARP_SIZE];
|
||||
__syncthreads();
|
||||
all_inf = warp_reduce_all(all_inf);
|
||||
|
||||
if (!all_inf) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KV_min_max[sequence*ne31 + jt] = KV_max_sj + FATTN_KQ_STRIDE;
|
||||
}
|
||||
}
|
||||
|
||||
template <int DV, int ncols1, int ncols2>
|
||||
static void launch_fattn_new_mma(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_new_mma_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
@@ -1605,10 +1743,15 @@ static void launch_fattn_new_mma(
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||
const int nsm_actual = ggml_cuda_info().devices[id].nsm;
|
||||
int nsm = 1; while (nsm*2 <= nsm_actual) nsm *= 2;
|
||||
|
||||
if (Q->ne[1] == 1 && K->ne[1] <= 4096 && nsm > 32) nsm /= 2;
|
||||
if (Q->ne[1] >= 32 && K->ne[1] >= 4096) nsm *= 2;
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
ggml_cuda_pool_alloc<int> KV_max(pool);
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
@@ -1675,6 +1818,25 @@ static void launch_fattn_new_mma(
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
// multiple sequences of possibly different lengths.
|
||||
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
const int s31 = mask->nb[1] / sizeof(half2);
|
||||
const int s33 = mask->nb[3] / sizeof(half2);
|
||||
|
||||
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
|
||||
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
|
||||
|
||||
const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
|
||||
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
|
||||
|
||||
KV_max.alloc(ne_KV_max);
|
||||
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
||||
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
@@ -1761,6 +1923,7 @@ static void launch_fattn_new_mma(
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *)sinks->data) : nullptr,
|
||||
KV_max.get(),
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
@@ -1837,7 +2000,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
fattn_kernel_t fattn_kernel;
|
||||
fattn_new_mma_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
||||
@@ -1944,8 +2107,16 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (K->ne[0] == 64 && V->ne[0] == 64) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
|
||||
//if (K->ne[0] == 64 && V->ne[0] == 64) {
|
||||
// ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
|
||||
// return;
|
||||
//}
|
||||
if (K->ne[0] == 192 && V->ne[0] == 128) {
|
||||
GGML_ASSERT(Q->ne[0] == 192);
|
||||
GGML_ASSERT(gqa_ratio == 1);
|
||||
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst);
|
||||
// Reduce compile time
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||
|
||||
@@ -102,7 +102,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
// Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512,
|
||||
// so no other implementation works.
|
||||
//
|
||||
if (new_mma_available(cc) && Q->ne[0] == 576) {
|
||||
if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128))) {
|
||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
return;
|
||||
}
|
||||
@@ -172,8 +172,8 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te
|
||||
return ggml_cuda_fattn_tile_f32_is_supported(ctx, dst);
|
||||
}
|
||||
|
||||
if (new_mma_available(cc) && Q->ne[0] == 576) {
|
||||
return V->ne[0] == 512;
|
||||
if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) {
|
||||
|
||||
Reference in New Issue
Block a user