mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Fix GLM-4.5 attention (#700)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -96,7 +96,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||
const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
|
||||
[[maybe_unused]] const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_K = nb11 / sizeof(half);
|
||||
|
||||
@@ -78,7 +78,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const float use_gqa_opt = mask && max_bias == 0.0f;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
@@ -88,12 +88,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio == 4) {
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio == 2) {
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
@@ -508,7 +508,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies;
|
||||
const bool can_use_vector_kernel = Q->ne[0] % (2*WARP_SIZE) == 0;
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
|
||||
Reference in New Issue
Block a user