mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Bug fixes. Added some functionality to help with private extensions
This commit is contained in:
@@ -5,14 +5,18 @@ import itertools
|
||||
|
||||
|
||||
class LosslessLatentDecoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||
super(LosslessLatentDecoder, self).__init__()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
if trainable:
|
||||
self.kernel = nn.Parameter(numpy_kernel)
|
||||
else:
|
||||
self.kernel = numpy_kernel
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
# my old code from tensorflow.
|
||||
@@ -44,14 +48,18 @@ class LosslessLatentDecoder(nn.Module):
|
||||
|
||||
|
||||
class LosslessLatentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||
super(LosslessLatentEncoder, self).__init__()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
if trainable:
|
||||
self.kernel = nn.Parameter(numpy_kernel)
|
||||
else:
|
||||
self.kernel = numpy_kernel
|
||||
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
@@ -82,13 +90,13 @@ class LosslessLatentEncoder(nn.Module):
|
||||
|
||||
|
||||
class LosslessLatentVAE(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||
super(LosslessLatentVAE, self).__init__()
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype)
|
||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable)
|
||||
encoder_out_channels = self.encoder.out_channels
|
||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype)
|
||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable)
|
||||
|
||||
def forward(self, x):
|
||||
latent = self.latent_encoder(x)
|
||||
|
||||
Reference in New Issue
Block a user