diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 68c12dd3..b2285fdd 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -1746,6 +1746,9 @@ static void launch_fattn_new_mma( const int nsm_actual = ggml_cuda_info().devices[id].nsm; int nsm = 1; while (nsm*2 <= nsm_actual) nsm *= 2; + if (Q->ne[1] == 1 && K->ne[1] <= 4096 && nsm > 32) nsm /= 2; + if (Q->ne[1] >= 32 && K->ne[1] >= 4096) nsm *= 2; + ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); ggml_cuda_pool_alloc KV_max(pool);