Allow component models to use learned pos embeddings without regarding LLM max_seq_len

This commit is contained in:
turboderp
2025-03-14 23:13:59 +01:00
parent 7b05acd233
commit 9669fa33c9

View File

@@ -20,11 +20,13 @@ class ExLlamaV2PosEmbedding(ExLlamaV2Module):
def __init__(
self,
model: ExLlamaV2,
key: str
key: str,
override_max_seq_len: bool = False
):
super().__init__(model, key)
self.native_ctx_size = model.config.max_seq_len
self.override_max_seq_len = override_max_seq_len
self.embedding = None
@@ -34,8 +36,9 @@ class ExLlamaV2PosEmbedding(ExLlamaV2Module):
w = self.load_weight()
assert isinstance(w, nn.Parameter)
self.native_ctx_size = w.shape[0]
assert self.model.config.max_seq_len <= self.native_ctx_size, \
f"Learned positional embeddings cannot be extended past native size of {self.native_ctx_size}."
if not self.override_max_seq_len:
assert self.model.config.max_seq_len <= self.native_ctx_size, \
f"Learned positional embeddings cannot be extended past native size of {self.native_ctx_size}."
self.embedding = nn.Embedding(self.native_ctx_size, self.model.config.hidden_size, device = "meta")
self.embedding.weight = w