mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Removed train config, updating it, and added my llvae as pytorch model
This commit is contained in:
@@ -1,36 +0,0 @@
|
||||
{
|
||||
"job": "train",
|
||||
"config": {
|
||||
"name": "name_of_your_model",
|
||||
"base_model": "/path/to/base/model",
|
||||
"training_folder": "/path/to/output/folder",
|
||||
"is_v2": false,
|
||||
"device": "cuda",
|
||||
"gradient_accumulation_steps": 1,
|
||||
"mixed_precision": "fp16",
|
||||
"logging_dir": "/path/to/tensorboard/log/folder",
|
||||
|
||||
"process": [
|
||||
{
|
||||
"type": "fine_tune"
|
||||
}
|
||||
]
|
||||
},
|
||||
"meta": {
|
||||
"name": "[name]",
|
||||
"description": "A short description of your model",
|
||||
"trigger_words": [
|
||||
"put",
|
||||
"trigger",
|
||||
"words",
|
||||
"here"
|
||||
],
|
||||
"version": "0.1",
|
||||
"creator": {
|
||||
"name": "Your Name",
|
||||
"email": "your@email.com",
|
||||
"website": "https://yourwebsite.com"
|
||||
},
|
||||
"any": "All meta data above is arbitrary, it can be whatever you want."
|
||||
}
|
||||
}
|
||||
118
toolkit/llvae.py
Normal file
118
toolkit/llvae.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
|
||||
class LosslessLatentDecoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth):
|
||||
super(LosslessLatentDecoder, self).__init__()
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).float()
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
# my old code from tensorflow.
|
||||
# tensorflow kernel is (height, width, out_channels, in_channels)
|
||||
# pytorch kernel is (in_channels, out_channels, height, width)
|
||||
out_channels = self.out_channels
|
||||
|
||||
# kernel_shape = [kernel_filter_size, kernel_filter_size, out_channels, in_channels] # tensorflow
|
||||
kernel_shape = [in_channels, out_channels, latent_depth, latent_depth] # pytorch
|
||||
kernel = np.zeros(kernel_shape, np.float32)
|
||||
|
||||
# Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel.
|
||||
for c in range(0, out_channels):
|
||||
i = 0
|
||||
for x, y in itertools.product(range(latent_depth), repeat=2):
|
||||
# kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow
|
||||
kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch
|
||||
i += 1
|
||||
|
||||
return kernel
|
||||
|
||||
def forward(self, x):
|
||||
# Deconvolve input tensor with the kernel
|
||||
return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
|
||||
|
||||
|
||||
class LosslessLatentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth):
|
||||
super(LosslessLatentEncoder, self).__init__()
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).float()
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
# my old code from tensorflow.
|
||||
# tensorflow kernel is (height, width, in_channels, out_channels)
|
||||
# pytorch kernel is (out_channels, in_channels, height, width)
|
||||
out_channels = self.out_channels
|
||||
|
||||
# kernel_shape = [latent_depth, latent_depth, in_channels, out_channels] # tensorflow
|
||||
kernel_shape = [out_channels, in_channels, latent_depth, latent_depth] # pytorch
|
||||
kernel = np.zeros(kernel_shape, np.float32)
|
||||
|
||||
# Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel.
|
||||
for c in range(0, in_channels):
|
||||
i = 0
|
||||
for x, y in itertools.product(range(latent_depth), repeat=2):
|
||||
# kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow
|
||||
kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch
|
||||
i += 1
|
||||
return kernel
|
||||
|
||||
def forward(self, x):
|
||||
# Convolve input tensor with the kernel
|
||||
return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
|
||||
|
||||
|
||||
class LosslessLatentVAE(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth):
|
||||
super(LosslessLatentVAE, self).__init__()
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth)
|
||||
encoder_out_channels = self.encoder.out_channels
|
||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth)
|
||||
|
||||
def forward(self, x):
|
||||
latent = self.latent_encoder(x)
|
||||
out = self.latent_decoder(latent)
|
||||
return out
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
|
||||
# test it
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
user_path = os.path.expanduser('~')
|
||||
|
||||
input_path = os.path.join(user_path, "Pictures/sample_2_512.png")
|
||||
output_path = os.path.join(user_path, "Pictures/sample_2_512_llvae.png")
|
||||
img = Image.open(input_path)
|
||||
img_tensor = transforms.ToTensor()(img)
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
print("input_shape: ", list(img_tensor.shape))
|
||||
vae = LosslessLatentVAE(in_channels=3, latent_depth=8)
|
||||
latent = vae.encode(img_tensor)
|
||||
print("latent_shape: ", list(latent.shape))
|
||||
out_tensor = vae.decode(latent)
|
||||
print("out_shape: ", list(out_tensor.shape))
|
||||
|
||||
mse_loss = nn.MSELoss()
|
||||
mse = mse_loss(img_tensor, out_tensor)
|
||||
print("roundtrip_loss: ", mse.item())
|
||||
out_img = transforms.ToPILImage()(out_tensor.squeeze(0))
|
||||
out_img.save(output_path)
|
||||
Reference in New Issue
Block a user