mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
CUDA FA
This commit is contained in:
@@ -378,6 +378,47 @@ struct fattn_mma_f16_config<576, 512> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fattn_mma_f16_config<1088, 1024> {
|
||||
static constexpr int nbatch_fa = 32;
|
||||
static constexpr int nwarps_max = 8;
|
||||
static constexpr bool Q_in_reg = false;
|
||||
static constexpr int nstages_target = 1;
|
||||
|
||||
static int get_nbatch_K2_host([[maybe_unused]] const int cc, [[maybe_unused]] const int ncols) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device([[maybe_unused]] int ncols) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host([[maybe_unused]] const int cc, [[maybe_unused]] const int ncols) {
|
||||
return 64;
|
||||
//if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
|
||||
// return ncols <= 16 ? 64 : 128;
|
||||
//}
|
||||
//return ncols <= 16 ? 256 : 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device([[maybe_unused]] int ncols) {
|
||||
return 64;
|
||||
//#if __CUDA_ARCH__ == CC_TURING
|
||||
// return ncols <= 16 ? 64 : 128;
|
||||
//#else
|
||||
// return ncols <= 16 ? 256 : 128;
|
||||
//#endif // __CUDA_ARCH__ == CC_TURING
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64; //128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 64; //128;
|
||||
}
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
// The compiler is always able to unroll loops if they contain continue expressions.
|
||||
@@ -2165,6 +2206,20 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (Q->ne[0] == 1088 && K->ne[0] == 1088 && V->ne[0] == 1024) {
|
||||
GGML_ASSERT(gqa_ratio == 20);
|
||||
if (Q->ne[1] <= 4) {
|
||||
if (ggml_cuda_info().devices[ctx.device].cc >= CC_ADA_LOVELACE) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<1088, 1024, 1, 16>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<1088, 1024, 1, 32>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<1088, 1024, 4>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<1088, 1024, 4, 4>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||
if (gqa_ratio == 20 && Q->ne[1] <= 4 && K->ne[1] >= 2048) {
|
||||
if (ggml_cuda_info().devices[ctx.device].cc >= CC_ADA_LOVELACE) {
|
||||
|
||||
@@ -114,7 +114,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
// so no other implementation works.
|
||||
//
|
||||
|
||||
if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
|
||||
if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 1088 && V->ne[0] == 1024) ||
|
||||
(K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
|
||||
//printf("Using ggml_cuda_flash_attn_ext_mma_new\n");
|
||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
return;
|
||||
@@ -185,8 +186,8 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te
|
||||
return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst);
|
||||
}
|
||||
|
||||
if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
|
||||
if (Q->ne[0] == 576) {
|
||||
if (new_mma_available(cc) && (Q->ne[0] == 576 || Q->ne[0] == 1088 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
|
||||
if (Q->ne[0] == 576 || Q->ne[0] == 1088) {
|
||||
int gqa_ratio = Q->ne[2]/K->ne[2];
|
||||
return (gqa_ratio % 4) == 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user