mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
Gigachat: CUDA FA (needs 192 x 192 for MLA = 3)
This commit is contained in:
@@ -43,6 +43,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
|
|||||||
case 128:
|
case 128:
|
||||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case 192:
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, ncols2>(ctx, dst);
|
||||||
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
||||||
break;
|
break;
|
||||||
@@ -88,5 +91,5 @@ bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_con
|
|||||||
auto K = dst->src[1];
|
auto K = dst->src[1];
|
||||||
auto V = dst->src[1];
|
auto V = dst->src[1];
|
||||||
if (K->ne[0] != V->ne[0]) return false;
|
if (K->ne[0] != V->ne[0]) return false;
|
||||||
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256;
|
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 192 || K->ne[0] == 256;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -315,6 +315,38 @@ struct fattn_mma_f16_config<192, 128> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct fattn_mma_f16_config<192, 192> {
|
||||||
|
static constexpr int nbatch_fa = 64;
|
||||||
|
static constexpr int nwarps_max = 4;
|
||||||
|
static constexpr bool Q_in_reg = true;
|
||||||
|
static constexpr int nstages_target = 1;
|
||||||
|
|
||||||
|
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||||
|
return 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||||
|
return 32;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct fattn_mma_f16_config<576, 512> {
|
struct fattn_mma_f16_config<576, 512> {
|
||||||
static constexpr int nbatch_fa = 32;
|
static constexpr int nbatch_fa = 32;
|
||||||
@@ -2119,6 +2151,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
|||||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (K->ne[0] == 192 && V->ne[0] == 192) {
|
||||||
|
GGML_ASSERT(Q->ne[0] == 192);
|
||||||
|
GGML_ASSERT(gqa_ratio == 1);
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||||
|
|||||||
Reference in New Issue
Block a user