mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-10 05:50:10 +00:00
59 lines
1.6 KiB
Python
59 lines
1.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Any, Dict
|
|
|
|
from fairseq.distributed import utils
|
|
|
|
|
|
try:
|
|
from fairscale.optim import OSS
|
|
|
|
_has_fairscale = True
|
|
except ImportError:
|
|
_has_fairscale = False
|
|
|
|
|
|
def shard_(optimizer, group):
|
|
if not _has_fairscale:
|
|
raise ImportError(
|
|
"\n\nPlease install the fairscale package:" "\n\n pip install fairscale"
|
|
)
|
|
|
|
class FairseqOSS(OSS):
|
|
@property
|
|
def disable_mem_eff_fp16_loading_hack(self):
|
|
return True
|
|
|
|
def __getattr__(self, name):
|
|
if name.startswith("supports") and hasattr(self.optim, name):
|
|
return getattr(self.optim, name)
|
|
raise AttributeError(
|
|
"'FairseqOSS' object has no attribute {0!r}".format(name)
|
|
)
|
|
|
|
def broadcast_global_state_dict(
|
|
self, state_dict: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Broadcasts the entire state_dict to all other ranks
|
|
each rank is responsible to load their own partition of data
|
|
"""
|
|
return utils.broadcast_object(
|
|
state_dict,
|
|
src_rank=0,
|
|
group=self.group,
|
|
)
|
|
|
|
torch_optimizer = optimizer.optimizer
|
|
optim_cls = type(torch_optimizer)
|
|
|
|
optimizer.optimizer = FairseqOSS(
|
|
torch_optimizer.param_groups,
|
|
optim_cls,
|
|
group=group,
|
|
**optimizer.optimizer_config
|
|
)
|