mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-05-02 20:31:18 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ModuleProxyWrapper(nn.Module):
|
||||
"""
|
||||
Wrap a DistributedDataParallel module and forward requests for missing
|
||||
attributes to the module wrapped by DDP (the twice-wrapped module).
|
||||
Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
|
||||
|
||||
Usage::
|
||||
|
||||
module.xyz = "hello world"
|
||||
wrapped_module = DistributedDataParallel(module, **ddp_args)
|
||||
wrapped_module = ModuleProxyWrapper(wrapped_module)
|
||||
assert wrapped_module.xyz == "hello world"
|
||||
assert wrapped_module.state_dict().keys() == module.state_dict().keys()
|
||||
|
||||
Args:
|
||||
module (nn.Module): module to wrap
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module):
|
||||
super().__init__()
|
||||
assert hasattr(
|
||||
module, "module"
|
||||
), "ModuleProxyWrapper expects input to wrap another module"
|
||||
self.module = module
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Forward missing attributes to twice-wrapped module."""
|
||||
try:
|
||||
# defer to nn.Module's logic
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
try:
|
||||
# forward to the once-wrapped module
|
||||
return getattr(self.module, name)
|
||||
except AttributeError:
|
||||
# forward to the twice-wrapped module
|
||||
return getattr(self.module.module, name)
|
||||
|
||||
def state_dict(self, *args, **kwargs):
|
||||
"""Forward to the twice-wrapped module."""
|
||||
return self.module.module.state_dict(*args, **kwargs)
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
"""Forward to the twice-wrapped module."""
|
||||
return self.module.module.load_state_dict(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
Reference in New Issue
Block a user