Fix GLM-4.5 attention (#700)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-08-17 14:31:03 +03:00
committed by GitHub
parent d4d017766e
commit 1ca612375e
2 changed files with 5 additions and 5 deletions

View File

@@ -96,7 +96,7 @@ static __global__ void flash_attn_ext_f16(
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
[[maybe_unused]] const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
const int stride_Q = nb01 / sizeof(float);
const int stride_K = nb11 / sizeof(half);

View File

@@ -78,7 +78,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const float use_gqa_opt = mask && max_bias == 0.0f;
const bool use_gqa_opt = mask && max_bias == 0.0f;
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
@@ -88,12 +88,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
return;
}
if (use_gqa_opt && gqa_ratio == 4) {
if (use_gqa_opt && gqa_ratio % 4 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio == 2) {
if (use_gqa_opt && gqa_ratio % 2 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
return;
}
@@ -508,7 +508,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies;
const bool can_use_vector_kernel = Q->ne[0] % (2*WARP_SIZE) == 0;
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (precision == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);