mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
166 lines
6.0 KiB
Python
166 lines
6.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
A modified version of the legacy DistributedDataParallel module that uses c10d
|
|
communication primitives. This version is simpler than the latest PyTorch
|
|
version and is useful for debugging. Notably it does not overlap gradient
|
|
communication with the backward pass, which makes it slower but more robust
|
|
than the PyTorch version.
|
|
|
|
This version also supports the *no_sync* context manager, which allows faster
|
|
training with `--update-freq`.
|
|
"""
|
|
|
|
from collections import OrderedDict
|
|
from contextlib import contextmanager
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from fairseq.distributed import utils
|
|
|
|
|
|
class LegacyDistributedDataParallel(nn.Module):
|
|
"""Implements distributed data parallelism at the module level.
|
|
|
|
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
|
|
This version uses a c10d process group for communication and does not
|
|
broadcast buffers.
|
|
|
|
Args:
|
|
module (~torch.nn.Module): module to be parallelized
|
|
process_group: the c10d process group to be used for distributed data
|
|
parallel all-reduction.
|
|
buffer_size (int, optional): number of elements to buffer before
|
|
performing all-reduce (default: 256M).
|
|
"""
|
|
|
|
def __init__(self, module, process_group, buffer_size=2**28):
|
|
super().__init__()
|
|
|
|
self.module = module
|
|
self.process_group = process_group
|
|
self.world_size = utils.get_world_size(self.process_group)
|
|
|
|
# Never use a bigger buffer than the number of model params
|
|
self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters()))
|
|
self.buffer = None
|
|
|
|
# We can also forcibly accumulate grads locally and only do the
|
|
# all-reduce at some later time
|
|
self.accumulate_grads = False
|
|
|
|
# make per-device lists of parameters
|
|
paramlists = OrderedDict()
|
|
for param in self.module.parameters():
|
|
device = param.device
|
|
if paramlists.get(device) is None:
|
|
paramlists[device] = []
|
|
paramlists[device] += [param]
|
|
self.per_device_params = list(paramlists.values())
|
|
|
|
@contextmanager
|
|
def no_sync(self):
|
|
"""A context manager to disable gradient synchronization."""
|
|
old_accumulate_grads = self.accumulate_grads
|
|
self.accumulate_grads = True
|
|
yield
|
|
self.accumulate_grads = old_accumulate_grads
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
return self.module(*inputs, **kwargs)
|
|
|
|
def all_reduce_grads(self):
|
|
"""
|
|
This function must be called explicitly after backward to reduce
|
|
gradients. There is no automatic hook like c10d.
|
|
"""
|
|
|
|
def all_reduce_params(params):
|
|
buffer = self.buffer
|
|
nonzero_buffer = False
|
|
if len(params) > 1:
|
|
offset = 0
|
|
for p in params:
|
|
sz = p.numel()
|
|
if p.grad is not None:
|
|
buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
|
|
nonzero_buffer = True
|
|
else:
|
|
buffer[offset : offset + sz].zero_()
|
|
offset += sz
|
|
else:
|
|
# we only have a single grad to all-reduce
|
|
p = params[0]
|
|
if p.grad is not None:
|
|
buffer = p.grad.data
|
|
nonzero_buffer = True
|
|
elif p.numel() <= self.buffer.numel():
|
|
buffer = buffer[: p.numel()]
|
|
buffer.zero_()
|
|
else:
|
|
buffer = torch.zeros_like(p)
|
|
|
|
if nonzero_buffer:
|
|
buffer.div_(self.world_size)
|
|
|
|
utils.all_reduce(buffer, self.process_group)
|
|
|
|
# copy all-reduced grads back into their original place
|
|
offset = 0
|
|
for p in params:
|
|
sz = p.numel()
|
|
if p.grad is not None:
|
|
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
|
|
else:
|
|
p.grad = buffer[offset : offset + sz].view_as(p).clone()
|
|
offset += sz
|
|
|
|
def reduction_fn():
|
|
# This function only needs to be called once
|
|
if self.accumulate_grads:
|
|
return
|
|
|
|
if self.buffer is None:
|
|
self.buffer = next(self.module.parameters()).new(self.buffer_size)
|
|
|
|
for params in self.per_device_params:
|
|
# All-reduce the gradients in buckets
|
|
offset = 0
|
|
buffered_params = []
|
|
for param in params:
|
|
if not param.requires_grad:
|
|
continue
|
|
if param.grad is None:
|
|
param.grad = torch.zeros_like(param)
|
|
|
|
if hasattr(param, "expert"):
|
|
# Skip gradient sync for unshared parameters
|
|
continue
|
|
|
|
if param.grad.requires_grad:
|
|
raise RuntimeError(
|
|
"DistributedDataParallel only works "
|
|
"with gradients that don't require "
|
|
"grad"
|
|
)
|
|
sz = param.numel()
|
|
if sz > self.buffer.numel():
|
|
# all-reduce big params directly
|
|
all_reduce_params([param])
|
|
else:
|
|
if offset + sz > self.buffer.numel():
|
|
all_reduce_params(buffered_params)
|
|
offset = 0
|
|
buffered_params.clear()
|
|
buffered_params.append(param)
|
|
offset += sz
|
|
|
|
if len(buffered_params) > 0:
|
|
all_reduce_params(buffered_params)
|
|
|
|
reduction_fn()
|