mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
212 lines
6.0 KiB
Python
212 lines
6.0 KiB
Python
from diffusers import AutoencoderKL
|
|
from typing import Optional, Union
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput
|
|
from diffusers.models.autoencoders.vae import DecoderOutput
|
|
|
|
|
|
class PixelMixer(nn.Module):
|
|
def __init__(self, in_channels, downscale_factor):
|
|
super(PixelMixer, self).__init__()
|
|
self.downscale_factor = downscale_factor
|
|
self.in_channels = in_channels
|
|
|
|
def forward(self, x):
|
|
latent = self.encode(x)
|
|
out = self.decode(latent)
|
|
return out
|
|
|
|
def encode(self, x):
|
|
return torch.nn.PixelUnshuffle(self.downscale_factor)(x)
|
|
|
|
def decode(self, x):
|
|
return torch.nn.PixelShuffle(self.downscale_factor)(x)
|
|
|
|
|
|
# for reference
|
|
|
|
# none of this matters with llvae, but we need to match the interface (latent_channels might matter)
|
|
|
|
class Config:
|
|
in_channels = 3
|
|
out_channels = 3
|
|
down_block_types = ('1', '1',
|
|
'1', '1')
|
|
up_block_types = ('1', '1',
|
|
'1', '1')
|
|
block_out_channels = (1, 1, 1, 1)
|
|
latent_channels = 192 # usually 4
|
|
norm_num_groups = 32
|
|
sample_size = 512
|
|
# scaling_factor = 1
|
|
# shift_factor = 0
|
|
scaling_factor = 1.8
|
|
shift_factor = -0.123
|
|
# VAE
|
|
# - Mean: -0.12306906282901764
|
|
# - Std: 0.556016206741333
|
|
# Normalization parameters:
|
|
# - Shift factor: -0.12306906282901764
|
|
# - Scaling factor: 1.7985087266803625
|
|
|
|
def __getitem__(cls, x):
|
|
return getattr(cls, x)
|
|
|
|
|
|
class AutoencoderPixelMixer(nn.Module):
|
|
|
|
def __init__(self, in_channels=3, downscale_factor=8):
|
|
super().__init__()
|
|
self.mixer = PixelMixer(in_channels, downscale_factor)
|
|
self._dtype = torch.float32
|
|
self._device = torch.device(
|
|
"cuda" if torch.cuda.is_available() else "cpu")
|
|
self.config = Config()
|
|
|
|
if downscale_factor == 8:
|
|
# we go by len of block out channels in code, so simulate it
|
|
self.config.block_out_channels = (1, 1, 1, 1)
|
|
self.config.latent_channels = 192
|
|
|
|
elif downscale_factor == 16:
|
|
# we go by len of block out channels in code, so simulate it
|
|
self.config.block_out_channels = (1, 1, 1, 1, 1)
|
|
self.config.latent_channels = 768
|
|
else:
|
|
raise ValueError(
|
|
f"downscale_factor {downscale_factor} not supported")
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self._dtype
|
|
|
|
@dtype.setter
|
|
def dtype(self, value):
|
|
self._dtype = value
|
|
|
|
@property
|
|
def device(self):
|
|
return self._device
|
|
|
|
@device.setter
|
|
def device(self, value):
|
|
self._device = value
|
|
|
|
# mimic to from torch
|
|
def to(self, *args, **kwargs):
|
|
# pull out dtype and device if they exist
|
|
if 'dtype' in kwargs:
|
|
self._dtype = kwargs['dtype']
|
|
if 'device' in kwargs:
|
|
self._device = kwargs['device']
|
|
return super().to(*args, **kwargs)
|
|
|
|
def enable_xformers_memory_efficient_attention(self):
|
|
pass
|
|
|
|
# @apply_forward_hook
|
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
|
|
|
h = self.mixer.encode(x)
|
|
|
|
# moments = self.quant_conv(h)
|
|
# posterior = DiagonalGaussianDistribution(moments)
|
|
|
|
if not return_dict:
|
|
return (h,)
|
|
|
|
class FakeDist:
|
|
def __init__(self, x):
|
|
self._sample = x
|
|
|
|
def sample(self):
|
|
return self._sample
|
|
|
|
return AutoencoderKLOutput(latent_dist=FakeDist(h))
|
|
|
|
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
dec = self.mixer.decode(z)
|
|
|
|
if not return_dict:
|
|
return (dec,)
|
|
|
|
return DecoderOutput(sample=dec)
|
|
|
|
# @apply_forward_hook
|
|
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
decoded = self._decode(z).sample
|
|
|
|
if not return_dict:
|
|
return (decoded,)
|
|
|
|
return DecoderOutput(sample=decoded)
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
pass
|
|
|
|
def enable_tiling(self, use_tiling: bool = True):
|
|
pass
|
|
|
|
def disable_tiling(self):
|
|
pass
|
|
|
|
def enable_slicing(self):
|
|
pass
|
|
|
|
def disable_slicing(self):
|
|
pass
|
|
|
|
def set_use_memory_efficient_attention_xformers(self, value: bool = True):
|
|
pass
|
|
|
|
def forward(
|
|
self,
|
|
sample: torch.FloatTensor,
|
|
sample_posterior: bool = False,
|
|
return_dict: bool = True,
|
|
generator: Optional[torch.Generator] = None,
|
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
|
|
x = sample
|
|
posterior = self.encode(x).latent_dist
|
|
if sample_posterior:
|
|
z = posterior.sample(generator=generator)
|
|
else:
|
|
z = posterior.mode()
|
|
dec = self.decode(z).sample
|
|
|
|
if not return_dict:
|
|
return (dec,)
|
|
|
|
return DecoderOutput(sample=dec)
|
|
|
|
|
|
# test it
|
|
if __name__ == '__main__':
|
|
import os
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
user_path = os.path.expanduser('~')
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
dtype = torch.float32
|
|
|
|
input_path = os.path.join(user_path, "Pictures/test/test.jpg")
|
|
output_path = os.path.join(user_path, "Pictures/test/test.jpg")
|
|
img = Image.open(input_path)
|
|
img_tensor = transforms.ToTensor()(img)
|
|
img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype)
|
|
print("input_shape: ", list(img_tensor.shape))
|
|
vae = PixelMixer(in_channels=3, downscale_factor=8)
|
|
latent = vae.encode(img_tensor)
|
|
print("latent_shape: ", list(latent.shape))
|
|
out_tensor = vae.decode(latent)
|
|
print("out_shape: ", list(out_tensor.shape))
|
|
|
|
mse_loss = nn.MSELoss()
|
|
mse = mse_loss(img_tensor, out_tensor)
|
|
print("roundtrip_loss: ", mse.item())
|
|
out_img = transforms.ToPILImage()(out_tensor.squeeze(0))
|
|
out_img.save(output_path)
|