diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 6cbaa542..e37edd3a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -59,7 +59,7 @@ from tqdm import tqdm from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \ DecoratorConfig -from toolkit.logging import create_logger +from toolkit.logging_aitk import create_logger from diffusers import FluxTransformer2DModel from toolkit.accelerator import get_accelerator from toolkit.print import print_acc @@ -578,7 +578,7 @@ class BaseSDTrainProcess(BaseTrainProcess): direct_save = True if self.adapter_config.type == 'redux': direct_save = True - if self.adapter_config.type == 'control_lora': + if self.adapter_config.type in ['control_lora', 'subpixel']: direct_save = True save_ip_adapter_from_diffusers( state_dict, @@ -918,6 +918,7 @@ class BaseSDTrainProcess(BaseTrainProcess): noise = self.sd.get_latent_noise( height=latents.shape[2], width=latents.shape[3], + num_channels=latents.shape[1], batch_size=batch_size, noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ed0adead..f76def92 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -13,7 +13,7 @@ SaveFormat = Literal['safetensors', 'diffusers'] if TYPE_CHECKING: from toolkit.guidance import GuidanceType - from toolkit.logging import EmptyLogger + from toolkit.logging_aitk import EmptyLogger else: EmptyLogger = None @@ -252,6 +252,9 @@ class AdapterConfig: self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0) self.has_inpainting_input: bool = kwargs.get('has_inpainting_input', False) self.invert_inpaint_mask_chance: float = kwargs.get('invert_inpaint_mask_chance', 0.0) + + # for subpixel adapter + self.subpixel_downscale_factor: int = kwargs.get('subpixel_downscale_factor', 8) class EmbeddingConfig: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index f88bcdaa..6f8be862 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.control_lora_adapter import ControlLoraAdapter +from toolkit.models.subpixel_adapter import SubpixelAdapter from toolkit.models.ilora import InstantLoRAModule from toolkit.models.single_value_adapter import SingleValueAdapter from toolkit.models.te_adapter import TEAdapter @@ -103,6 +104,7 @@ class CustomAdapter(torch.nn.Module): self.single_value_adapter: SingleValueAdapter = None self.redux_adapter: ReduxImageEncoder = None self.control_lora: ControlLoraAdapter = None + self.subpixel_adapter: SubpixelAdapter = None self.conditional_embeds: Optional[torch.Tensor] = None self.unconditional_embeds: Optional[torch.Tensor] = None @@ -253,6 +255,13 @@ class CustomAdapter(torch.nn.Module): config=self.config, train_config=self.train_config ) + elif self.adapter_type == 'subpixel': + self.subpixel_adapter = SubpixelAdapter( + self, + sd=self.sd_ref(), + config=self.config, + train_config=self.train_config + ) else: raise ValueError(f"unknown adapter type: {self.adapter_type}") @@ -284,7 +293,7 @@ class CustomAdapter(torch.nn.Module): def setup_clip(self): adapter_config = self.config sd = self.sd_ref() - if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora"]: + if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]: return if self.config.type == 'photo_maker': try: @@ -502,6 +511,14 @@ class CustomAdapter(torch.nn.Module): for k2, v2 in v.items(): new_dict[k + '.' + k2] = v2 self.control_lora.load_weights(new_dict, strict=strict) + + if self.adapter_type == 'subpixel': + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.subpixel_adapter.load_weights(new_dict, strict=strict) pass @@ -558,6 +575,11 @@ class CustomAdapter(torch.nn.Module): for k, v in d.items(): state_dict[k] = v return state_dict + elif self.adapter_type == 'subpixel': + d = self.subpixel_adapter.get_state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict else: raise NotImplementedError @@ -702,7 +724,7 @@ class CustomAdapter(torch.nn.Module): prompt: Union[List[str], str], is_unconditional: bool = False, ): - if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora']: + if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel']: return prompt elif self.adapter_type == 'text_encoder': # todo allow for training @@ -1225,6 +1247,10 @@ class CustomAdapter(torch.nn.Module): param_list = self.control_lora.get_params() for param in param_list: yield param + elif self.config.type == 'subpixel': + param_list = self.subpixel_adapter.get_params() + for param in param_list: + yield param else: raise NotImplementedError diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index d3845d2a..e52f86fe 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -381,6 +381,8 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti sd: 'StableDiffusion' = None, ): self.dataset_config = dataset_config + # update bucket divisibility + self.dataset_config.bucket_tolerance = sd.get_bucket_divisibility() self.is_video = dataset_config.num_frames > 1 super().__init__() folder_path = dataset_config.folder_path diff --git a/toolkit/logging.py b/toolkit/logging_aitk.py similarity index 100% rename from toolkit/logging.py rename to toolkit/logging_aitk.py diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 43ecb1be..88284e54 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -1,5 +1,6 @@ import copy import gc +import inspect import json import random import shutil @@ -230,6 +231,20 @@ class BaseModel: def is_lumina2(self): return self.arch == 'lumina2' + def get_bucket_divisibility(self): + if self.vae is None: + return 8 + try: + divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1) + except: + # if we have a custom vae, it might not have this + divisibility = 8 + + # flux packs this again, + if self.is_flux: + divisibility = divisibility * 4 + return divisibility + # these must be implemented in child classes def load_model(self): # override this in child classes @@ -797,13 +812,20 @@ class BaseModel: self.unet.to(self.device_torch) if self.unet.dtype != self.torch_dtype: self.unet = self.unet.to(dtype=self.torch_dtype) + + # check if get_noise prediction has guidance_embedding_scale + # if it does not, we dont pass it + signatures = inspect.signature(self.get_noise_prediction).parameters + + if 'guidance_embedding_scale' in signatures: + kwargs['guidance_embedding_scale'] = guidance_embedding_scale + if 'bypass_guidance_embedding' in signatures: + kwargs['bypass_guidance_embedding'] = bypass_guidance_embedding noise_pred = self.get_noise_prediction( latent_model_input=latent_model_input, timestep=timestep, text_embeddings=text_embeddings, - guidance_embedding_scale=guidance_embedding_scale, - bypass_guidance_embedding=bypass_guidance_embedding, **kwargs ) diff --git a/toolkit/models/subpixel_adapter.py b/toolkit/models/subpixel_adapter.py new file mode 100644 index 00000000..5429265d --- /dev/null +++ b/toolkit/models/subpixel_adapter.py @@ -0,0 +1,302 @@ +import inspect +import weakref +import torch +from typing import TYPE_CHECKING +from toolkit.lora_special import LoRASpecialNetwork +from diffusers import FluxTransformer2DModel +# weakref +from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig + from toolkit.custom_adapter import CustomAdapter + + + +class InOutModule(torch.nn.Module): + def __init__( + self, + adapter: 'SubpixelAdapter', + orig_layer: torch.nn.Linear, + in_channels=64, + out_channels=3072 + ): + super().__init__() + # only do the weight for the new input. We combine with the original linear layer + self.x_embedder = torch.nn.Linear( + in_channels, + out_channels, + bias=True, + ) + + self.proj_out = torch.nn.Linear( + out_channels, + in_channels, + bias=True, + ) + # make sure the weight is float32 + self.x_embedder.weight.data = self.x_embedder.weight.data.float() + self.x_embedder.bias.data = self.x_embedder.bias.data.float() + + self.proj_out.weight.data = self.proj_out.weight.data.float() + self.proj_out.bias.data = self.proj_out.bias.data.float() + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) + + @classmethod + def from_model( + cls, + model: FluxTransformer2DModel, + adapter: 'SubpixelAdapter', + num_channels: int = 768, + downscale_factor: int = 8 + ): + if model.__class__.__name__ == 'FluxTransformer2DModel': + + x_embedder: torch.nn.Linear = model.x_embedder + proj_out: torch.nn.Linear = model.proj_out + in_out_module = cls( + adapter, + orig_layer=x_embedder, + in_channels=num_channels, + out_channels=x_embedder.out_features, + ) + + # hijack the forward method + x_embedder._orig_ctrl_lora_forward = x_embedder.forward + x_embedder.forward = in_out_module.in_forward + proj_out._orig_ctrl_lora_forward = proj_out.forward + proj_out.forward = in_out_module.out_forward + + # update the config of the transformer + model.config.in_channels = num_channels + model.config["in_channels"] = num_channels + model.config.out_channels = num_channels + model.config["out_channels"] = num_channels + + # replace the vae of the model + sd = adapter.sd_ref() + sd.vae = AutoencoderPixelMixer( + in_channels=3, + downscale_factor=downscale_factor + ) + + sd.pipeline.vae = sd.vae + + return in_out_module + else: + raise ValueError("Model not supported") + + @property + def is_active(self): + return self.adapter_ref().is_active + + + def in_forward(self, x): + if not self.is_active: + # make sure lora is not active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = False + return self.orig_layer_ref()._orig_ctrl_lora_forward(x) + + # make sure lora is active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = True + + orig_device = x.device + orig_dtype = x.dtype + + x = x.to(self.x_embedder.weight.device, dtype=self.x_embedder.weight.dtype) + + x = self.x_embedder(x) + + x = x.to(orig_device, dtype=orig_dtype) + return x + + def out_forward(self, x): + if not self.is_active: + # make sure lora is not active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = False + return self.orig_layer_ref()._orig_ctrl_lora_forward(x) + + # make sure lora is active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = True + + orig_device = x.device + orig_dtype = x.dtype + + x = x.to(self.proj_out.weight.device, dtype=self.proj_out.weight.dtype) + + x = self.proj_out(x) + + x = x.to(orig_device, dtype=orig_dtype) + return x + + + +class SubpixelAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + config: 'AdapterConfig', + train_config: 'TrainConfig' + ): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref = weakref.ref(sd) + self.model_config: ModelConfig = sd.model_config + self.network_config = config.lora_config + self.train_config = train_config + self.device_torch = sd.device_torch + self.control_lora = None + + if self.network_config is not None: + + network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs + if hasattr(sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = self.sd.target_lora_modules + + if 'ignore_if_contains' not in network_kwargs: + network_kwargs['ignore_if_contains'] = [] + + # always ignore x_embedder + network_kwargs['ignore_if_contains'].append('transformer.x_embedder') + network_kwargs['ignore_if_contains'].append('transformer.proj_out') + + self.control_lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs + ) + self.control_lora.force_to(self.device_torch, dtype=torch.float32) + self.control_lora._update_torch_multiplier() + self.control_lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + self.control_lora.can_merge_in = False + self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.control_lora.enable_gradient_checkpointing() + + downscale_factor = config.subpixel_downscale_factor + if downscale_factor == 8: + num_channels = 768 + elif downscale_factor == 16: + num_channels = 3072 + else: + raise ValueError( + f"downscale_factor {downscale_factor} not supported" + ) + + self.in_out: InOutModule = InOutModule.from_model( + sd.unet_unwrapped, + self, + num_channels=num_channels, # packed channels + downscale_factor=downscale_factor + ) + self.in_out.to(self.device_torch) + + def get_params(self): + if self.control_lora is not None: + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.control_lora.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr + if 'learning_rate' in sig.parameters: + config['learning_rate'] = self.train_config.lr + params_net = self.control_lora.prepare_optimizer_params( + **config + ) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + else: + params = [] + + # make sure the embedder is float32 + self.in_out.to(torch.float32) + + params += list(self.in_out.parameters()) + + # we need to be able to yield from the list like yield from params + + return params + + def load_weights(self, state_dict, strict=True): + lora_sd = {} + img_embedder_sd = {} + for key, value in state_dict.items(): + if "transformer.x_embedder" in key: + new_key = key.replace("transformer.", "") + img_embedder_sd[new_key] = value + elif "transformer.proj_out" in key: + new_key = key.replace("transformer.", "") + img_embedder_sd[new_key] = value + else: + lora_sd[key] = value + + # todo process state dict before loading + if self.control_lora is not None: + self.control_lora.load_weights(lora_sd) + # automatically upgrade the x imbedder if more dims are added + self.in_out.load_state_dict(img_embedder_sd, strict=False) + + def get_state_dict(self): + if self.control_lora is not None: + lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) + else: + lora_sd = {} + # todo make sure we match loras elseware. + img_embedder_sd = self.in_out.state_dict() + for key, value in img_embedder_sd.items(): + lora_sd[f"transformer.{key}"] = value + return lora_sd + + @property + def is_active(self): + return self.adapter_ref().is_active diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index e1636b9d..48992cfb 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -318,6 +318,9 @@ class Wan21(BaseModel): # cache for holding noise self.effective_noise = None + + def get_bucket_divisibility(self): + return 16 # static method to get the scheduler @staticmethod diff --git a/toolkit/pixel_shuffle_encoder.py b/toolkit/pixel_shuffle_encoder.py new file mode 100644 index 00000000..8848e405 --- /dev/null +++ b/toolkit/pixel_shuffle_encoder.py @@ -0,0 +1,211 @@ +from diffusers import AutoencoderKL +from typing import Optional, Union +import torch +import torch.nn as nn +import numpy as np +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput +from diffusers.models.autoencoders.vae import DecoderOutput + + +class PixelMixer(nn.Module): + def __init__(self, in_channels, downscale_factor): + super(PixelMixer, self).__init__() + self.downscale_factor = downscale_factor + self.in_channels = in_channels + + def forward(self, x): + latent = self.encode(x) + out = self.decode(latent) + return out + + def encode(self, x): + return torch.nn.PixelUnshuffle(self.downscale_factor)(x) + + def decode(self, x): + return torch.nn.PixelShuffle(self.downscale_factor)(x) + + +# for reference + +# none of this matters with llvae, but we need to match the interface (latent_channels might matter) + +class Config: + in_channels = 3 + out_channels = 3 + down_block_types = ('1', '1', + '1', '1') + up_block_types = ('1', '1', + '1', '1') + block_out_channels = (1, 1, 1, 1) + latent_channels = 192 # usually 4 + norm_num_groups = 32 + sample_size = 512 + # scaling_factor = 1 + # shift_factor = 0 + scaling_factor = 1.8 + shift_factor = -0.123 + # VAE + # - Mean: -0.12306906282901764 + # - Std: 0.556016206741333 + # Normalization parameters: + # - Shift factor: -0.12306906282901764 + # - Scaling factor: 1.7985087266803625 + + def __getitem__(cls, x): + return getattr(cls, x) + + +class AutoencoderPixelMixer(nn.Module): + + def __init__(self, in_channels=3, downscale_factor=8): + super().__init__() + self.mixer = PixelMixer(in_channels, downscale_factor) + self._dtype = torch.float32 + self._device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self.config = Config() + + if downscale_factor == 8: + # we go by len of block out channels in code, so simulate it + self.config.block_out_channels = (1, 1, 1, 1) + self.config.latent_channels = 192 + + elif downscale_factor == 16: + # we go by len of block out channels in code, so simulate it + self.config.block_out_channels = (1, 1, 1, 1, 1) + self.config.latent_channels = 768 + else: + raise ValueError( + f"downscale_factor {downscale_factor} not supported") + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, value): + self._dtype = value + + @property + def device(self): + return self._device + + @device.setter + def device(self, value): + self._device = value + + # mimic to from torch + def to(self, *args, **kwargs): + # pull out dtype and device if they exist + if 'dtype' in kwargs: + self._dtype = kwargs['dtype'] + if 'device' in kwargs: + self._device = kwargs['device'] + return super().to(*args, **kwargs) + + def enable_xformers_memory_efficient_attention(self): + pass + + # @apply_forward_hook + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + + h = self.mixer.encode(x) + + # moments = self.quant_conv(h) + # posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (h,) + + class FakeDist: + def __init__(self, x): + self._sample = x + + def sample(self): + return self._sample + + return AutoencoderKLOutput(latent_dist=FakeDist(h)) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + dec = self.mixer.decode(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # @apply_forward_hook + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def _set_gradient_checkpointing(self, module, value=False): + pass + + def enable_tiling(self, use_tiling: bool = True): + pass + + def disable_tiling(self): + pass + + def enable_slicing(self): + pass + + def disable_slicing(self): + pass + + def set_use_memory_efficient_attention_xformers(self, value: bool = True): + pass + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +# test it +if __name__ == '__main__': + import os + from PIL import Image + import torchvision.transforms as transforms + user_path = os.path.expanduser('~') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + input_path = os.path.join(user_path, "Pictures/test/test.jpg") + output_path = os.path.join(user_path, "Pictures/test/test.jpg") + img = Image.open(input_path) + img_tensor = transforms.ToTensor()(img) + img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype) + print("input_shape: ", list(img_tensor.shape)) + vae = PixelMixer(in_channels=3, downscale_factor=8) + latent = vae.encode(img_tensor) + print("latent_shape: ", list(latent.shape)) + out_tensor = vae.decode(latent) + print("out_shape: ", list(out_tensor.shape)) + + mse_loss = nn.MSELoss() + mse = mse_loss(img_tensor, out_tensor) + print("roundtrip_loss: ", mse.item()) + out_img = transforms.ToPILImage()(out_tensor.squeeze(0)) + out_img.save(output_path) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 53d41637..8039d060 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -249,6 +249,17 @@ class StableDiffusion: @property def unet_unwrapped(self): return unwrap_model(self.unet) + + def get_bucket_divisibility(self): + if self.vae is None: + return 8 + divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1) + + # flux packs this again, + if self.is_flux: + divisibility = divisibility * 4 + return divisibility + def load_model(self): if self.is_loaded: @@ -1721,6 +1732,7 @@ class StableDiffusion: pixel_width=None, batch_size=1, noise_offset=0.0, + num_channels=None, ): VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) if height is None and pixel_height is None: @@ -1732,10 +1744,11 @@ class StableDiffusion: if width is None: width = pixel_width // VAE_SCALE_FACTOR - num_channels = self.unet_unwrapped.config['in_channels'] - if self.is_flux: - # has 64 channels in for some reason - num_channels = 16 + if num_channels is None: + num_channels = self.unet_unwrapped.config['in_channels'] + if self.is_flux: + # it gets packed, unpack it + num_channels = num_channels // 4 noise = torch.randn( ( batch_size,