mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
146 lines
4.9 KiB
Python
146 lines
4.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import contextlib
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from fairseq.dataclass.configs import DistributedTrainingConfig
|
|
from fairseq.distributed import utils as dist_utils
|
|
|
|
|
|
try:
|
|
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
|
|
|
|
has_FSDP = True
|
|
except ImportError:
|
|
FSDP = torch.nn.Module
|
|
has_FSDP = False
|
|
|
|
|
|
class FullyShardedDataParallel(FSDP):
|
|
"""
|
|
A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some
|
|
fairseq-specific checkpoint saving/loading logic.
|
|
|
|
Args:
|
|
use_sharded_state (bool): if True, then ``state_dict`` will return
|
|
``FSDP.local_state_dict`` and ``load_state_dict`` will call
|
|
``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will
|
|
return the full model weights on data parallel rank 0 (empty on
|
|
other ranks) and ``load_state_dict`` will broadcast model weights
|
|
from rank 0 to other ranks.
|
|
"""
|
|
|
|
def __init__(self, *args, use_sharded_state: bool = False, **kwargs):
|
|
if not has_FSDP:
|
|
raise ImportError(
|
|
"Cannot find FullyShardedDataParallel. "
|
|
"Please install fairscale with: pip install fairscale"
|
|
)
|
|
super().__init__(*args, **kwargs)
|
|
self.use_sharded_state = use_sharded_state
|
|
|
|
@property
|
|
def unwrapped_module(self) -> torch.nn.Module:
|
|
if self.flatten_parameters:
|
|
return self.module.module
|
|
else:
|
|
return self.module
|
|
|
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
|
if self.use_sharded_state:
|
|
return super().local_state_dict(
|
|
destination=destination, prefix=prefix, keep_vars=keep_vars
|
|
)
|
|
else:
|
|
if self.rank == 0:
|
|
return super().state_dict(
|
|
destination=destination, prefix=prefix, keep_vars=keep_vars
|
|
)
|
|
else:
|
|
# We must call state_dict() due to use of communication
|
|
# primitives. But we don't use the result.
|
|
super().state_dict()
|
|
return destination or {}
|
|
|
|
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
|
|
if self.use_sharded_state:
|
|
return super().load_local_state_dict(state_dict, strict=strict)
|
|
else:
|
|
state_dict = dist_utils.broadcast_object(
|
|
state_dict, src_rank=0, group=self.process_group
|
|
)
|
|
return super().load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
class DummyProcessGroup:
|
|
def __init__(self, rank: int, size: int):
|
|
self._rank = rank
|
|
self._size = size
|
|
|
|
def rank(self) -> int:
|
|
return self._rank
|
|
|
|
def size(self) -> int:
|
|
return self._size
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def fsdp_enable_wrap(cfg: DistributedTrainingConfig):
|
|
try:
|
|
from fairscale.nn import enable_wrap
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Cannot find FullyShardedDataParallel. "
|
|
"Please install fairscale with: pip install fairscale"
|
|
)
|
|
if cfg.memory_efficient_fp16:
|
|
assert cfg.fp16 # memory_efficient_fp16 should imply fp16
|
|
group = dist_utils.get_data_parallel_group()
|
|
if group is None and cfg.distributed_world_size == 1:
|
|
group = DummyProcessGroup(rank=0, size=1)
|
|
fsdp_config = {
|
|
"process_group": group,
|
|
"reshard_after_forward": not cfg.no_reshard_after_forward,
|
|
"mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
|
|
"fp32_reduce_scatter": cfg.fp32_reduce_scatter,
|
|
"flatten_parameters": not cfg.not_fsdp_flatten_parameters,
|
|
"cpu_offload": cfg.cpu_offload,
|
|
"compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
|
|
"bucket_cap_mb": cfg.bucket_cap_mb,
|
|
"state_dict_device": torch.device("cpu"), # reduce GPU mem usage
|
|
}
|
|
with enable_wrap(
|
|
wrapper_cls=FullyShardedDataParallel,
|
|
use_sharded_state=cfg.use_sharded_state,
|
|
**fsdp_config,
|
|
):
|
|
yield
|
|
|
|
|
|
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
|
|
"""
|
|
Helper to wrap layers/modules in FSDP. This falls back to a no-op if
|
|
fairscale is not available.
|
|
|
|
Args:
|
|
module (nn.Module): module to (maybe) wrap
|
|
min_num_params (int, Optional): minimum number of layer params to wrap
|
|
"""
|
|
try:
|
|
from fairscale.nn import wrap
|
|
|
|
if min_num_params is not None:
|
|
num_params = sum(p.numel() for p in module.parameters())
|
|
if num_params >= min_num_params:
|
|
return wrap(module, **kwargs)
|
|
else:
|
|
return module
|
|
else:
|
|
return wrap(module, **kwargs)
|
|
except ImportError:
|
|
return module
|