Support (alternating) SWA

This commit is contained in:
turboderp
2024-07-04 05:36:07 +02:00
parent 84d00cbbc0
commit 66c4a9c849
4 changed files with 46 additions and 4 deletions

View File

@@ -116,6 +116,9 @@ class ExLlamaV2ArchParams:
self.norm_key_1_post = None
self.norm_key_2_post = None
self.swa = False
self.alternating_swa = False
self.fused_qkv_altpack = False
# Mistral

View File

@@ -15,6 +15,7 @@ from exllamav2.architecture import RopeStyle
import math
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
import torch.nn.functional as F
import inspect
# from line_profiler import profile
from typing import TYPE_CHECKING
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
has_flash_attn = False
has_flash_attn_with_paged = False
has_flash_attn_with_window = False
try:
import flash_attn
@@ -45,6 +47,9 @@ try:
has_flash_attn = True
has_flash_attn_with_paged = True
has_flash_attn_with_window = "window_size" in list(inspect.signature(flash_attn_func).parameters)
except ModuleNotFoundError:
pass
@@ -277,7 +282,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
key: str,
layer_idx: int,
has_norm: bool = True,
has_residual: bool = True):
has_residual: bool = True,
sliding_window: int = 0):
super().__init__(model, key)
@@ -337,6 +343,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
else:
self.scaling = 1 / math.sqrt(cfg.head_dim)
self.sliding_window = sliding_window
def numel(self) -> int:
@@ -682,6 +690,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
# block_table = block_table,
# causal = True
# )
window_size = -1 if not self.sliding_window else self.sliding_window
attn_output, _ = flash_attn_cuda.fwd_kvcache(
q, k_cache, v_cache, k, v,
cache_seqlens_a,
@@ -692,7 +702,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
None,
self.scaling,
True,
-1, -1,
window_size, window_size,
True,
0,
)
@@ -733,6 +743,10 @@ class ExLlamaV2Attention(ExLlamaV2Module):
k_states = self.repeat_kv(k_states, cfg.num_key_value_groups)
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
if self.sliding_window and k_states.shape[2] >= self.sliding_window:
k_states = k_states[:, :, -self.sliding_window:, :]
v_states = v_states[:, :, -self.sliding_window:, :]
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
attn_output = F.scaled_dot_product_attention(
q_states,
@@ -750,7 +764,13 @@ class ExLlamaV2Attention(ExLlamaV2Module):
attn_weights = torch.matmul(q_states, k_states)
attn_weights *= self.scaling
attn_mask = attn_params.get_attn_mask(attn_weights.device)
if attn_mask is not None: attn_weights = attn_weights + attn_mask
if self.sliding_window and k_states.shape[-1] >= self.sliding_window:
attn_weights = attn_weights[:, :, :, -self.sliding_window:]
v_states = v_states[:, :, -self.sliding_window:, :]
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
@@ -763,13 +783,20 @@ class ExLlamaV2Attention(ExLlamaV2Module):
def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg):
assert has_flash_attn_with_window or not self.sliding_window, \
"Installed version of flash-attn does not support sliding window"
flash_kwargs = {
"window_size": (self.sliding_window, self.sliding_window)
} if self.sliding_window else {}
attn_output = flash_attn_func(
q_states,
k_states,
v_states,
causal = True
causal = True,
softmax_scale = self.scaling,
**flash_kwargs
)
attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
return attn_output
@@ -777,6 +804,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
def _attn_xformers(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg):
assert not self.sliding_window, \
"Sliding window not currently supported for xformers"
# xformers memory_efficient_attention, could be beneficial if your device's architecture is less than <sm_80
# xformer does not expand the kv automatically, we need to do it manually. The efficiency between
# xformers.memory_efficient_attention and flash_attn in >sm_80 are almost the same. But the martix operation

View File

@@ -102,6 +102,7 @@ class ExLlamaV2Config:
use_qk_norm: bool
query_pre_attn_scalar: float | None
final_logit_softcapping: float | None
sliding_window: int
checkpoint_fused_mlp: bool
@@ -259,6 +260,8 @@ class ExLlamaV2Config:
"n_positions"], 2048)
self.original_max_seq_len = self.max_seq_len
self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0)
rs = read(read_config, dict, "rope_scaling", None)
if rs:
scaling_type = rs.get("type", None)

View File

@@ -227,7 +227,13 @@ class ExLlamaV2:
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx)
self.modules += [pd]
else:
attn = ExLlamaV2Attention(self, layer_key, layer_idx)
if self.config.arch.alternating_swa:
swa = self.config.sliding_window if not bool(layer_idx % 2) else 0
elif self.config.arch.swa:
swa = self.config.sliding_window
else:
swa = 0
attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa)
if self.config.arch.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx)
self.modules += [attn, mlp]