diff --git a/extensions_built_in/flex2/__init__.py b/extensions_built_in/flex2/__init__.py new file mode 100644 index 00000000..75aea751 --- /dev/null +++ b/extensions_built_in/flex2/__init__.py @@ -0,0 +1,6 @@ +from .flex2 import Flex2 + +AI_TOOLKIT_MODELS = [ + # put a list of models here + Flex2 +] diff --git a/extensions_built_in/flex2/flex2.py b/extensions_built_in/flex2/flex2.py new file mode 100644 index 00000000..e3234fe1 --- /dev/null +++ b/extensions_built_in/flex2/flex2.py @@ -0,0 +1,483 @@ +import os +from typing import TYPE_CHECKING, List + +import torch +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from diffusers import FluxTransformer2DModel, AutoencoderKL +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.mask import generate_random_mask +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer +from .pipeline import Flex2Pipeline +from einops import rearrange, repeat +import random +import torch.nn.functional as F + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + + +class Flex2(BaseModel): + arch = "flex2" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['FluxTransformer2DModel'] + + # for training, pass these as kwargs + self.invert_inpaint_mask_chance = model_config.model_kwargs.get('invert_inpaint_mask_chance', 0.0) + self.inpaint_dropout = model_config.model_kwargs.get('inpaint_dropout', 0.0) + self.control_dropout = model_config.model_kwargs.get('control_dropout', 0.0) + self.inpaint_random_chance = model_config.model_kwargs.get('inpaint_random_chance', 0.0) + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Flux2 model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + # this is the original path put in the model directory + # it is here because for finetuning we only save the transformer usually + # so we need this for the VAE, te, etc + base_model_path = self.model_config.name_or_path_original + + transformer_path = model_path + transformer_subfolder = 'transformer' + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + self.print_and_status_update("Loading transformer") + transformer = FluxTransformer2DModel.from_pretrained( + transformer_path, + subfolder=transformer_subfolder, + torch_dtype=dtype, + ) + transformer.to(self.quantize_device, dtype=dtype) + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained( + base_model_path, subfolder="tokenizer_2", torch_dtype=dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + base_model_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + self.print_and_status_update("Loading CLIP") + text_encoder = CLIPTextModel.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype) + + self.noise_scheduler = Flex2.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + pipe: Flex2Pipeline = Flex2Pipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = Flex2.get_train_scheduler() + + pipeline: Flex2Pipeline = Flex2Pipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer_2=self.tokenizer[1], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer) + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: Flex2Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if gen_config.ctrl_img is None: + control_img = None + else: + control_img = Image.open(gen_config.ctrl_img) + if ".inpaint." not in gen_config.ctrl_img: + control_img = control_img.convert("RGB") + else: + # make sure it has an alpha + if control_img.mode != "RGBA": + raise ValueError("Inpainting images must have an alpha channel") + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + control_image=control_img, + control_image_idx=gen_config.ctrl_idx, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + guidance_embedding_scale: float, + bypass_guidance_embedding: bool, + **kwargs + ): + with torch.no_grad(): + bs, c, h, w = latent_model_input.shape + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 + ) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", + b=bs).to(self.device_torch) + + txt_ids = torch.zeros( + bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + # # handle guidance + if self.unet_unwrapped.config.guidance_embeds: + if isinstance(guidance_embedding_scale, list): + guidance = torch.tensor( + guidance_embedding_scale, device=self.device_torch) + else: + guidance = torch.tensor( + [guidance_embedding_scale], device=self.device_torch) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + + if bypass_guidance_embedding: + bypass_flux_guidance(self.unet) + + cast_dtype = self.unet.dtype + # changes from orig implementation + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + noise_pred = self.unet( + hidden_states=latent_model_input_packed.to( + self.device_torch, cast_dtype), + timestep=timestep / 1000, + encoder_hidden_states=text_embeddings.text_embeds.to( + self.device_torch, cast_dtype), + pooled_projections=text_embeddings.pooled_embeds.to( + self.device_torch, cast_dtype), + txt_ids=txt_ids, + img_ids=img_ids, + guidance=guidance, + return_dict=False, + **kwargs, + )[0] + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + noise_pred = rearrange( + noise_pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + c=self.vae.config.latent_channels + ) + + if bypass_guidance_embedding: + restore_flux_guidance(self.unet) + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( + self.tokenizer, + self.text_encoder, + prompt, + max_length=512, + ) + pe = PromptEmbeds( + prompt_embeds + ) + pe.pooled_embeds = pooled_prompt_embeds + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return self.model.proj_out.weight.requires_grad + + def get_te_has_grad(self): + # return from a weight if it has grad + return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: FluxTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): + with torch.no_grad(): + # inpainting input is 0-1 (bs, 4, h, w) on batch.inpaint_tensor + # 4th channel is the mask with 1 being keep area and 0 being area to inpaint. + # todo handle dropout on a batch item level, this frops out the entire batch + do_dropout = random.random() < self.inpaint_dropout if self.inpaint_dropout > 0.0 else False + # do random mask if we dont have one + inpaint_tensor = batch.inpaint_tensor + if self.inpaint_random_chance > 0.0: + do_random = random.random() < self.inpaint_random_chance + if do_random: + # force a random tensor + inpaint_tensor = None + + if inpaint_tensor is None and not do_dropout: + # generate a random one since we dont have one + # this will make random blobs, invert the blobs for now as we normanlly inpaint the alpha + inpaint_tensor = 1 - generate_random_mask( + batch_size=latents.shape[0], + height=latents.shape[2], + width=latents.shape[3], + device=latents.device, + ).to(latents.device, latents.dtype) + if inpaint_tensor is not None and not do_dropout: + + if inpaint_tensor.shape[1] == 4: + # get just the mask + inpainting_tensor_mask = inpaint_tensor[:, 3:4, :, :].to(latents.device, dtype=latents.dtype) + elif inpaint_tensor.shape[1] == 3: + # rgb mask. Just get one channel + inpainting_tensor_mask = inpaint_tensor[:, 0:1, :, :].to(latents.device, dtype=latents.dtype) + else: + inpainting_tensor_mask = inpaint_tensor + + # # use our batch latents so we cna avoid ancoding again + inpainting_latent = batch.latents + + # resize the mask to match the new encoded size + inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear') + inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype) + + do_mask_invert = False + if self.invert_inpaint_mask_chance > 0.0: + do_mask_invert = random.random() < self.invert_inpaint_mask_chance + if do_mask_invert: + # invert the mask + inpainting_tensor_mask = 1 - inpainting_tensor_mask + + # mask out the inpainting area, it is currently 0 for inpaint area, and 1 for keep area + # we are zeroing our the latents in the inpaint area not on the pixel space. + inpainting_latent = inpainting_latent * inpainting_tensor_mask + + # mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it. + inpainting_tensor_mask = 1 - inpainting_tensor_mask + # leave the mask as 0-1 and concat on channel of latents + inpainting_latent = torch.cat((inpainting_latent, inpainting_tensor_mask), dim=1) + else: + # we have iinpainting but didnt get a control. or we are doing a dropout + # the input needs to be all zeros for the latents and all 1s for the mask + inpainting_latent = torch.zeros_like(latents) + # add ones for the mask since we are technically inpainting everything + inpainting_latent = torch.cat((inpainting_latent, torch.ones_like(inpainting_latent[:, :1, :, :])), dim=1) + + control_tensor = batch.control_tensor + if control_tensor is None: + # concat random normal noise onto the latents + # check dimension, this is before they are rearranged + # it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging + ctrl = torch.zeros( + latents.shape[0], # bs + latents.shape[1], + latents.shape[2], + latents.shape[3], + device=latents.device, + dtype=latents.dtype + ) + # inpainting always comes first + ctrl = torch.cat((inpainting_latent, ctrl), dim=1) + latents = torch.cat((latents, ctrl), dim=1) + return latents.detach() + # if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w] + # if we have 1, it comes in like [bs, ch, h, w] + # stack out control tensors to be [bs, ch * num_control_images, h, w] + + control_tensor_list = [] + if len(control_tensor.shape) == 4: + control_tensor_list.append(control_tensor) + else: + num_control_images = control_tensor.shape[1] + # reshape + control_tensor = control_tensor.view( + control_tensor.shape[0], + control_tensor.shape[1] * control_tensor.shape[2], + control_tensor.shape[3], + control_tensor.shape[4] + ) + control_tensor_list = control_tensor.chunk(num_control_images, dim=1) + + do_dropout = random.random() < self.control_dropout if self.control_dropout > 0.0 else False + if do_dropout: + # dropout with zeros + control_latent = torch.zeros_like(batch.latents) + else: + # we only have one control so we randomly pick from this list + control_tensor = random.choice(control_tensor_list) + # it is 0-1 need to convert to -1 to 1 + control_tensor = control_tensor * 2 - 1 + + control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]: + control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear') + + # encode it + control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype) + + # inpainting always comes first + control_latent = torch.cat((inpainting_latent, control_latent), dim=1) + # concat it onto the latents + latents = torch.cat((latents, control_latent), dim=1) + return latents.detach() \ No newline at end of file diff --git a/extensions_built_in/flex2/pipeline.py b/extensions_built_in/flex2/pipeline.py new file mode 100644 index 00000000..8661ff13 --- /dev/null +++ b/extensions_built_in/flex2/pipeline.py @@ -0,0 +1,348 @@ +from diffusers import FluxControlPipeline, FluxTransformer2DModel +from typing import Any, Callable, Dict, List, Optional, Union +import torch + +from diffusers.image_processor import PipelineImageInput +import numpy as np +from PIL import Image +import torch.nn.functional as F +from torchvision import transforms +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, XLA_AVAILABLE + + +class Flex2Pipeline(FluxControlPipeline): + def __init__( + self, + scheduler, + vae, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + transformer, + ): + super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + control_image: Optional[PipelineImageInput] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + control_image_idx: int = 0, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + # num_channels_latents = self.transformer.config.in_channels // 8 + num_channels_latents = 128 // 8 + + # pull mask off control image if there is one it is a pil image + mask = None + if control_image is not None and control_image.mode == "RGBA": + control_img_array = np.array(control_image) + mask = control_img_array[:, :, 3:4] + # scale it to 0 - 1 + mask = mask / 255.0 + # control image ideally would be a full image here + control_img_array = control_img_array[:, :, :3] + control_image = Image.fromarray(control_img_array.astype(np.uint8)) + + if control_image is not None: + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + if control_image.ndim == 4: + num_control_channels = num_channels_latents + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + if mask is not None: + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + mask = transform(mask).to(device, dtype=control_image.dtype).unsqueeze(0) + # resize mask to match control image + mask = F.interpolate(mask, size=(control_image.shape[2], control_image.shape[3]), mode="bilinear", align_corners=False) + mask = mask.to(device) + # apply the mask to the control image so the inpaint latent area is 0 + # mask is currently 0 for inpaint area and 1 for image area + control_image = control_image * mask + # invert mask so it is 1 for inpaint area and 0 for image area + mask = 1 - mask + control_image = torch.cat([control_image, mask], dim=1) + num_control_channels += 1 + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_control_channels, + height_control_image, + width_control_image, + ) + + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # make a blank control latent + control_image_list = [ + # impainting + torch.cat([torch.zeros_like(latents), torch.ones_like(latents[:, :, :4])], dim=2), + # control + torch.zeros_like(latents), + ] + if control_image is not None: + + control_image_list[control_image_idx] = control_image + + latent_model_input = torch.cat([latents] + control_image_list, dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) + + \ No newline at end of file diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 769b702a..89fb0e05 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1675,9 +1675,12 @@ class SDTrainer(BaseSDTrainProcess): else: if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() - if self.adapter and isinstance(self.adapter, CustomAdapter): - with self.timer('condition_noisy_latents'): + with self.timer('condition_noisy_latents'): + # do it for the model + noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch) + if self.adapter and isinstance(self.adapter, CustomAdapter): noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) + with self.timer('predict_unet'): noise_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ae591bb5..ed0adead 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -405,7 +405,7 @@ class TrainConfig: self.correct_pred_norm = kwargs.get('correct_pred_norm', False) self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) - self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet + self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace # scale the prediction by this. Increase for more detail, decrease for less self.pred_scaler = kwargs.get('pred_scaler', 1.0) @@ -470,9 +470,6 @@ class ModelConfig: self.is_auraflow: bool = kwargs.get('is_auraflow', False) self.is_v3: bool = kwargs.get('is_v3', False) self.is_flux: bool = kwargs.get('is_flux', False) - self.is_flex2: bool = kwargs.get('is_flex2', False) - if self.is_flex2: - self.is_flux = True self.is_lumina2: bool = kwargs.get('is_lumina2', False) if self.is_pixart_sigma: self.is_pixart = True @@ -540,6 +537,9 @@ class ModelConfig: self.arch: ModelArch = kwargs.get("arch", None) + # kwargs to pass to the model + self.model_kwargs = kwargs.get("model_kwargs", {}) + # handle migrating to new model arch if self.arch is not None: # reverse the arch to the old style @@ -557,8 +557,6 @@ class ModelConfig: self.is_auraflow = True elif self.arch == 'flux': self.is_flux = True - elif self.arch == 'flex2': - self.is_flex2 = True elif self.arch == 'lumina2': self.is_lumina2 = True elif self.arch == 'vega': @@ -582,8 +580,6 @@ class ModelConfig: self.arch = 'auraflow' elif kwargs.get('is_flux', False): self.arch = 'flux' - elif kwargs.get('is_flex2', False): - self.arch = 'flex2' elif kwargs.get('is_lumina2', False): self.arch = 'lumina2' elif kwargs.get('is_vega', False): @@ -878,6 +874,7 @@ class GenerateImageConfig: self.extra_values = extra_values if extra_values is not None else [] self.num_frames = num_frames self.fps = fps + self.ctrl_img = None self.ctrl_idx = ctrl_idx @@ -1072,6 +1069,8 @@ class GenerateImageConfig: self.num_frames = int(content) elif flag == 'fps': self.fps = int(content) + elif flag == 'ctrl_img': + self.ctrl_img = content elif flag == 'ctrl_idx': self.ctrl_idx = int(content) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index fcc6cda9..f88bcdaa 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -634,11 +634,8 @@ class CustomAdapter(torch.nn.Module): latents = torch.cat((latents, control_latent), dim=1) return latents.detach() - control_tensor = batch.control_tensor.to(latents.device, dtype=latents.dtype) if control_tensor is None: - # concat random normal noise onto the latents - # check dimension, this is before they are rearranged - # it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging + # concat zeros onto the latents ctrl = torch.zeros( latents.shape[0], # bs latents.shape[1] * self.num_control_images, # ch @@ -656,6 +653,8 @@ class CustomAdapter(torch.nn.Module): # if we have 1, it comes in like [bs, ch, h, w] # stack out control tensors to be [bs, ch * num_control_images, h, w] + control_tensor = batch.control_tensor.to(latents.device, dtype=latents.dtype) + control_tensor_list = [] if len(control_tensor.shape) == 4: control_tensor_list.append(control_tensor) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 8adcffa8..43ecb1be 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -42,6 +42,7 @@ from toolkit.print import print_acc if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO # tell it to shut up diffusers.logging.set_verbosity(diffusers.logging.ERROR) @@ -97,6 +98,8 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも class BaseModel: + # override these in child classes + arch = None def __init__( self, @@ -174,6 +177,14 @@ class BaseModel: @unet.setter def unet(self, value): self.model = value + + @property + def transformer(self): + return self.model + + @transformer.setter + def transformer(self, value): + self.model = value @property def unet_unwrapped(self): @@ -215,10 +226,6 @@ class BaseModel: def is_flux(self): return self.arch == 'flux' - @property - def is_flex2(self): - return self.arch == 'flex2' - @property def is_lumina2(self): return self.arch == 'lumina2' @@ -385,8 +392,13 @@ class BaseModel: extra = {} validation_image = None if self.adapter is not None and gen_config.adapter_image_path is not None: - validation_image = Image.open( - gen_config.adapter_image_path).convert("RGB") + validation_image = Image.open(gen_config.adapter_image_path) + if ".inpaint." not in gen_config.adapter_image_path: + validation_image = validation_image.convert("RGB") + else: + # make sure it has an alpha + if validation_image.mode != "RGBA": + raise ValueError("Inpainting images must have an alpha channel") if isinstance(self.adapter, T2IAdapter): # not sure why this is double?? validation_image = validation_image.resize( @@ -398,6 +410,10 @@ class BaseModel: (gen_config.width, gen_config.height)) extra['image'] = validation_image extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['control_image'] = validation_image + extra['control_image_idx'] = gen_config.ctrl_idx if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): transform = transforms.Compose([ transforms.ToTensor(), @@ -786,6 +802,8 @@ class BaseModel: latent_model_input=latent_model_input, timestep=timestep, text_embeddings=text_embeddings, + guidance_embedding_scale=guidance_embedding_scale, + bypass_guidance_embedding=bypass_guidance_embedding, **kwargs ) @@ -1431,3 +1449,7 @@ class BaseModel: def convert_lora_weights_before_load(self, state_dict): # can be overridden in child classes to convert weights before loading return state_dict + + def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): + # can be overridden in child classes to condition latents before noise prediction + return latents diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py index bdb692e0..13fa4267 100644 --- a/toolkit/models/cogview4.py +++ b/toolkit/models/cogview4.py @@ -60,6 +60,7 @@ scheduler_config = { class CogView4(BaseModel): + arch = 'cogview4' def __init__( self, device, diff --git a/toolkit/models/flex2.py b/toolkit/models/flex2.py deleted file mode 100644 index eabff80c..00000000 --- a/toolkit/models/flex2.py +++ /dev/null @@ -1,993 +0,0 @@ -from typing import List, Optional, Union -from diffusers import FluxPipeline -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.utils import ( - USE_PEFT_BACKEND, - is_torch_xla_available, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from transformers import AutoModel, AutoTokenizer -from transformers import ( - CLIPImageProcessor, - CLIPTextModel, - CLIPTokenizer, - CLIPVisionModelWithProjection -) - - - -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, FluxTransformer2DModel -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import Flex2Pipeline - - >>> pipe = Flex2Pipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") - >>> prompt = "A cat holding a sign that says hello world" - >>> # Depending on the variant being used, the pipeline call will slightly vary. - >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] - >>> image.save("flux.png") - ``` -""" - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - - -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.16, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class Flex2Pipeline( - DiffusionPipeline, - FluxLoraLoaderMixin, - FromSingleFileMixin, - TextualInversionLoaderMixin, - FluxIPAdapterMixin, -): - r""" - The Flux pipeline for text-to-image generation. - - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - - Args: - transformer ([`FluxTransformer2DModel`]): - Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" - _optional_components = ["image_encoder", "feature_extractor"] - _callback_tensor_inputs = ["latents", "prompt_embeds"] - - def __init__( - self, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - text_encoder_2: AutoModel, - tokenizer_2: AutoTokenizer, - transformer: FluxTransformer2DModel, - image_encoder: CLIPVisionModelWithProjection = None, - feature_extractor: CLIPImageProcessor = None, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - transformer=transformer, - scheduler=scheduler, - image_encoder=image_encoder, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 - # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible - # by the patch size. So the vae scale factor is multiplied by the patch size to account for this - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 - ) - self.default_sample_size = 128 - self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " - - # determine length of system prompt - self.system_prompt_length = self.tokenizer_2( - [self.system_prompt], - padding="longest", - return_tensors="pt", - ).input_ids[0].shape[0] - - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - def _get_llm_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length + self.system_prompt_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids.to(device) - prompt_attention_mask = text_inputs.attention_mask.to(device) - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length + self.system_prompt_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2( - text_input_ids, - attention_mask=prompt_attention_mask, - output_hidden_states=True - ) - prompt_embeds = prompt_embeds.hidden_states[-1] - - # remove the system prompt from the input and attention mask - prompt_embeds = prompt_embeds[:, self.system_prompt_length:] - prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_llm_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids - - def encode_image(self, image, device, num_images_per_prompt): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - return image_embeds - - def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt - ): - image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers - ): - single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) - - image_embeds.append(single_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - def check_inputs( - self, - prompt, - prompt_2, - height, - width, - negative_prompt=None, - negative_prompt_2=None, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, - ): - if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: - logger.warning( - f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" - ) - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - elif negative_prompt_2 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - - @staticmethod - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (self.vae_scale_factor * 2)) - width = 2 * (int(width) // (self.vae_scale_factor * 2)) - - shape = (batch_size, num_channels_latents, height, width) - - if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - return latents, latent_image_ids - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def current_timestep(self): - return self._current_timestep - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - negative_prompt: Union[str, List[str]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - true_cfg_scale: float = 1.0, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 28, - sigmas: Optional[List[float]] = None, - guidance_scale: float = 3.5, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - negative_ip_adapter_image: Optional[PipelineImageInput] = None, - negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is - not greater than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - true_cfg_scale (`float`, *optional*, defaults to 1.0): - When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. - ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of - IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not - provided, embeddings are computed from the `ip_adapter_image` input argument. - negative_ip_adapter_image: - (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. - negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of - IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not - provided, embeddings are computed from the `ip_adapter_image` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. - - Examples: - - Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. - """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - height, - width, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs - self._current_timestep = None - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - has_neg_prompt = negative_prompt is not None or ( - negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None - ) - do_true_cfg = true_cfg_scale > 1 and has_neg_prompt - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - if do_true_cfg: - ( - negative_prompt_embeds, - negative_pooled_prompt_embeds, - _, - ) = self.encode_prompt( - prompt=negative_prompt, - prompt_2=negative_prompt_2, - prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=negative_pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - image_seq_len = latents.shape[1] - mu = calculate_shift( - image_seq_len, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.16), - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - sigmas=sigmas, - mu=mu, - ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - - if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( - negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None - ): - negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) - elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( - negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None - ): - ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) - - if self.joint_attention_kwargs is None: - self._joint_attention_kwargs = {} - - image_embeds = None - negative_image_embeds = None - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, - ip_adapter_image_embeds, - device, - batch_size * num_images_per_prompt, - ) - if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: - negative_image_embeds = self.prepare_ip_adapter_image_embeds( - negative_ip_adapter_image, - negative_ip_adapter_image_embeds, - device, - batch_size * num_images_per_prompt, - ) - - # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - if image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - self._current_timestep = None - - if output_type == "latent": - image = latents - else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return FluxPipelineOutput(images=image) diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index a8bb7446..e1636b9d 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -300,6 +300,7 @@ class AggressiveWanUnloadPipeline(WanPipeline): class Wan21(BaseModel): + arch = 'wan21' def __init__( self, device, diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 5e3f58f4..53d41637 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -53,7 +53,6 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \ FluxControlPipeline from toolkit.models.lumina2 import Lumina2Transformer2DModel -from toolkit.models.flex2 import Flex2Pipeline import diffusers from diffusers import \ AutoencoderKL, \ @@ -189,7 +188,6 @@ class StableDiffusion: # self.is_pixart = model_config.is_pixart # self.is_auraflow = model_config.is_auraflow # self.is_flux = model_config.is_flux - # self.is_flex2 = model_config.is_flex2 # self.is_lumina2 = model_config.is_lumina2 self.use_text_encoder_1 = model_config.use_text_encoder_1 @@ -244,10 +242,6 @@ class StableDiffusion: def is_flux(self): return self.arch == 'flux' - @property - def is_flex2(self): - return self.arch == 'flex2' - @property def is_lumina2(self): return self.arch == 'lumina2' @@ -755,24 +749,16 @@ class StableDiffusion: vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() - if self.is_flex2: - tokenizer_2 = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer_2") - text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) - - else: - self.print_and_status_update("Loading T5") - tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) - text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", - torch_dtype=dtype) + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", + torch_dtype=dtype) text_encoder_2.to(self.device_torch, dtype=dtype) flush() if self.model_config.quantize_te: - if self.is_flex2: - self.print_and_status_update("Quantizing LLM") - else: - self.print_and_status_update("Quantizing T5") + self.print_and_status_update("Quantizing T5") quantize(text_encoder_2, weights=get_qtype(self.model_config.qtype)) freeze(text_encoder_2) flush() @@ -784,8 +770,6 @@ class StableDiffusion: self.print_and_status_update("Making pipe") Pipe = FluxPipeline - if self.is_flex2: - Pipe = Flex2Pipeline pipe: Pipe = Pipe( scheduler=scheduler, @@ -1164,8 +1148,6 @@ class StableDiffusion: arch = 'pixart' if self.is_flux: arch = 'flux' - if self.is_flex2: - arch = 'flex2' if self.is_lumina2: arch = 'lumina2' noise_scheduler = get_sampler( @@ -1240,8 +1222,6 @@ class StableDiffusion: else: Pipe = FluxPipeline - if self.is_flex2: - Pipe = Flex2Pipeline if self.adapter is not None and isinstance(self.adapter, CustomAdapter): # see if it is a control lora if self.adapter.control_lora is not None: @@ -2405,18 +2385,6 @@ class StableDiffusion: embeds, attention_mask=attention_mask, # not used ) - elif self.is_flex2: - prompt_embeds, pooled_prompt_embeds, text_ids = self.pipeline.encode_prompt( - prompt, - prompt, - device=self.device_torch, - max_sequence_length=512, - ) - pe = PromptEmbeds( - prompt_embeds - ) - pe.pooled_embeds = pooled_prompt_embeds - return pe elif self.is_flux: prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( self.tokenizer, # list @@ -3091,3 +3059,7 @@ class StableDiffusion: def convert_lora_weights_before_load(self, state_dict): # can be overridden in child classes to convert weights before loading return state_dict + + def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): + # can be overridden in child classes to condition latents before noise prediction + return latents diff --git a/toolkit/util/get_model.py b/toolkit/util/get_model.py index 4d1668f8..280fc632 100644 --- a/toolkit/util/get_model.py +++ b/toolkit/util/get_model.py @@ -1,12 +1,49 @@ +import os +from typing import List +from toolkit.models.base_model import BaseModel from toolkit.stable_diffusion_model import StableDiffusion from toolkit.config_modules import ModelConfig +from toolkit.paths import TOOLKIT_ROOT +import importlib +import pkgutil + +from toolkit.models.wan21 import Wan21 +from toolkit.models.cogview4 import CogView4 + +BUILT_IN_MODELS = [ + Wan21, + CogView4, +] + + +def get_all_models() -> List[BaseModel]: + extension_folders = ['extensions', 'extensions_built_in'] + + # This will hold the classes from all extension modules + all_model_classes: List[BaseModel] = BUILT_IN_MODELS + + # Iterate over all directories (i.e., packages) in the "extensions" directory + for sub_dir in extension_folders: + extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir) + for (_, name, _) in pkgutil.iter_modules([extensions_dir]): + try: + # Import the module + module = importlib.import_module(f"{sub_dir}.{name}") + # Get the value of the AI_TOOLKIT_MODELS variable + models = getattr(module, "AI_TOOLKIT_MODELS", None) + # Check if the value is a list + if isinstance(models, list): + # Iterate over the list and add the classes to the main list + all_model_classes.extend(models) + except ImportError as e: + print(f"Failed to import the {name} module. Error: {str(e)}") + return all_model_classes + def get_model_class(config: ModelConfig): - if config.arch == "wan21": - from toolkit.models.wan21 import Wan21 - return Wan21 - elif config.arch == "cogview4": - from toolkit.models.cogview4 import CogView4 - return CogView4 - else: - return StableDiffusion \ No newline at end of file + all_models = get_all_models() + for ModelClass in all_models: + if ModelClass.arch == config.arch: + return ModelClass + # default to the legacy model + return StableDiffusion