Added my good ole pattern loss. God I love that thing, conv transpose pattern instantly wiped from vae

This commit is contained in:
Jaret Burkett
2023-07-20 15:44:16 -06:00
parent 982e0be7a9
commit 0761656a90
4 changed files with 56 additions and 12 deletions

View File

@@ -5,13 +5,14 @@ import itertools
class LosslessLatentDecoder(nn.Module):
def __init__(self, in_channels, latent_depth):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
super(LosslessLatentDecoder, self).__init__()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.latent_depth = latent_depth
self.in_channels = in_channels
self.out_channels = int(in_channels // (latent_depth * latent_depth))
numpy_kernel = self.build_kernel(in_channels, latent_depth)
self.kernel = torch.from_numpy(numpy_kernel).float()
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
def build_kernel(self, in_channels, latent_depth):
# my old code from tensorflow.
@@ -39,13 +40,15 @@ class LosslessLatentDecoder(nn.Module):
class LosslessLatentEncoder(nn.Module):
def __init__(self, in_channels, latent_depth):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
super(LosslessLatentEncoder, self).__init__()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.latent_depth = latent_depth
self.in_channels = in_channels
self.out_channels = int(in_channels * (latent_depth * latent_depth))
numpy_kernel = self.build_kernel(in_channels, latent_depth)
self.kernel = torch.from_numpy(numpy_kernel).float()
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
def build_kernel(self, in_channels, latent_depth):
# my old code from tensorflow.
@@ -72,13 +75,13 @@ class LosslessLatentEncoder(nn.Module):
class LosslessLatentVAE(nn.Module):
def __init__(self, in_channels, latent_depth):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
super(LosslessLatentVAE, self).__init__()
self.latent_depth = latent_depth
self.in_channels = in_channels
self.encoder = LosslessLatentEncoder(in_channels, latent_depth)
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype)
encoder_out_channels = self.encoder.out_channels
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth)
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype)
def forward(self, x):
latent = self.latent_encoder(x)