mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from fairseq.distributed import utils
|
|
|
|
|
|
class TPUDistributedDataParallel(nn.Module):
|
|
def __init__(self, module, process_group):
|
|
super().__init__()
|
|
self.module = module
|
|
self.process_group = process_group
|
|
self.world_size = utils.get_world_size(self.process_group)
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
return self.module(*inputs, **kwargs)
|
|
|
|
def all_reduce_grads(self):
|
|
gradients = []
|
|
for p in self.parameters():
|
|
if not p.requires_grad:
|
|
continue
|
|
if p.grad is None:
|
|
p.grad = torch.zeros_like(p)
|
|
if p.grad.requires_grad:
|
|
raise RuntimeError(
|
|
"TPUDistributedDataParallel only works with gradients that don't "
|
|
"require grad"
|
|
)
|
|
gradients.append(p.grad)
|
|
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
xm.all_reduce(
|
|
"sum",
|
|
gradients,
|
|
scale=1.0 / self.world_size,
|
|
groups=self.process_group[1],
|
|
)
|