mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-29 10:51:19 +00:00
Add monkey patched fairseq package to run on python 3.11 (what is needed for our use of RVC at least)
This commit is contained in:
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fairseq.distributed import utils
|
||||
|
||||
|
||||
class TPUDistributedDataParallel(nn.Module):
|
||||
def __init__(self, module, process_group):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.process_group = process_group
|
||||
self.world_size = utils.get_world_size(self.process_group)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
def all_reduce_grads(self):
|
||||
gradients = []
|
||||
for p in self.parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if p.grad is None:
|
||||
p.grad = torch.zeros_like(p)
|
||||
if p.grad.requires_grad:
|
||||
raise RuntimeError(
|
||||
"TPUDistributedDataParallel only works with gradients that don't "
|
||||
"require grad"
|
||||
)
|
||||
gradients.append(p.grad)
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
xm.all_reduce(
|
||||
"sum",
|
||||
gradients,
|
||||
scale=1.0 / self.world_size,
|
||||
groups=self.process_group[1],
|
||||
)
|
||||
Reference in New Issue
Block a user