mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
136 lines
3.3 KiB
Python
136 lines
3.3 KiB
Python
from diffusers import AutoencoderKL
|
|
from typing import Optional, Union
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput
|
|
from diffusers.models.autoencoders.vae import DecoderOutput
|
|
|
|
|
|
class Config:
|
|
in_channels = 3
|
|
out_channels = 3
|
|
down_block_types = ("1",)
|
|
up_block_types = ("1",)
|
|
block_out_channels = (1,)
|
|
latent_channels = 3 # usually 4
|
|
norm_num_groups = 1
|
|
sample_size = 512
|
|
scaling_factor = 1.0
|
|
# scaling_factor = 1.8
|
|
shift_factor = 0
|
|
|
|
def __getitem__(cls, x):
|
|
return getattr(cls, x)
|
|
|
|
|
|
class FakeVAE(nn.Module):
|
|
def __init__(self, scaling_factor=1.0):
|
|
super().__init__()
|
|
self._dtype = torch.float32
|
|
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.config = Config()
|
|
self.config.scaling_factor = scaling_factor
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self._dtype
|
|
|
|
@dtype.setter
|
|
def dtype(self, value):
|
|
self._dtype = value
|
|
|
|
@property
|
|
def device(self):
|
|
return self._device
|
|
|
|
@device.setter
|
|
def device(self, value):
|
|
self._device = value
|
|
|
|
# mimic to from torch
|
|
def to(self, *args, **kwargs):
|
|
# pull out dtype and device if they exist
|
|
if "dtype" in kwargs:
|
|
self._dtype = kwargs["dtype"]
|
|
if "device" in kwargs:
|
|
self._device = kwargs["device"]
|
|
return super().to(*args, **kwargs)
|
|
|
|
def enable_xformers_memory_efficient_attention(self):
|
|
pass
|
|
|
|
# @apply_forward_hook
|
|
def encode(
|
|
self, x: torch.FloatTensor, return_dict: bool = True
|
|
) -> AutoencoderKLOutput:
|
|
h = x
|
|
|
|
# moments = self.quant_conv(h)
|
|
# posterior = DiagonalGaussianDistribution(moments)
|
|
|
|
if not return_dict:
|
|
return (h,)
|
|
|
|
class FakeDist:
|
|
def __init__(self, x):
|
|
self._sample = x
|
|
|
|
def sample(self):
|
|
return self._sample
|
|
|
|
return AutoencoderKLOutput(latent_dist=FakeDist(h))
|
|
|
|
def _decode(
|
|
self, z: torch.FloatTensor, return_dict: bool = True
|
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
dec = z
|
|
|
|
if not return_dict:
|
|
return (dec,)
|
|
|
|
return DecoderOutput(sample=dec)
|
|
|
|
# @apply_forward_hook
|
|
def decode(
|
|
self, z: torch.FloatTensor, return_dict: bool = True
|
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
decoded = self._decode(z).sample
|
|
|
|
if not return_dict:
|
|
return (decoded,)
|
|
|
|
return DecoderOutput(sample=decoded)
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
pass
|
|
|
|
def enable_tiling(self, use_tiling: bool = True):
|
|
pass
|
|
|
|
def disable_tiling(self):
|
|
pass
|
|
|
|
def enable_slicing(self):
|
|
pass
|
|
|
|
def disable_slicing(self):
|
|
pass
|
|
|
|
def set_use_memory_efficient_attention_xformers(self, value: bool = True):
|
|
pass
|
|
|
|
def forward(
|
|
self,
|
|
sample: torch.FloatTensor,
|
|
sample_posterior: bool = False,
|
|
return_dict: bool = True,
|
|
generator: Optional[torch.Generator] = None,
|
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
|
dec = sample
|
|
|
|
if not return_dict:
|
|
return (dec,)
|
|
|
|
return DecoderOutput(sample=dec)
|