mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Pixel shuffle adapter. Some bug fixes thrown in
This commit is contained in:
@@ -13,7 +13,7 @@ SaveFormat = Literal['safetensors', 'diffusers']
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.guidance import GuidanceType
|
||||
from toolkit.logging import EmptyLogger
|
||||
from toolkit.logging_aitk import EmptyLogger
|
||||
else:
|
||||
EmptyLogger = None
|
||||
|
||||
@@ -252,6 +252,9 @@ class AdapterConfig:
|
||||
self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0)
|
||||
self.has_inpainting_input: bool = kwargs.get('has_inpainting_input', False)
|
||||
self.invert_inpaint_mask_chance: float = kwargs.get('invert_inpaint_mask_chance', 0.0)
|
||||
|
||||
# for subpixel adapter
|
||||
self.subpixel_downscale_factor: int = kwargs.get('subpixel_downscale_factor', 8)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
|
||||
@@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.control_lora_adapter import ControlLoraAdapter
|
||||
from toolkit.models.subpixel_adapter import SubpixelAdapter
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
from toolkit.models.single_value_adapter import SingleValueAdapter
|
||||
from toolkit.models.te_adapter import TEAdapter
|
||||
@@ -103,6 +104,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.single_value_adapter: SingleValueAdapter = None
|
||||
self.redux_adapter: ReduxImageEncoder = None
|
||||
self.control_lora: ControlLoraAdapter = None
|
||||
self.subpixel_adapter: SubpixelAdapter = None
|
||||
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
@@ -253,6 +255,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
config=self.config,
|
||||
train_config=self.train_config
|
||||
)
|
||||
elif self.adapter_type == 'subpixel':
|
||||
self.subpixel_adapter = SubpixelAdapter(
|
||||
self,
|
||||
sd=self.sd_ref(),
|
||||
config=self.config,
|
||||
train_config=self.train_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
@@ -284,7 +293,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
def setup_clip(self):
|
||||
adapter_config = self.config
|
||||
sd = self.sd_ref()
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora"]:
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]:
|
||||
return
|
||||
if self.config.type == 'photo_maker':
|
||||
try:
|
||||
@@ -502,6 +511,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.control_lora.load_weights(new_dict, strict=strict)
|
||||
|
||||
if self.adapter_type == 'subpixel':
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.subpixel_adapter.load_weights(new_dict, strict=strict)
|
||||
|
||||
pass
|
||||
|
||||
@@ -558,6 +575,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
elif self.adapter_type == 'subpixel':
|
||||
d = self.subpixel_adapter.get_state_dict()
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -702,7 +724,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora']:
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel']:
|
||||
return prompt
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
# todo allow for training
|
||||
@@ -1225,6 +1247,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
param_list = self.control_lora.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
elif self.config.type == 'subpixel':
|
||||
param_list = self.subpixel_adapter.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -381,6 +381,8 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
sd: 'StableDiffusion' = None,
|
||||
):
|
||||
self.dataset_config = dataset_config
|
||||
# update bucket divisibility
|
||||
self.dataset_config.bucket_tolerance = sd.get_bucket_divisibility()
|
||||
self.is_video = dataset_config.num_frames > 1
|
||||
super().__init__()
|
||||
folder_path = dataset_config.folder_path
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import gc
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
@@ -230,6 +231,20 @@ class BaseModel:
|
||||
def is_lumina2(self):
|
||||
return self.arch == 'lumina2'
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
if self.vae is None:
|
||||
return 8
|
||||
try:
|
||||
divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
except:
|
||||
# if we have a custom vae, it might not have this
|
||||
divisibility = 8
|
||||
|
||||
# flux packs this again,
|
||||
if self.is_flux:
|
||||
divisibility = divisibility * 4
|
||||
return divisibility
|
||||
|
||||
# these must be implemented in child classes
|
||||
def load_model(self):
|
||||
# override this in child classes
|
||||
@@ -797,13 +812,20 @@ class BaseModel:
|
||||
self.unet.to(self.device_torch)
|
||||
if self.unet.dtype != self.torch_dtype:
|
||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||
|
||||
# check if get_noise prediction has guidance_embedding_scale
|
||||
# if it does not, we dont pass it
|
||||
signatures = inspect.signature(self.get_noise_prediction).parameters
|
||||
|
||||
if 'guidance_embedding_scale' in signatures:
|
||||
kwargs['guidance_embedding_scale'] = guidance_embedding_scale
|
||||
if 'bypass_guidance_embedding' in signatures:
|
||||
kwargs['bypass_guidance_embedding'] = bypass_guidance_embedding
|
||||
|
||||
noise_pred = self.get_noise_prediction(
|
||||
latent_model_input=latent_model_input,
|
||||
timestep=timestep,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_embedding_scale=guidance_embedding_scale,
|
||||
bypass_guidance_embedding=bypass_guidance_embedding,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
302
toolkit/models/subpixel_adapter.py
Normal file
302
toolkit/models/subpixel_adapter.py
Normal file
@@ -0,0 +1,302 @@
|
||||
import inspect
|
||||
import weakref
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from diffusers import FluxTransformer2DModel
|
||||
# weakref
|
||||
from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
|
||||
class InOutModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: 'SubpixelAdapter',
|
||||
orig_layer: torch.nn.Linear,
|
||||
in_channels=64,
|
||||
out_channels=3072
|
||||
):
|
||||
super().__init__()
|
||||
# only do the weight for the new input. We combine with the original linear layer
|
||||
self.x_embedder = torch.nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.proj_out = torch.nn.Linear(
|
||||
out_channels,
|
||||
in_channels,
|
||||
bias=True,
|
||||
)
|
||||
# make sure the weight is float32
|
||||
self.x_embedder.weight.data = self.x_embedder.weight.data.float()
|
||||
self.x_embedder.bias.data = self.x_embedder.bias.data.float()
|
||||
|
||||
self.proj_out.weight.data = self.proj_out.weight.data.float()
|
||||
self.proj_out.bias.data = self.proj_out.bias.data.float()
|
||||
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer)
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls,
|
||||
model: FluxTransformer2DModel,
|
||||
adapter: 'SubpixelAdapter',
|
||||
num_channels: int = 768,
|
||||
downscale_factor: int = 8
|
||||
):
|
||||
if model.__class__.__name__ == 'FluxTransformer2DModel':
|
||||
|
||||
x_embedder: torch.nn.Linear = model.x_embedder
|
||||
proj_out: torch.nn.Linear = model.proj_out
|
||||
in_out_module = cls(
|
||||
adapter,
|
||||
orig_layer=x_embedder,
|
||||
in_channels=num_channels,
|
||||
out_channels=x_embedder.out_features,
|
||||
)
|
||||
|
||||
# hijack the forward method
|
||||
x_embedder._orig_ctrl_lora_forward = x_embedder.forward
|
||||
x_embedder.forward = in_out_module.in_forward
|
||||
proj_out._orig_ctrl_lora_forward = proj_out.forward
|
||||
proj_out.forward = in_out_module.out_forward
|
||||
|
||||
# update the config of the transformer
|
||||
model.config.in_channels = num_channels
|
||||
model.config["in_channels"] = num_channels
|
||||
model.config.out_channels = num_channels
|
||||
model.config["out_channels"] = num_channels
|
||||
|
||||
# replace the vae of the model
|
||||
sd = adapter.sd_ref()
|
||||
sd.vae = AutoencoderPixelMixer(
|
||||
in_channels=3,
|
||||
downscale_factor=downscale_factor
|
||||
)
|
||||
|
||||
sd.pipeline.vae = sd.vae
|
||||
|
||||
return in_out_module
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
|
||||
def in_forward(self, x):
|
||||
if not self.is_active:
|
||||
# make sure lora is not active
|
||||
if self.adapter_ref().control_lora is not None:
|
||||
self.adapter_ref().control_lora.is_active = False
|
||||
return self.orig_layer_ref()._orig_ctrl_lora_forward(x)
|
||||
|
||||
# make sure lora is active
|
||||
if self.adapter_ref().control_lora is not None:
|
||||
self.adapter_ref().control_lora.is_active = True
|
||||
|
||||
orig_device = x.device
|
||||
orig_dtype = x.dtype
|
||||
|
||||
x = x.to(self.x_embedder.weight.device, dtype=self.x_embedder.weight.dtype)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
|
||||
x = x.to(orig_device, dtype=orig_dtype)
|
||||
return x
|
||||
|
||||
def out_forward(self, x):
|
||||
if not self.is_active:
|
||||
# make sure lora is not active
|
||||
if self.adapter_ref().control_lora is not None:
|
||||
self.adapter_ref().control_lora.is_active = False
|
||||
return self.orig_layer_ref()._orig_ctrl_lora_forward(x)
|
||||
|
||||
# make sure lora is active
|
||||
if self.adapter_ref().control_lora is not None:
|
||||
self.adapter_ref().control_lora.is_active = True
|
||||
|
||||
orig_device = x.device
|
||||
orig_dtype = x.dtype
|
||||
|
||||
x = x.to(self.proj_out.weight.device, dtype=self.proj_out.weight.dtype)
|
||||
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = x.to(orig_device, dtype=orig_dtype)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SubpixelAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: 'CustomAdapter',
|
||||
sd: 'StableDiffusion',
|
||||
config: 'AdapterConfig',
|
||||
train_config: 'TrainConfig'
|
||||
):
|
||||
super().__init__()
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.model_config: ModelConfig = sd.model_config
|
||||
self.network_config = config.lora_config
|
||||
self.train_config = train_config
|
||||
self.device_torch = sd.device_torch
|
||||
self.control_lora = None
|
||||
|
||||
if self.network_config is not None:
|
||||
|
||||
network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
|
||||
if hasattr(sd, 'target_lora_modules'):
|
||||
network_kwargs['target_lin_modules'] = self.sd.target_lora_modules
|
||||
|
||||
if 'ignore_if_contains' not in network_kwargs:
|
||||
network_kwargs['ignore_if_contains'] = []
|
||||
|
||||
# always ignore x_embedder
|
||||
network_kwargs['ignore_if_contains'].append('transformer.x_embedder')
|
||||
network_kwargs['ignore_if_contains'].append('transformer.proj_out')
|
||||
|
||||
self.control_lora = LoRASpecialNetwork(
|
||||
text_encoder=sd.text_encoder,
|
||||
unet=sd.unet,
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.linear_alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=self.network_config.conv,
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
|
||||
is_v2=self.model_config.is_v2,
|
||||
is_v3=self.model_config.is_v3,
|
||||
is_pixart=self.model_config.is_pixart,
|
||||
is_auraflow=self.model_config.is_auraflow,
|
||||
is_flux=self.model_config.is_flux,
|
||||
is_lumina2=self.model_config.is_lumina2,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
is_vega=self.model_config.is_vega,
|
||||
dropout=self.network_config.dropout,
|
||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||
use_bias=False,
|
||||
is_lorm=False,
|
||||
network_config=self.network_config,
|
||||
network_type=self.network_config.type,
|
||||
transformer_only=self.network_config.transformer_only,
|
||||
is_transformer=sd.is_transformer,
|
||||
base_model=sd,
|
||||
**network_kwargs
|
||||
)
|
||||
self.control_lora.force_to(self.device_torch, dtype=torch.float32)
|
||||
self.control_lora._update_torch_multiplier()
|
||||
self.control_lora.apply_to(
|
||||
sd.text_encoder,
|
||||
sd.unet,
|
||||
self.train_config.train_text_encoder,
|
||||
self.train_config.train_unet
|
||||
)
|
||||
self.control_lora.can_merge_in = False
|
||||
self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet)
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.control_lora.enable_gradient_checkpointing()
|
||||
|
||||
downscale_factor = config.subpixel_downscale_factor
|
||||
if downscale_factor == 8:
|
||||
num_channels = 768
|
||||
elif downscale_factor == 16:
|
||||
num_channels = 3072
|
||||
else:
|
||||
raise ValueError(
|
||||
f"downscale_factor {downscale_factor} not supported"
|
||||
)
|
||||
|
||||
self.in_out: InOutModule = InOutModule.from_model(
|
||||
sd.unet_unwrapped,
|
||||
self,
|
||||
num_channels=num_channels, # packed channels
|
||||
downscale_factor=downscale_factor
|
||||
)
|
||||
self.in_out.to(self.device_torch)
|
||||
|
||||
def get_params(self):
|
||||
if self.control_lora is not None:
|
||||
config = {
|
||||
'text_encoder_lr': self.train_config.lr,
|
||||
'unet_lr': self.train_config.lr,
|
||||
}
|
||||
sig = inspect.signature(self.control_lora.prepare_optimizer_params)
|
||||
if 'default_lr' in sig.parameters:
|
||||
config['default_lr'] = self.train_config.lr
|
||||
if 'learning_rate' in sig.parameters:
|
||||
config['learning_rate'] = self.train_config.lr
|
||||
params_net = self.control_lora.prepare_optimizer_params(
|
||||
**config
|
||||
)
|
||||
|
||||
# we want only tensors here
|
||||
params = []
|
||||
for p in params_net:
|
||||
if isinstance(p, dict):
|
||||
params += p["params"]
|
||||
elif isinstance(p, torch.Tensor):
|
||||
params.append(p)
|
||||
elif isinstance(p, list):
|
||||
params += p
|
||||
else:
|
||||
params = []
|
||||
|
||||
# make sure the embedder is float32
|
||||
self.in_out.to(torch.float32)
|
||||
|
||||
params += list(self.in_out.parameters())
|
||||
|
||||
# we need to be able to yield from the list like yield from params
|
||||
|
||||
return params
|
||||
|
||||
def load_weights(self, state_dict, strict=True):
|
||||
lora_sd = {}
|
||||
img_embedder_sd = {}
|
||||
for key, value in state_dict.items():
|
||||
if "transformer.x_embedder" in key:
|
||||
new_key = key.replace("transformer.", "")
|
||||
img_embedder_sd[new_key] = value
|
||||
elif "transformer.proj_out" in key:
|
||||
new_key = key.replace("transformer.", "")
|
||||
img_embedder_sd[new_key] = value
|
||||
else:
|
||||
lora_sd[key] = value
|
||||
|
||||
# todo process state dict before loading
|
||||
if self.control_lora is not None:
|
||||
self.control_lora.load_weights(lora_sd)
|
||||
# automatically upgrade the x imbedder if more dims are added
|
||||
self.in_out.load_state_dict(img_embedder_sd, strict=False)
|
||||
|
||||
def get_state_dict(self):
|
||||
if self.control_lora is not None:
|
||||
lora_sd = self.control_lora.get_state_dict(dtype=torch.float32)
|
||||
else:
|
||||
lora_sd = {}
|
||||
# todo make sure we match loras elseware.
|
||||
img_embedder_sd = self.in_out.state_dict()
|
||||
for key, value in img_embedder_sd.items():
|
||||
lora_sd[f"transformer.{key}"] = value
|
||||
return lora_sd
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
@@ -318,6 +318,9 @@ class Wan21(BaseModel):
|
||||
|
||||
# cache for holding noise
|
||||
self.effective_noise = None
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
return 16
|
||||
|
||||
# static method to get the scheduler
|
||||
@staticmethod
|
||||
|
||||
211
toolkit/pixel_shuffle_encoder.py
Normal file
211
toolkit/pixel_shuffle_encoder.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from diffusers import AutoencoderKL
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput
|
||||
from diffusers.models.autoencoders.vae import DecoderOutput
|
||||
|
||||
|
||||
class PixelMixer(nn.Module):
|
||||
def __init__(self, in_channels, downscale_factor):
|
||||
super(PixelMixer, self).__init__()
|
||||
self.downscale_factor = downscale_factor
|
||||
self.in_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
latent = self.encode(x)
|
||||
out = self.decode(latent)
|
||||
return out
|
||||
|
||||
def encode(self, x):
|
||||
return torch.nn.PixelUnshuffle(self.downscale_factor)(x)
|
||||
|
||||
def decode(self, x):
|
||||
return torch.nn.PixelShuffle(self.downscale_factor)(x)
|
||||
|
||||
|
||||
# for reference
|
||||
|
||||
# none of this matters with llvae, but we need to match the interface (latent_channels might matter)
|
||||
|
||||
class Config:
|
||||
in_channels = 3
|
||||
out_channels = 3
|
||||
down_block_types = ('1', '1',
|
||||
'1', '1')
|
||||
up_block_types = ('1', '1',
|
||||
'1', '1')
|
||||
block_out_channels = (1, 1, 1, 1)
|
||||
latent_channels = 192 # usually 4
|
||||
norm_num_groups = 32
|
||||
sample_size = 512
|
||||
# scaling_factor = 1
|
||||
# shift_factor = 0
|
||||
scaling_factor = 1.8
|
||||
shift_factor = -0.123
|
||||
# VAE
|
||||
# - Mean: -0.12306906282901764
|
||||
# - Std: 0.556016206741333
|
||||
# Normalization parameters:
|
||||
# - Shift factor: -0.12306906282901764
|
||||
# - Scaling factor: 1.7985087266803625
|
||||
|
||||
def __getitem__(cls, x):
|
||||
return getattr(cls, x)
|
||||
|
||||
|
||||
class AutoencoderPixelMixer(nn.Module):
|
||||
|
||||
def __init__(self, in_channels=3, downscale_factor=8):
|
||||
super().__init__()
|
||||
self.mixer = PixelMixer(in_channels, downscale_factor)
|
||||
self._dtype = torch.float32
|
||||
self._device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.config = Config()
|
||||
|
||||
if downscale_factor == 8:
|
||||
# we go by len of block out channels in code, so simulate it
|
||||
self.config.block_out_channels = (1, 1, 1, 1)
|
||||
self.config.latent_channels = 192
|
||||
|
||||
elif downscale_factor == 16:
|
||||
# we go by len of block out channels in code, so simulate it
|
||||
self.config.block_out_channels = (1, 1, 1, 1, 1)
|
||||
self.config.latent_channels = 768
|
||||
else:
|
||||
raise ValueError(
|
||||
f"downscale_factor {downscale_factor} not supported")
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@dtype.setter
|
||||
def dtype(self, value):
|
||||
self._dtype = value
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@device.setter
|
||||
def device(self, value):
|
||||
self._device = value
|
||||
|
||||
# mimic to from torch
|
||||
def to(self, *args, **kwargs):
|
||||
# pull out dtype and device if they exist
|
||||
if 'dtype' in kwargs:
|
||||
self._dtype = kwargs['dtype']
|
||||
if 'device' in kwargs:
|
||||
self._device = kwargs['device']
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
pass
|
||||
|
||||
# @apply_forward_hook
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
|
||||
h = self.mixer.encode(x)
|
||||
|
||||
# moments = self.quant_conv(h)
|
||||
# posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (h,)
|
||||
|
||||
class FakeDist:
|
||||
def __init__(self, x):
|
||||
self._sample = x
|
||||
|
||||
def sample(self):
|
||||
return self._sample
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=FakeDist(h))
|
||||
|
||||
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
dec = self.mixer.decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# @apply_forward_hook
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
pass
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
pass
|
||||
|
||||
def disable_tiling(self):
|
||||
pass
|
||||
|
||||
def enable_slicing(self):
|
||||
pass
|
||||
|
||||
def disable_slicing(self):
|
||||
pass
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, value: bool = True):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
|
||||
# test it
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
user_path = os.path.expanduser('~')
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
input_path = os.path.join(user_path, "Pictures/test/test.jpg")
|
||||
output_path = os.path.join(user_path, "Pictures/test/test.jpg")
|
||||
img = Image.open(input_path)
|
||||
img_tensor = transforms.ToTensor()(img)
|
||||
img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype)
|
||||
print("input_shape: ", list(img_tensor.shape))
|
||||
vae = PixelMixer(in_channels=3, downscale_factor=8)
|
||||
latent = vae.encode(img_tensor)
|
||||
print("latent_shape: ", list(latent.shape))
|
||||
out_tensor = vae.decode(latent)
|
||||
print("out_shape: ", list(out_tensor.shape))
|
||||
|
||||
mse_loss = nn.MSELoss()
|
||||
mse = mse_loss(img_tensor, out_tensor)
|
||||
print("roundtrip_loss: ", mse.item())
|
||||
out_img = transforms.ToPILImage()(out_tensor.squeeze(0))
|
||||
out_img.save(output_path)
|
||||
@@ -249,6 +249,17 @@ class StableDiffusion:
|
||||
@property
|
||||
def unet_unwrapped(self):
|
||||
return unwrap_model(self.unet)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
if self.vae is None:
|
||||
return 8
|
||||
divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
# flux packs this again,
|
||||
if self.is_flux:
|
||||
divisibility = divisibility * 4
|
||||
return divisibility
|
||||
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
@@ -1721,6 +1732,7 @@ class StableDiffusion:
|
||||
pixel_width=None,
|
||||
batch_size=1,
|
||||
noise_offset=0.0,
|
||||
num_channels=None,
|
||||
):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
if height is None and pixel_height is None:
|
||||
@@ -1732,10 +1744,11 @@ class StableDiffusion:
|
||||
if width is None:
|
||||
width = pixel_width // VAE_SCALE_FACTOR
|
||||
|
||||
num_channels = self.unet_unwrapped.config['in_channels']
|
||||
if self.is_flux:
|
||||
# has 64 channels in for some reason
|
||||
num_channels = 16
|
||||
if num_channels is None:
|
||||
num_channels = self.unet_unwrapped.config['in_channels']
|
||||
if self.is_flux:
|
||||
# it gets packed, unpack it
|
||||
num_channels = num_channels // 4
|
||||
noise = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
|
||||
Reference in New Issue
Block a user