mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-01-26 17:20:04 +00:00
76 lines
2.1 KiB
Python
76 lines
2.1 KiB
Python
from typing import Callable
|
|
|
|
import torch
|
|
from torch import zero_
|
|
from torch.nn import Module
|
|
from torch.nn.init import kaiming_normal_, xavier_normal_, normal_
|
|
|
|
|
|
def create_init_function(method: str = 'none') -> Callable[[Module], Module]:
|
|
def init(module: Module):
|
|
if method == 'none':
|
|
return module
|
|
elif method == 'he':
|
|
kaiming_normal_(module.weight)
|
|
return module
|
|
elif method == 'xavier':
|
|
xavier_normal_(module.weight)
|
|
return module
|
|
elif method == 'dcgan':
|
|
normal_(module.weight, 0.0, 0.02)
|
|
return module
|
|
elif method == 'dcgan_001':
|
|
normal_(module.weight, 0.0, 0.01)
|
|
return module
|
|
elif method == "zero":
|
|
with torch.no_grad():
|
|
zero_(module.weight)
|
|
return module
|
|
else:
|
|
raise ("Invalid initialization method %s" % method)
|
|
|
|
return init
|
|
|
|
|
|
class HeInitialization:
|
|
def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'):
|
|
self.nonlinearity = nonlinearity
|
|
self.mode = mode
|
|
self.a = a
|
|
|
|
def __call__(self, module: Module) -> Module:
|
|
with torch.no_grad():
|
|
kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)
|
|
return module
|
|
|
|
|
|
class NormalInitialization:
|
|
def __init__(self, mean: float = 0.0, std: float = 1.0):
|
|
self.std = std
|
|
self.mean = mean
|
|
|
|
def __call__(self, module: Module) -> Module:
|
|
with torch.no_grad():
|
|
normal_(module.weight, self.mean, self.std)
|
|
return module
|
|
|
|
|
|
class XavierInitialization:
|
|
def __init__(self, gain: float = 1.0):
|
|
self.gain = gain
|
|
|
|
def __call__(self, module: Module) -> Module:
|
|
with torch.no_grad():
|
|
xavier_normal_(module.weight, self.gain)
|
|
return module
|
|
|
|
|
|
class ZeroInitialization:
|
|
def __call__(self, module: Module) -> Module:
|
|
with torch.no_grad:
|
|
zero_(module.weight)
|
|
return module
|
|
|
|
class NoInitialization:
|
|
def __call__(self, module: Module) -> Module:
|
|
return module |