mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Support (alternating) SWA
This commit is contained in:
@@ -116,6 +116,9 @@ class ExLlamaV2ArchParams:
|
||||
self.norm_key_1_post = None
|
||||
self.norm_key_2_post = None
|
||||
|
||||
self.swa = False
|
||||
self.alternating_swa = False
|
||||
|
||||
self.fused_qkv_altpack = False
|
||||
|
||||
# Mistral
|
||||
|
||||
@@ -15,6 +15,7 @@ from exllamav2.architecture import RopeStyle
|
||||
import math
|
||||
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
|
||||
import torch.nn.functional as F
|
||||
import inspect
|
||||
# from line_profiler import profile
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
|
||||
|
||||
has_flash_attn = False
|
||||
has_flash_attn_with_paged = False
|
||||
has_flash_attn_with_window = False
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
@@ -45,6 +47,9 @@ try:
|
||||
has_flash_attn = True
|
||||
has_flash_attn_with_paged = True
|
||||
|
||||
has_flash_attn_with_window = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||
|
||||
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
@@ -277,7 +282,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
key: str,
|
||||
layer_idx: int,
|
||||
has_norm: bool = True,
|
||||
has_residual: bool = True):
|
||||
has_residual: bool = True,
|
||||
sliding_window: int = 0):
|
||||
|
||||
super().__init__(model, key)
|
||||
|
||||
@@ -337,6 +343,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
else:
|
||||
self.scaling = 1 / math.sqrt(cfg.head_dim)
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
|
||||
def numel(self) -> int:
|
||||
|
||||
@@ -682,6 +690,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
# block_table = block_table,
|
||||
# causal = True
|
||||
# )
|
||||
window_size = -1 if not self.sliding_window else self.sliding_window
|
||||
|
||||
attn_output, _ = flash_attn_cuda.fwd_kvcache(
|
||||
q, k_cache, v_cache, k, v,
|
||||
cache_seqlens_a,
|
||||
@@ -692,7 +702,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
None,
|
||||
self.scaling,
|
||||
True,
|
||||
-1, -1,
|
||||
window_size, window_size,
|
||||
True,
|
||||
0,
|
||||
)
|
||||
@@ -733,6 +743,10 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
k_states = self.repeat_kv(k_states, cfg.num_key_value_groups)
|
||||
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
|
||||
|
||||
if self.sliding_window and k_states.shape[2] >= self.sliding_window:
|
||||
k_states = k_states[:, :, -self.sliding_window:, :]
|
||||
v_states = v_states[:, :, -self.sliding_window:, :]
|
||||
|
||||
attn_mask_lr = causal_lower_right(q_len, k_states.shape[2])
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q_states,
|
||||
@@ -750,7 +764,13 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
attn_weights = torch.matmul(q_states, k_states)
|
||||
|
||||
attn_weights *= self.scaling
|
||||
attn_mask = attn_params.get_attn_mask(attn_weights.device)
|
||||
|
||||
if attn_mask is not None: attn_weights = attn_weights + attn_mask
|
||||
if self.sliding_window and k_states.shape[-1] >= self.sliding_window:
|
||||
attn_weights = attn_weights[:, :, :, -self.sliding_window:]
|
||||
v_states = v_states[:, :, -self.sliding_window:, :]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16)
|
||||
|
||||
v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
|
||||
@@ -763,13 +783,20 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg):
|
||||
|
||||
assert has_flash_attn_with_window or not self.sliding_window, \
|
||||
"Installed version of flash-attn does not support sliding window"
|
||||
|
||||
flash_kwargs = {
|
||||
"window_size": (self.sliding_window, self.sliding_window)
|
||||
} if self.sliding_window else {}
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
q_states,
|
||||
k_states,
|
||||
v_states,
|
||||
causal = True
|
||||
causal = True,
|
||||
softmax_scale = self.scaling,
|
||||
**flash_kwargs
|
||||
)
|
||||
attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
|
||||
return attn_output
|
||||
@@ -777,6 +804,9 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
def _attn_xformers(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg):
|
||||
|
||||
assert not self.sliding_window, \
|
||||
"Sliding window not currently supported for xformers"
|
||||
|
||||
# xformers memory_efficient_attention, could be beneficial if your device's architecture is less than <sm_80
|
||||
# xformer does not expand the kv automatically, we need to do it manually. The efficiency between
|
||||
# xformers.memory_efficient_attention and flash_attn in >sm_80 are almost the same. But the martix operation
|
||||
|
||||
@@ -102,6 +102,7 @@ class ExLlamaV2Config:
|
||||
use_qk_norm: bool
|
||||
query_pre_attn_scalar: float | None
|
||||
final_logit_softcapping: float | None
|
||||
sliding_window: int
|
||||
|
||||
checkpoint_fused_mlp: bool
|
||||
|
||||
@@ -259,6 +260,8 @@ class ExLlamaV2Config:
|
||||
"n_positions"], 2048)
|
||||
self.original_max_seq_len = self.max_seq_len
|
||||
|
||||
self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0)
|
||||
|
||||
rs = read(read_config, dict, "rope_scaling", None)
|
||||
if rs:
|
||||
scaling_type = rs.get("type", None)
|
||||
|
||||
@@ -227,7 +227,13 @@ class ExLlamaV2:
|
||||
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx)
|
||||
self.modules += [pd]
|
||||
else:
|
||||
attn = ExLlamaV2Attention(self, layer_key, layer_idx)
|
||||
if self.config.arch.alternating_swa:
|
||||
swa = self.config.sliding_window if not bool(layer_idx % 2) else 0
|
||||
elif self.config.arch.swa:
|
||||
swa = self.config.sliding_window
|
||||
else:
|
||||
swa = 0
|
||||
attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa)
|
||||
if self.config.arch.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
|
||||
else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx)
|
||||
self.modules += [attn, mlp]
|
||||
|
||||
Reference in New Issue
Block a user