Attn logit softcapping (for eager attn)

This commit is contained in:
turboderp
2024-07-06 07:13:58 +02:00
parent 6095c0eb6e
commit 01ce7bbb6e
2 changed files with 9 additions and 4 deletions

View File

@@ -766,7 +766,10 @@ class ExLlamaV2Attention(ExLlamaV2Module):
attn_weights *= self.scaling
attn_mask = attn_params.get_attn_mask(attn_weights.device)
if attn_mask is not None: attn_weights = attn_weights + attn_mask
if cfg.attn_logit_softcapping:
ext_c.softcap_(attn_weights, cfg.attn_logit_softcapping)
if attn_mask is not None:
attn_weights = attn_weights + attn_mask
if self.sliding_window and k_states.shape[-1] >= self.sliding_window:
attn_weights = attn_weights[:, :, :, -self.sliding_window:]
v_states = v_states[:, :, -self.sliding_window:, :]

View File

@@ -102,6 +102,7 @@ class ExLlamaV2Config:
use_qk_norm: bool
query_pre_attn_scalar: float | None
final_logit_softcapping: float | None
attn_logit_softcapping: float | None
sliding_window: int
checkpoint_fused_mlp: bool
@@ -165,9 +166,9 @@ class ExLlamaV2Config:
# Load generation_config.json
self.generation_config_path = os.path.join(self.model_dir, "generation_config.json")
if os.path.exists(self.generation_config_path):
with open(self.generation_config_path, encoding = "utf8") as f:
generation_config_path = os.path.join(self.model_dir, "generation_config.json")
if os.path.exists(generation_config_path):
with open(generation_config_path, encoding = "utf8") as f:
gen_config = json.load(f)
self.generation_config = {}
try:
@@ -247,6 +248,7 @@ class ExLlamaV2Config:
else:
self.scale_depth = scale_depth / math.sqrt(self.num_hidden_layers)
self.attn_logit_softcapping = read(read_config, float, "attn_logit_softcapping", None)
self.final_logit_softcapping = read(read_config, float, "final_logit_softcapping", None)
# Positional embeddings