mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Removed train config, updating it, and added my llvae as pytorch model
This commit is contained in:
118
toolkit/llvae.py
Normal file
118
toolkit/llvae.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
|
||||
class LosslessLatentDecoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth):
|
||||
super(LosslessLatentDecoder, self).__init__()
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).float()
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
# my old code from tensorflow.
|
||||
# tensorflow kernel is (height, width, out_channels, in_channels)
|
||||
# pytorch kernel is (in_channels, out_channels, height, width)
|
||||
out_channels = self.out_channels
|
||||
|
||||
# kernel_shape = [kernel_filter_size, kernel_filter_size, out_channels, in_channels] # tensorflow
|
||||
kernel_shape = [in_channels, out_channels, latent_depth, latent_depth] # pytorch
|
||||
kernel = np.zeros(kernel_shape, np.float32)
|
||||
|
||||
# Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel.
|
||||
for c in range(0, out_channels):
|
||||
i = 0
|
||||
for x, y in itertools.product(range(latent_depth), repeat=2):
|
||||
# kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow
|
||||
kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch
|
||||
i += 1
|
||||
|
||||
return kernel
|
||||
|
||||
def forward(self, x):
|
||||
# Deconvolve input tensor with the kernel
|
||||
return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
|
||||
|
||||
|
||||
class LosslessLatentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth):
|
||||
super(LosslessLatentEncoder, self).__init__()
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).float()
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
# my old code from tensorflow.
|
||||
# tensorflow kernel is (height, width, in_channels, out_channels)
|
||||
# pytorch kernel is (out_channels, in_channels, height, width)
|
||||
out_channels = self.out_channels
|
||||
|
||||
# kernel_shape = [latent_depth, latent_depth, in_channels, out_channels] # tensorflow
|
||||
kernel_shape = [out_channels, in_channels, latent_depth, latent_depth] # pytorch
|
||||
kernel = np.zeros(kernel_shape, np.float32)
|
||||
|
||||
# Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel.
|
||||
for c in range(0, in_channels):
|
||||
i = 0
|
||||
for x, y in itertools.product(range(latent_depth), repeat=2):
|
||||
# kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow
|
||||
kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch
|
||||
i += 1
|
||||
return kernel
|
||||
|
||||
def forward(self, x):
|
||||
# Convolve input tensor with the kernel
|
||||
return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
|
||||
|
||||
|
||||
class LosslessLatentVAE(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth):
|
||||
super(LosslessLatentVAE, self).__init__()
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth)
|
||||
encoder_out_channels = self.encoder.out_channels
|
||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth)
|
||||
|
||||
def forward(self, x):
|
||||
latent = self.latent_encoder(x)
|
||||
out = self.latent_decoder(latent)
|
||||
return out
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
|
||||
# test it
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
user_path = os.path.expanduser('~')
|
||||
|
||||
input_path = os.path.join(user_path, "Pictures/sample_2_512.png")
|
||||
output_path = os.path.join(user_path, "Pictures/sample_2_512_llvae.png")
|
||||
img = Image.open(input_path)
|
||||
img_tensor = transforms.ToTensor()(img)
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
print("input_shape: ", list(img_tensor.shape))
|
||||
vae = LosslessLatentVAE(in_channels=3, latent_depth=8)
|
||||
latent = vae.encode(img_tensor)
|
||||
print("latent_shape: ", list(latent.shape))
|
||||
out_tensor = vae.decode(latent)
|
||||
print("out_shape: ", list(out_tensor.shape))
|
||||
|
||||
mse_loss = nn.MSELoss()
|
||||
mse = mse_loss(img_tensor, out_tensor)
|
||||
print("roundtrip_loss: ", mse.item())
|
||||
out_img = transforms.ToPILImage()(out_tensor.squeeze(0))
|
||||
out_img.save(output_path)
|
||||
Reference in New Issue
Block a user