mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-11 22:40:03 +00:00
41 lines
1.7 KiB
Python
41 lines
1.7 KiB
Python
from typing import Optional, Callable, Union
|
|
|
|
from torch.nn import Module
|
|
|
|
from tha3.module.module_factory import ModuleFactory
|
|
from tha3.nn.init_function import create_init_function
|
|
from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory
|
|
from tha3.nn.normalization import NormalizationLayerFactory
|
|
from tha3.nn.spectral_norm import apply_spectral_norm
|
|
|
|
|
|
def wrap_conv_or_linear_module(module: Module,
|
|
initialization_method: Union[str, Callable[[Module], Module]],
|
|
use_spectral_norm: bool):
|
|
if isinstance(initialization_method, str):
|
|
init = create_init_function(initialization_method)
|
|
else:
|
|
init = initialization_method
|
|
return apply_spectral_norm(init(module), use_spectral_norm)
|
|
|
|
|
|
class BlockArgs:
|
|
def __init__(self,
|
|
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
|
|
use_spectral_norm: bool = False,
|
|
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
|
|
nonlinearity_factory: Optional[ModuleFactory] = None):
|
|
self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
|
|
self.normalization_layer_factory = normalization_layer_factory
|
|
self.use_spectral_norm = use_spectral_norm
|
|
self.initialization_method = initialization_method
|
|
|
|
def wrap_module(self, module: Module) -> Module:
|
|
return wrap_conv_or_linear_module(module, self.get_init_func(), self.use_spectral_norm)
|
|
|
|
def get_init_func(self) -> Callable[[Module], Module]:
|
|
if isinstance(self.initialization_method, str):
|
|
return create_init_function(self.initialization_method)
|
|
else:
|
|
return self.initialization_method
|