mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Allow per-layer RoPE theta
This commit is contained in:
@@ -135,7 +135,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
has_norm: bool = True,
|
||||
has_residual: bool = True,
|
||||
sliding_window: int = 0,
|
||||
archparams = None
|
||||
archparams = None,
|
||||
rope_index: int = 0
|
||||
):
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
@@ -149,6 +150,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
self.layer_idx = layer_idx
|
||||
self.has_norm = has_norm
|
||||
self.has_residual = has_residual
|
||||
self.rope_index = rope_index
|
||||
|
||||
self.q_handle = None
|
||||
self.temp_lora_size = 0
|
||||
@@ -503,7 +505,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
sc = attn_params.get_alt_rope_embed(self.device_idx)
|
||||
if not sc:
|
||||
sin, cos = constants.sin, constants.cos
|
||||
sin, cos = constants.sin[self.rope_index], constants.cos[self.rope_index]
|
||||
else:
|
||||
sin, cos = sc
|
||||
|
||||
@@ -769,8 +771,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
for t, heads in [(q[idx], self.num_key_value_groups), (k[idx], 1)]:
|
||||
ext_c.rope_(
|
||||
t,
|
||||
context.sin,
|
||||
context.cos,
|
||||
context.sin[self.rope_index],
|
||||
context.cos[self.rope_index],
|
||||
0,
|
||||
(b - a) * heads,
|
||||
self.head_dim,
|
||||
@@ -1116,7 +1118,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
sc = attn_params.get_alt_rope_embed(self.device_idx)
|
||||
if not sc:
|
||||
sin, cos = constants.sin, constants.cos
|
||||
sin, cos = constants.sin[self.rope_index], constants.cos[self.rope_index]
|
||||
else:
|
||||
sin, cos = sc
|
||||
|
||||
@@ -1304,8 +1306,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
for t, heads in [(q[idx], self.num_key_value_groups), (k[idx], 1)]:
|
||||
ext_c.rope_(
|
||||
t,
|
||||
context.sin,
|
||||
context.cos,
|
||||
context.sin[self.rope_index],
|
||||
context.cos[self.rope_index],
|
||||
past_len,
|
||||
(b - a) * heads,
|
||||
self.head_dim,
|
||||
@@ -1444,7 +1446,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
if self.archparams.rope_style != RopeStyle.NONE:
|
||||
alt_cs = kwargs.get("alt_rope_embedding")
|
||||
cos, sin = alt_cs if alt_cs else (constants.cos, constants.sin)
|
||||
cos, sin = alt_cs if alt_cs else (constants.cos[self.rope_index], constants.sin[self.rope_index])
|
||||
|
||||
sc = attn_params.get_alt_rope_embed(self.device_idx)
|
||||
if sc: sin, cos = sc
|
||||
|
||||
@@ -40,8 +40,8 @@ class ExLlamaV2DeviceContext:
|
||||
scratch_bytes: int
|
||||
scratch_idx: int
|
||||
|
||||
sin: torch.Tensor | None
|
||||
cos: torch.Tensor | None
|
||||
sin: list[torch.Tensor] | None
|
||||
cos: list[torch.Tensor] | None
|
||||
|
||||
scratch: torch.Tensor | None
|
||||
|
||||
@@ -116,35 +116,49 @@ class ExLlamaV2DeviceContext:
|
||||
def prepare_sincos(self):
|
||||
|
||||
device = _torch_device(self.device_idx)
|
||||
|
||||
cfg = self.model.config
|
||||
if self.archparams.rope_style == RopeStyle.NONE:
|
||||
self.sin = torch.zeros((1,), device = device, dtype = torch.half)
|
||||
self.cos = self.sin
|
||||
return
|
||||
|
||||
# RoPE params
|
||||
thetas = [cfg.rotary_embedding_base]
|
||||
if cfg.rotary_embedding_base_alt:
|
||||
thetas.append(cfg.rotary_embedding_base_alt)
|
||||
|
||||
inv_freq, scaling_factor = rope.get_rope_params(device, cfg)
|
||||
self.sin = []
|
||||
self.cos = []
|
||||
|
||||
# Common
|
||||
for theta in thetas:
|
||||
|
||||
scale = cfg.scale_pos_emb or 1.0
|
||||
t = torch.arange(cfg.max_seq_len, device = device, dtype = torch.float32)
|
||||
if scale != 1.0: t /= scale
|
||||
if self.archparams.rope_style == RopeStyle.NONE:
|
||||
sin = torch.zeros((1,), device = device, dtype = torch.half)
|
||||
cos = sin
|
||||
self.sin.append(sin)
|
||||
self.cos.append(cos)
|
||||
break
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
||||
if self.archparams.rope_style == RopeStyle.NEOX:
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
elif self.archparams.rope_style == RopeStyle.GPTJ:
|
||||
emb = torch.repeat_interleave(freqs, 2, dim=-1)
|
||||
else:
|
||||
raise ValueError()
|
||||
# RoPE params
|
||||
|
||||
self.sin = emb.sin()[None, None, :, :]
|
||||
self.cos = emb.cos()[None, None, :, :]
|
||||
if scaling_factor != 1.0:
|
||||
self.sin *= scaling_factor
|
||||
self.cos *= scaling_factor
|
||||
self.sin = self.sin.half()
|
||||
self.cos = self.cos.half()
|
||||
inv_freq, scaling_factor = rope.get_rope_params(device, cfg, theta)
|
||||
|
||||
# Common
|
||||
|
||||
scale = cfg.scale_pos_emb or 1.0
|
||||
t = torch.arange(cfg.max_seq_len, device = device, dtype = torch.float32)
|
||||
if scale != 1.0: t /= scale
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
||||
if self.archparams.rope_style == RopeStyle.NEOX:
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
elif self.archparams.rope_style == RopeStyle.GPTJ:
|
||||
emb = torch.repeat_interleave(freqs, 2, dim=-1)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
sin = emb.sin()[None, None, :, :]
|
||||
cos = emb.cos()[None, None, :, :]
|
||||
if scaling_factor != 1.0:
|
||||
sin *= scaling_factor
|
||||
cos *= scaling_factor
|
||||
sin = sin.half()
|
||||
cos = cos.half()
|
||||
|
||||
self.sin.append(sin)
|
||||
self.cos.append(cos)
|
||||
@@ -106,6 +106,7 @@ class ExLlamaV2:
|
||||
for layer_idx in range(cfg.num_hidden_layers):
|
||||
|
||||
layer_key = cfg.arch.lm_prefix + f"model.layers.{layer_idx}"
|
||||
rope_index = 0
|
||||
|
||||
if cfg.arch.lm.alternating_swa:
|
||||
swa = cfg.sliding_window if (layer_idx + 1) % cfg.sliding_window_pattern != 0 else 0
|
||||
|
||||
@@ -52,7 +52,7 @@ def gen_mrope_embed(
|
||||
|
||||
# Get RoPE params
|
||||
|
||||
inv_freq, scaling_factor = rope.get_rope_params("cpu", config)
|
||||
inv_freq, scaling_factor = rope.get_rope_params("cpu", config, config.rotary_embedding_base)
|
||||
|
||||
# Create embeddings
|
||||
|
||||
|
||||
@@ -30,7 +30,8 @@ class ExLlamaV2ParallelDecoder(ExLlamaV2Module):
|
||||
key: str,
|
||||
layer_idx: int,
|
||||
sliding_window: int = 0,
|
||||
archparams = None
|
||||
archparams = None,
|
||||
rope_index: int = 0,
|
||||
):
|
||||
super().__init__(model, key, archparams)
|
||||
|
||||
@@ -49,7 +50,8 @@ class ExLlamaV2ParallelDecoder(ExLlamaV2Module):
|
||||
layer_idx,
|
||||
has_norm = False,
|
||||
has_residual = False,
|
||||
sliding_window = sliding_window
|
||||
sliding_window = sliding_window,
|
||||
rope_index = rope_index
|
||||
)
|
||||
self.mlp = ExLlamaV2MLP(
|
||||
model,
|
||||
|
||||
@@ -12,9 +12,9 @@ if TYPE_CHECKING:
|
||||
def get_rope_params_su(
|
||||
device: torch.Device,
|
||||
cfg: ExLlamaV2Config,
|
||||
base: float
|
||||
):
|
||||
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
|
||||
base = cfg.rotary_embedding_base
|
||||
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
|
||||
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
|
||||
|
||||
@@ -35,9 +35,9 @@ def get_rope_params_su(
|
||||
def get_rope_params_llama3(
|
||||
device: torch.Device,
|
||||
cfg: ExLlamaV2Config,
|
||||
base: float
|
||||
):
|
||||
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
|
||||
base = cfg.rotary_embedding_base
|
||||
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
|
||||
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
|
||||
|
||||
@@ -80,10 +80,10 @@ def get_rope_params_llama3(
|
||||
def get_rope_params_yarn(
|
||||
device: torch.Device,
|
||||
cfg: ExLlamaV2Config,
|
||||
base: float,
|
||||
):
|
||||
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
|
||||
|
||||
base = cfg.rotary_embedding_base
|
||||
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
|
||||
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
|
||||
|
||||
@@ -148,10 +148,10 @@ def get_rope_params_yarn(
|
||||
def get_rope_params_default(
|
||||
device: torch.Device,
|
||||
cfg: ExLlamaV2Config,
|
||||
base: float,
|
||||
):
|
||||
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
|
||||
|
||||
base = cfg.rotary_embedding_base
|
||||
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
|
||||
base *= cfg.scale_alpha_value ** (head_dim / (head_dim - 2))
|
||||
|
||||
@@ -162,15 +162,16 @@ def get_rope_params_default(
|
||||
def get_rope_params(
|
||||
device: torch.Device,
|
||||
cfg: ExLlamaV2Config,
|
||||
base: float,
|
||||
):
|
||||
if cfg.alt_rope_method == "su":
|
||||
inv_freq, scaling_factor = get_rope_params_su(device, cfg)
|
||||
inv_freq, scaling_factor = get_rope_params_su(device, cfg, base)
|
||||
elif cfg.alt_rope_method == "llama3":
|
||||
inv_freq, scaling_factor = get_rope_params_llama3(device, cfg)
|
||||
inv_freq, scaling_factor = get_rope_params_llama3(device, cfg, base)
|
||||
elif cfg.alt_rope_method == "yarn":
|
||||
inv_freq, scaling_factor = get_rope_params_yarn(device, cfg)
|
||||
inv_freq, scaling_factor = get_rope_params_yarn(device, cfg, base)
|
||||
else:
|
||||
inv_freq, scaling_factor = get_rope_params_default(device, cfg)
|
||||
inv_freq, scaling_factor = get_rope_params_default(device, cfg, base)
|
||||
|
||||
if cfg.arch.lm.rope_freq_half:
|
||||
inv_freq = inv_freq.half()
|
||||
|
||||
Reference in New Issue
Block a user