Allow per-layer RoPE theta

This commit is contained in:
turboderp
2025-03-14 23:12:50 +01:00
parent 23395dfa42
commit 7b05acd233
6 changed files with 66 additions and 46 deletions

View File

@@ -135,7 +135,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
has_norm: bool = True,
has_residual: bool = True,
sliding_window: int = 0,
archparams = None
archparams = None,
rope_index: int = 0
):
super().__init__(model, key, archparams)
@@ -149,6 +150,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
self.layer_idx = layer_idx
self.has_norm = has_norm
self.has_residual = has_residual
self.rope_index = rope_index
self.q_handle = None
self.temp_lora_size = 0
@@ -503,7 +505,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
sc = attn_params.get_alt_rope_embed(self.device_idx)
if not sc:
sin, cos = constants.sin, constants.cos
sin, cos = constants.sin[self.rope_index], constants.cos[self.rope_index]
else:
sin, cos = sc
@@ -769,8 +771,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
for t, heads in [(q[idx], self.num_key_value_groups), (k[idx], 1)]:
ext_c.rope_(
t,
context.sin,
context.cos,
context.sin[self.rope_index],
context.cos[self.rope_index],
0,
(b - a) * heads,
self.head_dim,
@@ -1116,7 +1118,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
sc = attn_params.get_alt_rope_embed(self.device_idx)
if not sc:
sin, cos = constants.sin, constants.cos
sin, cos = constants.sin[self.rope_index], constants.cos[self.rope_index]
else:
sin, cos = sc
@@ -1304,8 +1306,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
for t, heads in [(q[idx], self.num_key_value_groups), (k[idx], 1)]:
ext_c.rope_(
t,
context.sin,
context.cos,
context.sin[self.rope_index],
context.cos[self.rope_index],
past_len,
(b - a) * heads,
self.head_dim,
@@ -1444,7 +1446,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
if self.archparams.rope_style != RopeStyle.NONE:
alt_cs = kwargs.get("alt_rope_embedding")
cos, sin = alt_cs if alt_cs else (constants.cos, constants.sin)
cos, sin = alt_cs if alt_cs else (constants.cos[self.rope_index], constants.sin[self.rope_index])
sc = attn_params.get_alt_rope_embed(self.device_idx)
if sc: sin, cos = sc

View File

@@ -40,8 +40,8 @@ class ExLlamaV2DeviceContext:
scratch_bytes: int
scratch_idx: int
sin: torch.Tensor | None
cos: torch.Tensor | None
sin: list[torch.Tensor] | None
cos: list[torch.Tensor] | None
scratch: torch.Tensor | None
@@ -116,35 +116,49 @@ class ExLlamaV2DeviceContext:
def prepare_sincos(self):
device = _torch_device(self.device_idx)
cfg = self.model.config
if self.archparams.rope_style == RopeStyle.NONE:
self.sin = torch.zeros((1,), device = device, dtype = torch.half)
self.cos = self.sin
return
# RoPE params
thetas = [cfg.rotary_embedding_base]
if cfg.rotary_embedding_base_alt:
thetas.append(cfg.rotary_embedding_base_alt)
inv_freq, scaling_factor = rope.get_rope_params(device, cfg)
self.sin = []
self.cos = []
# Common
for theta in thetas:
scale = cfg.scale_pos_emb or 1.0
t = torch.arange(cfg.max_seq_len, device = device, dtype = torch.float32)
if scale != 1.0: t /= scale
if self.archparams.rope_style == RopeStyle.NONE:
sin = torch.zeros((1,), device = device, dtype = torch.half)
cos = sin
self.sin.append(sin)
self.cos.append(cos)
break
freqs = torch.einsum("i,j->ij", t, inv_freq)
if self.archparams.rope_style == RopeStyle.NEOX:
emb = torch.cat((freqs, freqs), dim=-1)
elif self.archparams.rope_style == RopeStyle.GPTJ:
emb = torch.repeat_interleave(freqs, 2, dim=-1)
else:
raise ValueError()
# RoPE params
self.sin = emb.sin()[None, None, :, :]
self.cos = emb.cos()[None, None, :, :]
if scaling_factor != 1.0:
self.sin *= scaling_factor
self.cos *= scaling_factor
self.sin = self.sin.half()
self.cos = self.cos.half()
inv_freq, scaling_factor = rope.get_rope_params(device, cfg, theta)
# Common
scale = cfg.scale_pos_emb or 1.0
t = torch.arange(cfg.max_seq_len, device = device, dtype = torch.float32)
if scale != 1.0: t /= scale
freqs = torch.einsum("i,j->ij", t, inv_freq)
if self.archparams.rope_style == RopeStyle.NEOX:
emb = torch.cat((freqs, freqs), dim=-1)
elif self.archparams.rope_style == RopeStyle.GPTJ:
emb = torch.repeat_interleave(freqs, 2, dim=-1)
else:
raise ValueError()
sin = emb.sin()[None, None, :, :]
cos = emb.cos()[None, None, :, :]
if scaling_factor != 1.0:
sin *= scaling_factor
cos *= scaling_factor
sin = sin.half()
cos = cos.half()
self.sin.append(sin)
self.cos.append(cos)

View File

@@ -106,6 +106,7 @@ class ExLlamaV2:
for layer_idx in range(cfg.num_hidden_layers):
layer_key = cfg.arch.lm_prefix + f"model.layers.{layer_idx}"
rope_index = 0
if cfg.arch.lm.alternating_swa:
swa = cfg.sliding_window if (layer_idx + 1) % cfg.sliding_window_pattern != 0 else 0

View File

@@ -52,7 +52,7 @@ def gen_mrope_embed(
# Get RoPE params
inv_freq, scaling_factor = rope.get_rope_params("cpu", config)
inv_freq, scaling_factor = rope.get_rope_params("cpu", config, config.rotary_embedding_base)
# Create embeddings

View File

@@ -30,7 +30,8 @@ class ExLlamaV2ParallelDecoder(ExLlamaV2Module):
key: str,
layer_idx: int,
sliding_window: int = 0,
archparams = None
archparams = None,
rope_index: int = 0,
):
super().__init__(model, key, archparams)
@@ -49,7 +50,8 @@ class ExLlamaV2ParallelDecoder(ExLlamaV2Module):
layer_idx,
has_norm = False,
has_residual = False,
sliding_window = sliding_window
sliding_window = sliding_window,
rope_index = rope_index
)
self.mlp = ExLlamaV2MLP(
model,

View File

@@ -12,9 +12,9 @@ if TYPE_CHECKING:
def get_rope_params_su(
device: torch.Device,
cfg: ExLlamaV2Config,
base: float
):
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
base = cfg.rotary_embedding_base
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
@@ -35,9 +35,9 @@ def get_rope_params_su(
def get_rope_params_llama3(
device: torch.Device,
cfg: ExLlamaV2Config,
base: float
):
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
base = cfg.rotary_embedding_base
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
@@ -80,10 +80,10 @@ def get_rope_params_llama3(
def get_rope_params_yarn(
device: torch.Device,
cfg: ExLlamaV2Config,
base: float,
):
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
base = cfg.rotary_embedding_base
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
base *= cfg.scale_alpha_value ** (cfg.head_dim / (cfg.head_dim - 2))
@@ -148,10 +148,10 @@ def get_rope_params_yarn(
def get_rope_params_default(
device: torch.Device,
cfg: ExLlamaV2Config,
base: float,
):
head_dim = int(cfg.head_dim * cfg.partial_rotary_factor)
base = cfg.rotary_embedding_base
if cfg.scale_alpha_value and cfg.scale_alpha_value != 1.0:
base *= cfg.scale_alpha_value ** (head_dim / (head_dim - 2))
@@ -162,15 +162,16 @@ def get_rope_params_default(
def get_rope_params(
device: torch.Device,
cfg: ExLlamaV2Config,
base: float,
):
if cfg.alt_rope_method == "su":
inv_freq, scaling_factor = get_rope_params_su(device, cfg)
inv_freq, scaling_factor = get_rope_params_su(device, cfg, base)
elif cfg.alt_rope_method == "llama3":
inv_freq, scaling_factor = get_rope_params_llama3(device, cfg)
inv_freq, scaling_factor = get_rope_params_llama3(device, cfg, base)
elif cfg.alt_rope_method == "yarn":
inv_freq, scaling_factor = get_rope_params_yarn(device, cfg)
inv_freq, scaling_factor = get_rope_params_yarn(device, cfg, base)
else:
inv_freq, scaling_factor = get_rope_params_default(device, cfg)
inv_freq, scaling_factor = get_rope_params_default(device, cfg, base)
if cfg.arch.lm.rope_freq_half:
inv_freq = inv_freq.half()