mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 06:20:12 +00:00
527 lines
19 KiB
Python
527 lines
19 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from fairseq import utils
|
|
from fairseq.incremental_decoding_utils import (
|
|
FairseqIncrementalState,
|
|
with_incremental_state,
|
|
)
|
|
from fairseq.modules.fairseq_dropout import FairseqDropout
|
|
from torch import Tensor
|
|
|
|
from .unfold import unfold1d
|
|
|
|
|
|
def DynamicConv(
|
|
input_size,
|
|
kernel_size=1,
|
|
padding_l=None,
|
|
num_heads=1,
|
|
weight_dropout=0.0,
|
|
weight_softmax=False,
|
|
renorm_padding=False,
|
|
bias=False,
|
|
conv_bias=False,
|
|
query_size=None,
|
|
in_proj=False,
|
|
):
|
|
if torch.cuda.is_available():
|
|
try:
|
|
from fairseq.modules.dynamicconv_layer import DynamicconvLayer
|
|
|
|
return DynamicconvLayer(
|
|
input_size,
|
|
kernel_size=kernel_size,
|
|
padding_l=padding_l,
|
|
num_heads=num_heads,
|
|
weight_dropout=weight_dropout,
|
|
weight_softmax=weight_softmax,
|
|
renorm_padding=renorm_padding,
|
|
bias=bias,
|
|
conv_bias=conv_bias,
|
|
query_size=query_size,
|
|
)
|
|
except ImportError as e:
|
|
print(e)
|
|
return DynamicConv1dTBC(
|
|
input_size,
|
|
kernel_size=kernel_size,
|
|
padding_l=padding_l,
|
|
num_heads=num_heads,
|
|
weight_dropout=weight_dropout,
|
|
weight_softmax=weight_softmax,
|
|
renorm_padding=renorm_padding,
|
|
bias=bias,
|
|
conv_bias=conv_bias,
|
|
query_size=query_size,
|
|
)
|
|
|
|
|
|
def Linear(in_features, out_features, bias=True):
|
|
m = nn.Linear(in_features, out_features, bias)
|
|
nn.init.xavier_uniform_(m.weight)
|
|
if bias:
|
|
nn.init.constant_(m.bias, 0.0)
|
|
return m
|
|
|
|
|
|
@with_incremental_state
|
|
class DynamicConv1dTBC(nn.Module):
|
|
"""Dynamic lightweight convolution taking T x B x C inputs
|
|
Args:
|
|
input_size: # of channels of the input
|
|
kernel_size: convolution channels
|
|
padding_l: padding to the left when using "same" padding
|
|
num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size)
|
|
weight_dropout: the drop rate of the DropConnect to drop the weight
|
|
weight_softmax: normalize the weight with softmax before the convolution
|
|
renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1)
|
|
bias: use bias
|
|
conv_bias: bias of the convolution
|
|
query_size: specified when feeding a different input as the query
|
|
in_proj: project the input and generate the filter together
|
|
|
|
Shape:
|
|
Input: TxBxC, i.e. (timesteps, batch_size, input_size)
|
|
Output: TxBxC, i.e. (timesteps, batch_size, input_size)
|
|
|
|
Attributes:
|
|
weight: the learnable weights of the module of shape
|
|
`(num_heads, 1, kernel_size)`
|
|
bias: the learnable bias of the module of shape `(input_size)`
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
kernel_size=1,
|
|
padding_l=None,
|
|
num_heads=1,
|
|
weight_dropout=0.0,
|
|
weight_softmax=False,
|
|
renorm_padding=False,
|
|
bias=False,
|
|
conv_bias=False,
|
|
query_size=None,
|
|
in_proj=False,
|
|
):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.query_size = input_size if query_size is None else query_size
|
|
self.kernel_size = kernel_size
|
|
self.padding_l = padding_l
|
|
self.num_heads = num_heads
|
|
self.weight_dropout_module = FairseqDropout(
|
|
weight_dropout, module_name=self.__class__.__name__
|
|
)
|
|
self.weight_softmax = weight_softmax
|
|
self.renorm_padding = renorm_padding
|
|
|
|
if in_proj:
|
|
self.weight_linear = Linear(
|
|
self.input_size, self.input_size + num_heads * kernel_size * 1
|
|
)
|
|
else:
|
|
self.weight_linear = Linear(
|
|
self.query_size, num_heads * kernel_size * 1, bias=bias
|
|
)
|
|
if conv_bias:
|
|
self.conv_bias = nn.Parameter(torch.Tensor(input_size))
|
|
else:
|
|
self.conv_bias = None
|
|
self.reset_parameters()
|
|
|
|
@property
|
|
def in_proj(self):
|
|
return (
|
|
self.weight_linear.out_features
|
|
== self.input_size + self.num_heads * self.kernel_size
|
|
)
|
|
|
|
def reset_parameters(self):
|
|
self.weight_linear.reset_parameters()
|
|
if self.conv_bias is not None:
|
|
nn.init.constant_(self.conv_bias, 0.0)
|
|
|
|
def forward(self, x, incremental_state=None, query=None, unfold=None):
|
|
"""Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C
|
|
args:
|
|
x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size)
|
|
incremental_state: A dict to keep the state
|
|
unfold: unfold the input or not. If not, we use the matrix trick instead
|
|
query: use the specified query to predict the conv filters
|
|
"""
|
|
unfold = (
|
|
x.size(0) > 512 if unfold is None else unfold
|
|
) # use unfold mode as default for long sequence to save memory
|
|
unfold = unfold or (incremental_state is not None)
|
|
assert query is None or not self.in_proj
|
|
|
|
if query is None:
|
|
query = x
|
|
if unfold:
|
|
output = self._forward_unfolded(x, incremental_state, query)
|
|
else:
|
|
output = self._forward_expanded(x, incremental_state, query)
|
|
|
|
if self.conv_bias is not None:
|
|
output = output + self.conv_bias.view(1, 1, -1)
|
|
return output
|
|
|
|
def _forward_unfolded(self, x, incremental_state, query):
|
|
"""The conventional implementation of convolutions.
|
|
Unfolding the input by having a window shifting to the right."""
|
|
T, B, C = x.size()
|
|
K, H = self.kernel_size, self.num_heads
|
|
R = C // H
|
|
assert R * H == C == self.input_size
|
|
|
|
if self.in_proj:
|
|
proj = self.weight_linear(x)
|
|
x = proj.narrow(2, 0, self.input_size).contiguous()
|
|
weight = (
|
|
proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1)
|
|
)
|
|
else:
|
|
weight = self.weight_linear(query).view(T * B * H, -1)
|
|
|
|
# renorm_padding is only implemented in _forward_expanded
|
|
assert not self.renorm_padding or incremental_state is not None
|
|
|
|
if incremental_state is not None:
|
|
input_buffer = self._get_input_buffer(incremental_state)
|
|
if input_buffer is None:
|
|
input_buffer = x.new()
|
|
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
|
|
if self.kernel_size > 1:
|
|
self._set_input_buffer(
|
|
incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :]
|
|
)
|
|
x_unfold = x_unfold.view(T * B * H, R, -1)
|
|
else:
|
|
padding_l = self.padding_l
|
|
if K > T and padding_l == K - 1:
|
|
weight = weight.narrow(1, K - T, T)
|
|
K, padding_l = T, T - 1
|
|
# unfold the input: T x B x C --> T' x B x C x K
|
|
x_unfold = unfold1d(x, K, padding_l, 0)
|
|
x_unfold = x_unfold.view(T * B * H, R, K)
|
|
|
|
if self.weight_softmax and not self.renorm_padding:
|
|
weight = F.softmax(weight, dim=1)
|
|
weight = weight.narrow(1, 0, K)
|
|
|
|
if incremental_state is not None:
|
|
weight = weight[:, -x_unfold.size(2) :]
|
|
K = weight.size(1)
|
|
|
|
if self.weight_softmax and self.renorm_padding:
|
|
weight = F.softmax(weight, dim=1)
|
|
|
|
weight = self.weight_dropout_module(weight, inplace=False)
|
|
|
|
output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1
|
|
output = output.view(T, B, C)
|
|
return output
|
|
|
|
def _forward_expanded(self, x, incremental_stat, query):
|
|
"""Turn the convolution filters into band matrices and do matrix multiplication.
|
|
This is faster when the sequence is short, but less memory efficient.
|
|
This is not used in the decoder during inference.
|
|
"""
|
|
T, B, C = x.size()
|
|
K, H = self.kernel_size, self.num_heads
|
|
R = C // H
|
|
assert R * H == C == self.input_size
|
|
if self.in_proj:
|
|
proj = self.weight_linear(x)
|
|
x = proj.narrow(2, 0, self.input_size).contiguous()
|
|
weight = (
|
|
proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1)
|
|
)
|
|
else:
|
|
weight = self.weight_linear(query).view(T * B * H, -1)
|
|
|
|
if not self.renorm_padding:
|
|
if self.weight_softmax:
|
|
weight = F.softmax(weight, dim=1)
|
|
weight = self.weight_dropout_module(weight, inplace=False)
|
|
weight = weight.narrow(1, 0, K).contiguous()
|
|
weight = weight.view(T, B * H, K).transpose(0, 1)
|
|
|
|
x = x.view(T, B * H, R).transpose(0, 1)
|
|
if self.weight_softmax and self.renorm_padding:
|
|
# turn the convolution filters into band matrices
|
|
weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf"))
|
|
weight_expanded.as_strided(
|
|
(B * H, T, K), (T * (T + K - 1), T + K, 1)
|
|
).copy_(weight)
|
|
weight_expanded = weight_expanded.narrow(2, self.padding_l, T)
|
|
# normalize the weight over valid positions like self-attention
|
|
weight_expanded = F.softmax(weight_expanded, dim=2)
|
|
weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False)
|
|
else:
|
|
P = self.padding_l
|
|
# For efficiency, we cut the kernel size and reduce the padding when the kernel is larger than the length
|
|
if K > T and P == K - 1:
|
|
weight = weight.narrow(2, K - T, T)
|
|
K, P = T, T - 1
|
|
# turn the convolution filters into band matrices
|
|
weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False)
|
|
weight_expanded.as_strided(
|
|
(B * H, T, K), (T * (T + K - 1), T + K, 1)
|
|
).copy_(weight)
|
|
weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T
|
|
output = torch.bmm(weight_expanded, x)
|
|
output = output.transpose(0, 1).contiguous().view(T, B, C)
|
|
return output
|
|
|
|
def reorder_incremental_state(self, incremental_state, new_order):
|
|
input_buffer = self._get_input_buffer(incremental_state)
|
|
if input_buffer is not None:
|
|
input_buffer = input_buffer.index_select(1, new_order)
|
|
self._set_input_buffer(incremental_state, input_buffer)
|
|
|
|
def _get_input_buffer(self, incremental_state):
|
|
return utils.get_incremental_state(self, incremental_state, "input_buffer")
|
|
|
|
def _set_input_buffer(self, incremental_state, new_buffer):
|
|
return utils.set_incremental_state(
|
|
self, incremental_state, "input_buffer", new_buffer
|
|
)
|
|
|
|
def extra_repr(self):
|
|
s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}".format(
|
|
self.input_size,
|
|
self.kernel_size,
|
|
self.padding_l,
|
|
self.num_heads,
|
|
self.weight_softmax,
|
|
self.conv_bias is not None,
|
|
self.renorm_padding,
|
|
self.in_proj,
|
|
)
|
|
|
|
if self.query_size != self.input_size:
|
|
s += ", query_size={}".format(self.query_size)
|
|
if self.weight_dropout_module.p > 0.0:
|
|
s += ", weight_dropout={}".format(self.weight_dropout_module.p)
|
|
return s
|
|
|
|
|
|
class DynamicConv_scripatable(nn.Module, FairseqIncrementalState):
|
|
"""Dynamic lightweight convolution taking T x B x C inputs
|
|
Args:
|
|
input_size: # of channels of the input
|
|
kernel_size: convolution channels
|
|
padding_l: padding to the left when using "same" padding
|
|
num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size)
|
|
weight_dropout: the drop rate of the DropConnect to drop the weight
|
|
weight_softmax: normalize the weight with softmax before the convolution
|
|
renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1)
|
|
bias: use bias
|
|
conv_bias: bias of the convolution
|
|
query_size: specified when feeding a different input as the query
|
|
in_proj: project the input and generate the filter together
|
|
|
|
Shape:
|
|
Input: TxBxC, i.e. (timesteps, batch_size, input_size)
|
|
Output: TxBxC, i.e. (timesteps, batch_size, input_size)
|
|
|
|
Attributes:
|
|
weight: the learnable weights of the module of shape
|
|
`(num_heads, 1, kernel_size)`
|
|
bias: the learnable bias of the module of shape `(input_size)`
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
kernel_size=1,
|
|
padding_l=None,
|
|
num_heads=1,
|
|
weight_dropout=0.0,
|
|
weight_softmax=False,
|
|
renorm_padding=False,
|
|
bias=False,
|
|
conv_bias=False,
|
|
query_size=None,
|
|
in_proj=False,
|
|
):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.query_size = input_size if query_size is None else query_size
|
|
self.kernel_size = kernel_size
|
|
self.padding_l = padding_l
|
|
self.num_heads = num_heads
|
|
self.weight_dropout_module = FairseqDropout(
|
|
weight_dropout, module_name=self.__class__.__name__
|
|
)
|
|
self.weight_softmax = weight_softmax
|
|
self.renorm_padding = renorm_padding
|
|
|
|
if in_proj:
|
|
self.weight_linear = Linear(
|
|
self.input_size, self.input_size + num_heads * kernel_size * 1
|
|
)
|
|
else:
|
|
self.weight_linear = Linear(
|
|
self.query_size, num_heads * kernel_size * 1, bias=bias
|
|
)
|
|
self.in_proj = (
|
|
self.weight_linear.out_features
|
|
== self.input_size + self.num_heads * self.kernel_size
|
|
)
|
|
self.has_conv_bias = conv_bias
|
|
self.conv_bias = nn.Parameter(torch.Tensor(input_size).view(1, 1, -1))
|
|
self.init_incremental_state()
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
self.weight_linear.reset_parameters()
|
|
if self.has_conv_bias:
|
|
nn.init.constant_(self.conv_bias, 0.0)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
|
query: Optional[Tensor] = None,
|
|
):
|
|
"""Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C
|
|
args:
|
|
x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size)
|
|
incremental_state: A dict to keep the state
|
|
unfold: unfold the input or not. If not, we use the matrix trick instead
|
|
query: use the specified query to predict the conv filters
|
|
"""
|
|
assert query is None or not self.in_proj
|
|
|
|
if query is None:
|
|
query = x
|
|
|
|
output = self._forward_unfolded(x, incremental_state, query)
|
|
|
|
if self.has_conv_bias:
|
|
output = output + self.conv_bias
|
|
return output
|
|
|
|
def _forward_unfolded(
|
|
self,
|
|
x,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
query,
|
|
):
|
|
"""The conventional implementation of convolutions.
|
|
Unfolding the input by having a window shifting to the right."""
|
|
T, B, C = x.size()
|
|
K, H = self.kernel_size, self.num_heads
|
|
R = C // H
|
|
assert R * H == C == self.input_size
|
|
|
|
TxBxH = T * B * H
|
|
|
|
if self.in_proj:
|
|
proj = self.weight_linear(x)
|
|
x = proj.narrow(2, 0, self.input_size).contiguous()
|
|
weight = proj.narrow(2, self.input_size, H * K).contiguous().view(TxBxH, -1)
|
|
else:
|
|
weight = self.weight_linear(query).view(TxBxH, -1)
|
|
|
|
# renorm_padding is only implemented in _forward_expanded
|
|
assert not self.renorm_padding or incremental_state is not None
|
|
|
|
if incremental_state is not None:
|
|
input_buffer = self._get_input_buffer(incremental_state)
|
|
if input_buffer is not None:
|
|
x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
|
|
else:
|
|
x_unfold = x.unsqueeze(3).clone()
|
|
if self.kernel_size > 1:
|
|
self._set_input_buffer(
|
|
incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :]
|
|
)
|
|
x_unfold = x_unfold.view(TxBxH, R, -1)
|
|
else:
|
|
padding_l = self.padding_l
|
|
if K > T and padding_l == K - 1:
|
|
weight = weight.narrow(1, K - T, T)
|
|
K, padding_l = T, T - 1
|
|
# unfold the input: T x B x C --> T' x B x C x K
|
|
x_unfold = unfold1d(x, K, padding_l, 0.0)
|
|
x_unfold = x_unfold.view(TxBxH, R, K)
|
|
|
|
if self.weight_softmax and not self.renorm_padding:
|
|
weight = F.softmax(weight, dim=1)
|
|
weight = weight.narrow(1, 0, K)
|
|
|
|
if incremental_state is not None:
|
|
weight = weight[:, -(x_unfold.size(2)) :]
|
|
K = weight.size(1)
|
|
|
|
if self.weight_softmax and self.renorm_padding:
|
|
weight = F.softmax(weight, dim=1)
|
|
|
|
weight = self.weight_dropout_module(weight, inplace=False)
|
|
|
|
output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T x B x H x R x 1
|
|
output = output.view(T, B, C)
|
|
return output
|
|
|
|
def reorder_incremental_state(
|
|
self,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
new_order: Tensor,
|
|
):
|
|
input_buffer = self._get_input_buffer(incremental_state)
|
|
if input_buffer is not None:
|
|
input_buffer = input_buffer.index_select(1, new_order)
|
|
self._set_input_buffer(incremental_state, input_buffer)
|
|
|
|
def _get_input_buffer(
|
|
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
|
):
|
|
result = self.get_incremental_state(incremental_state, "input_buffer")
|
|
if result is not None and "input_buffer" in result:
|
|
return result["input_buffer"]
|
|
else:
|
|
return None
|
|
|
|
def _set_input_buffer(
|
|
self,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
new_buffer: Optional[Tensor],
|
|
):
|
|
result = self.set_incremental_state(
|
|
incremental_state, "input_buffer", {"input_buffer": new_buffer}
|
|
)
|
|
if result is not None:
|
|
incremental_state = result
|
|
return incremental_state
|
|
|
|
def extra_repr(self):
|
|
s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}".format( # noqa
|
|
self.input_size,
|
|
self.kernel_size,
|
|
self.padding_l,
|
|
self.num_heads,
|
|
self.weight_softmax,
|
|
self.conv_bias is not None,
|
|
self.renorm_padding,
|
|
self.in_proj,
|
|
)
|
|
|
|
if self.query_size != self.input_size:
|
|
s += ", query_size={}".format(self.query_size)
|
|
if self.weight_dropout_module.p > 0.0:
|
|
s += ", weight_dropout={}".format(self.weight_dropout_module.p)
|
|
return s
|