diff --git a/extensions_built_in/advanced_generator/Img2ImgGenerator.py b/extensions_built_in/advanced_generator/Img2ImgGenerator.py index 2a5cfe3f..58713bd0 100644 --- a/extensions_built_in/advanced_generator/Img2ImgGenerator.py +++ b/extensions_built_in/advanced_generator/Img2ImgGenerator.py @@ -1,3 +1,4 @@ +import math import os import random from collections import OrderedDict @@ -6,8 +7,9 @@ from typing import List import numpy as np from PIL import Image from diffusers import T2IAdapter +from diffusers.utils.torch_utils import randn_tensor from torch.utils.data import DataLoader -from diffusers import StableDiffusionXLImg2ImgPipeline +from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline from tqdm import tqdm from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig @@ -21,6 +23,7 @@ from toolkit.data_loader import get_dataloader_from_datasets from toolkit.train_tools import get_torch_dtype from controlnet_aux.midas import MidasDetector from diffusers.utils import load_image +from torchvision.transforms import ToTensor def flush(): @@ -28,6 +31,9 @@ def flush(): gc.collect() + + + class GenerateConfig: def __init__(self, **kwargs): @@ -103,7 +109,6 @@ class Img2ImgGenerator(BaseExtensionProcess): self.sd.load_model() device = torch.device(self.device) - if self.model_config.is_xl: pipe = StableDiffusionXLImg2ImgPipeline( vae=self.sd.vae, @@ -114,6 +119,8 @@ class Img2ImgGenerator(BaseExtensionProcess): tokenizer_2=self.sd.tokenizer[1], scheduler=get_sampler(self.generate_config.sampler), ).to(device, dtype=self.torch_dtype) + elif self.model_config.is_pixart: + pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype) else: raise NotImplementedError("Only XL models are supported") pipe.set_progress_bar_config(disable=True) @@ -130,6 +137,9 @@ class Img2ImgGenerator(BaseExtensionProcess): for i, batch in enumerate(self.data_loader): batch: DataLoaderBatchDTO = batch + gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1) + generator = torch.manual_seed(gen_seed) + file_item: FileItemDTO = batch.file_items[0] img_path = file_item.path img_filename = os.path.basename(img_path) @@ -152,18 +162,76 @@ class Img2ImgGenerator(BaseExtensionProcess): img: torch.Tensor = batch.tensor.clone() image = self.to_pil(img) - # image.save(output_depth_path) - pipe: StableDiffusionXLImg2ImgPipeline = pipe + if self.model_config.is_pixart: + pipe: PixArtSigmaPipeline = pipe - gen_images = pipe.__call__( - prompt=caption, - negative_prompt=self.generate_config.neg, - image=image, - num_inference_steps=self.generate_config.sample_steps, - guidance_scale=self.generate_config.guidance_scale, - strength=self.generate_config.denoise_strength, - ).images[0] + # Encode the full image once + encoded_image = pipe.vae.encode( + pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype)) + if hasattr(encoded_image, "latent_dist"): + latents = encoded_image.latent_dist.sample(generator) + elif hasattr(encoded_image, "latents"): + latents = encoded_image.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + latents = pipe.vae.config.scaling_factor * latents + + # latents = self.sd.encode_images(img) + + # self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps) + # start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength) + # timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0) + # timestep = timestep.to(device, dtype=torch.int32) + # latent = latent.to(device, dtype=self.torch_dtype) + # noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype) + # latent = self.sd.add_noise(latent, noise, timestep) + # timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:] + batch_size = 1 + num_images_per_prompt = 1 + + shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor, + image.width // pipe.vae_scale_factor) + noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype) + + # noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype) + num_inference_steps = self.generate_config.sample_steps + strength = self.generate_config.denoise_strength + # Get timesteps + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + pipe.scheduler.set_timesteps(num_inference_steps, device="cpu") + timesteps = pipe.scheduler.timesteps[t_start:] + timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = pipe.scheduler.add_noise(latents, noise, timestep) + + gen_images = pipe.__call__( + prompt=caption, + negative_prompt=self.generate_config.neg, + latents=latents, + timesteps=timesteps, + width=image.width, + height=image.height, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + guidance_scale=self.generate_config.guidance_scale, + # strength=self.generate_config.denoise_strength, + use_resolution_binning=False, + output_type="np" + ).images[0] + gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8) + gen_images = Image.fromarray(gen_images) + else: + pipe: StableDiffusionXLImg2ImgPipeline = pipe + + gen_images = pipe.__call__( + prompt=caption, + negative_prompt=self.generate_config.neg, + image=image, + num_inference_steps=self.generate_config.sample_steps, + guidance_scale=self.generate_config.guidance_scale, + strength=self.generate_config.denoise_strength, + ).images[0] os.makedirs(os.path.dirname(output_path), exist_ok=True) gen_images.save(output_path) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 0cfe9608..e0a40ef4 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1331,6 +1331,7 @@ class BaseSDTrainProcess(BaseTrainProcess): is_lorm=is_lorm, network_config=self.network_config, network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, **network_kwargs ) diff --git a/run.py b/run.py index 6f133081..068abe9a 100644 --- a/run.py +++ b/run.py @@ -1,5 +1,5 @@ import os -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" import sys from typing import Union, OrderedDict from dotenv import load_dotenv diff --git a/toolkit/buckets.py b/toolkit/buckets.py index aec99040..835c9eb9 100644 --- a/toolkit/buckets.py +++ b/toolkit/buckets.py @@ -51,6 +51,9 @@ resolutions_1024: List[BucketResolution] = [ {"width": 512, "height": 1920}, {"width": 512, "height": 1984}, {"width": 512, "height": 2048}, + # extra wides + {"width": 8192, "height": 128}, + {"width": 128, "height": 8192}, ] # Even numbers so they can be patched easier diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ca7ee36e..f6324364 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -128,6 +128,8 @@ class NetworkConfig: if self.lorm_config.do_conv: self.conv = 4 + self.transformer_only = kwargs.get('transformer_only', False) + AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net'] diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 0e46e4ff..9cb377fc 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -169,6 +169,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3, network_type: str = "lora", full_train_in_out: bool = False, + transformer_only: bool = False, **kwargs ) -> None: """ @@ -193,6 +194,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if ignore_if_contains is None: ignore_if_contains = [] self.ignore_if_contains = ignore_if_contains + self.transformer_only = transformer_only self.only_if_contains: Union[List, None] = only_if_contains @@ -271,6 +273,15 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): is_conv2d = child_module.__class__.__name__ in CONV_MODULES is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + lora_name = [prefix, name, child_name] + # filter out blank + lora_name = [x for x in lora_name if x and x != ""] + lora_name = ".".join(lora_name) + # if it doesnt have a name, it wil have two dots + lora_name.replace("..", ".") + lora_name = lora_name.replace(".", "_") + skip = False if any([word in child_name for word in self.ignore_if_contains]): skip = True @@ -279,9 +290,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if count_parameters(child_module) < parameter_threshold: skip = True + if self.transformer_only and self.is_pixart and is_unet: + if "transformer_blocks" not in lora_name: + skip = True + if (is_linear or is_conv2d) and not skip: - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") if self.only_if_contains is not None and not any([word in lora_name for word in self.only_if_contains]): continue @@ -356,8 +369,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): index = None print(f"create LoRA for Text Encoder:") - text_encoder_loras, skipped = create_modules(False, index, text_encoder, - LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + + if self.is_pixart: + replace_modules = ["T5EncoderModel"] + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 01ec66ea..2df15b82 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -516,6 +516,9 @@ class ToolkitNetworkMixin: load_sd = OrderedDict() for key, value in weights_sd.items(): load_key = keymap[key] if key in keymap else key + # replace old double __ with single _ + if self.is_pixart: + load_key = load_key.replace('__', '_') load_sd[load_key] = value # extract extra items from state dict diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 143f5ca4..ea81ccdb 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -169,15 +169,6 @@ class StableDiffusion: if self.is_loaded: return dtype = get_torch_dtype(self.dtype) - # sch = KDPM2DiscreteScheduler - if self.noise_scheduler is None: - scheduler = get_sampler( - 'ddpm', { - "prediction_type": self.prediction_type, - }, - 'sd' if not self.is_pixart else 'pixart' - ) - self.noise_scheduler = scheduler # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why # self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) @@ -190,9 +181,10 @@ class StableDiffusion: from toolkit.civitai import get_model_path_from_url model_path = get_model_path_from_url(self.model_config.name_or_path) - load_args = { - 'scheduler': self.noise_scheduler, - } + load_args = {} + if self.noise_scheduler: + load_args['scheduler'] = self.noise_scheduler + if self.model_config.vae_path is not None: load_args['vae'] = load_vae(self.model_config.vae_path, dtype) if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega: @@ -290,6 +282,7 @@ class StableDiffusion: device=self.device_torch, torch_dtype=self.torch_dtype, text_encoder_3=text_encoder3, + **load_args ) flush() @@ -387,6 +380,8 @@ class StableDiffusion: tokenizer = pipe.tokenizer pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + if self.noise_scheduler is None: + self.noise_scheduler = pipe.scheduler elif self.model_config.is_auraflow: