CUDA: add head size of 64 to new mma

Haven't turned it on yet, but observe slightly better PP and slightly
worse TG performance with that.
This commit is contained in:
Iwan Kawrakow
2025-08-11 11:10:45 +03:00
parent 3cd7e5c9b4
commit 464b8fc03b

View File

@@ -43,37 +43,37 @@ struct fattn_mma_f16_config;
// Perhaps the 256 head size needs a closer look
// to see if this implementation is better.
//
//template <>
//struct fattn_mma_f16_config< 64, 64> {
// static constexpr int nbatch_fa = 64;
// static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2;
//
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
// return 32;
// }
//
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
// return 32;
// }
//
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
// return 32;
// }
//
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
// return 32;
// }
//
// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
// return 32;
// }
//
// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
// return 32;
// }
//};
template <>
struct fattn_mma_f16_config< 64, 64> {
static constexpr int nbatch_fa = 64;
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 32;
}
};
//
//template <>
//struct fattn_mma_f16_config< 80, 80> {
@@ -576,7 +576,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
float KQ_rowsum_add[cols_per_thread] = {0.0f};
if constexpr (ntiles == 1) {
if constexpr (ncols2 > 1 || mask_h2) {
if (ncols2 > 1 || mask_h2) {
#pragma unroll
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
@@ -1905,6 +1905,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (use_gqa_opt && gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 16>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return;
@@ -1930,8 +1935,6 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -1940,6 +1943,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (K->ne[0] == 64 && V->ne[0] == 64) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);