mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 06:20:12 +00:00
82 lines
2.5 KiB
Python
82 lines
2.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from typing import List
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from fairseq.modules.quant_noise import quant_noise
|
|
|
|
|
|
class AdaptiveInput(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
padding_idx: int,
|
|
initial_dim: int,
|
|
factor: float,
|
|
output_dim: int,
|
|
cutoff: List[int],
|
|
q_noise: float = 0,
|
|
qn_block_size: int = 8,
|
|
):
|
|
super().__init__()
|
|
|
|
if vocab_size > cutoff[-1]:
|
|
cutoff = cutoff + [vocab_size]
|
|
else:
|
|
assert (
|
|
vocab_size == cutoff[-1]
|
|
), "cannot specify cutoff larger than vocab size"
|
|
|
|
self.cutoff = cutoff
|
|
self.embedding_dim = output_dim
|
|
self.padding_idx = padding_idx
|
|
|
|
self.embeddings = nn.ModuleList()
|
|
for i in range(len(self.cutoff)):
|
|
prev = self.cutoff[i - 1] if i > 0 else 0
|
|
size = self.cutoff[i] - prev
|
|
dim = int(initial_dim // (factor**i))
|
|
seq = nn.Sequential(
|
|
nn.Embedding(size, dim, self.padding_idx),
|
|
quant_noise(
|
|
nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size
|
|
),
|
|
)
|
|
|
|
self.embeddings.append(seq)
|
|
self.padding_idx = None
|
|
self.padding_idx = padding_idx
|
|
|
|
def init_weights(m):
|
|
if isinstance(m, nn.Embedding):
|
|
nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5)
|
|
nn.init.constant_(m.weight[padding_idx], 0)
|
|
elif hasattr(m, "weight"):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
|
|
self.apply(init_weights)
|
|
|
|
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
|
|
|
def weights_for_band(self, band: int):
|
|
return self.embeddings[band][0].weight, self.embeddings[band][1].weight
|
|
|
|
def forward(self, input: torch.Tensor):
|
|
result = self._float_tensor.new(input.shape + (self.embedding_dim,))
|
|
for i in range(len(self.cutoff)):
|
|
mask = input.lt(self.cutoff[i])
|
|
if i > 0:
|
|
mask.mul_(input.ge(self.cutoff[i - 1]))
|
|
chunk_input = input[mask] - self.cutoff[i - 1]
|
|
else:
|
|
chunk_input = input[mask]
|
|
if mask.any():
|
|
result[mask] = self.embeddings[i](chunk_input)
|
|
return result
|