mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-22 06:04:26 +00:00
52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import uuid
|
|
from typing import Dict, Optional
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
class FairseqIncrementalState(object):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.init_incremental_state()
|
|
|
|
def init_incremental_state(self):
|
|
self._incremental_state_id = str(uuid.uuid4())
|
|
|
|
def _get_full_incremental_state_key(self, key: str) -> str:
|
|
return "{}.{}".format(self._incremental_state_id, key)
|
|
|
|
def get_incremental_state(
|
|
self,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
key: str,
|
|
) -> Optional[Dict[str, Optional[Tensor]]]:
|
|
"""Helper for getting incremental state for an nn.Module."""
|
|
full_key = self._get_full_incremental_state_key(key)
|
|
if incremental_state is None or full_key not in incremental_state:
|
|
return None
|
|
return incremental_state[full_key]
|
|
|
|
def set_incremental_state(
|
|
self,
|
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
key: str,
|
|
value: Dict[str, Optional[Tensor]],
|
|
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
|
"""Helper for setting incremental state for an nn.Module."""
|
|
if incremental_state is not None:
|
|
full_key = self._get_full_incremental_state_key(key)
|
|
incremental_state[full_key] = value
|
|
return incremental_state
|
|
|
|
|
|
def with_incremental_state(cls):
|
|
cls.__bases__ = (FairseqIncrementalState,) + tuple(
|
|
b for b in cls.__bases__ if b != FairseqIncrementalState
|
|
)
|
|
return cls
|