mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Allow component models to use learned pos embeddings without regarding LLM max_seq_len
This commit is contained in:
@@ -20,11 +20,13 @@ class ExLlamaV2PosEmbedding(ExLlamaV2Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
key: str
|
||||
key: str,
|
||||
override_max_seq_len: bool = False
|
||||
):
|
||||
super().__init__(model, key)
|
||||
|
||||
self.native_ctx_size = model.config.max_seq_len
|
||||
self.override_max_seq_len = override_max_seq_len
|
||||
self.embedding = None
|
||||
|
||||
|
||||
@@ -34,8 +36,9 @@ class ExLlamaV2PosEmbedding(ExLlamaV2Module):
|
||||
w = self.load_weight()
|
||||
assert isinstance(w, nn.Parameter)
|
||||
self.native_ctx_size = w.shape[0]
|
||||
assert self.model.config.max_seq_len <= self.native_ctx_size, \
|
||||
f"Learned positional embeddings cannot be extended past native size of {self.native_ctx_size}."
|
||||
if not self.override_max_seq_len:
|
||||
assert self.model.config.max_seq_len <= self.native_ctx_size, \
|
||||
f"Learned positional embeddings cannot be extended past native size of {self.native_ctx_size}."
|
||||
|
||||
self.embedding = nn.Embedding(self.native_ctx_size, self.model.config.hidden_size, device = "meta")
|
||||
self.embedding.weight = w
|
||||
|
||||
Reference in New Issue
Block a user