mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Handle multi control inputs for control lora training
This commit is contained in:
@@ -1063,7 +1063,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
# condition the prompt
|
||||
# todo handle more than one adapter image
|
||||
self.adapter.num_control_images = 1
|
||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
|
||||
@@ -241,6 +241,9 @@ class AdapterConfig:
|
||||
self.lora_config: NetworkConfig = NetworkConfig(**lora_config)
|
||||
else:
|
||||
self.lora_config = None
|
||||
self.num_control_images: int = kwargs.get('num_control_images', 1)
|
||||
# decimal for how often the control is dropped out and replaced with noise 1.0 is 100%
|
||||
self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
@@ -710,7 +713,7 @@ class DatasetConfig:
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc
|
||||
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
|
||||
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
@@ -833,6 +836,7 @@ class GenerateImageConfig:
|
||||
logger: Optional[EmptyLogger] = None,
|
||||
num_frames: int = 1,
|
||||
fps: int = 15,
|
||||
ctrl_idx: int = 0
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -863,6 +867,7 @@ class GenerateImageConfig:
|
||||
self.extra_values = extra_values if extra_values is not None else []
|
||||
self.num_frames = num_frames
|
||||
self.fps = fps
|
||||
self.ctrl_idx = ctrl_idx
|
||||
|
||||
|
||||
# prompt string will override any settings above
|
||||
@@ -1056,6 +1061,8 @@ class GenerateImageConfig:
|
||||
self.num_frames = int(content)
|
||||
elif flag == 'fps':
|
||||
self.fps = int(content)
|
||||
elif flag == 'ctrl_idx':
|
||||
self.ctrl_idx = int(content)
|
||||
|
||||
def post_process_embeddings(
|
||||
self,
|
||||
|
||||
@@ -82,7 +82,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
self.position_ids: Optional[List[int]] = None
|
||||
|
||||
self.num_control_images = 1
|
||||
self.num_control_images = self.config.num_control_images
|
||||
self.token_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# setup clip
|
||||
@@ -575,19 +575,53 @@ class CustomAdapter(torch.nn.Module):
|
||||
# concat random normal noise onto the latents
|
||||
# check dimension, this is before they are rearranged
|
||||
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
|
||||
latents = torch.cat((latents, torch.randn_like(latents)), dim=1)
|
||||
ctrl = torch.randn(
|
||||
latents.shape[0], # bs
|
||||
latents.shape[1] * self.num_control_images, # ch
|
||||
latents.shape[2],
|
||||
latents.shape[3],
|
||||
device=latents.device,
|
||||
dtype=latents.dtype
|
||||
)
|
||||
latents = torch.cat((latents, ctrl), dim=1)
|
||||
return latents.detach()
|
||||
# it is 0-1 need to convert to -1 to 1
|
||||
control_tensor = control_tensor * 2 - 1
|
||||
# if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w]
|
||||
# if we have 1, it comes in like [bs, ch, h, w]
|
||||
# stack out control tensors to be [bs, ch * num_control_images, h, w]
|
||||
|
||||
control_tensor_list = []
|
||||
if len(control_tensor.shape) == 4:
|
||||
control_tensor_list.append(control_tensor)
|
||||
else:
|
||||
# reshape
|
||||
control_tensor = control_tensor.view(
|
||||
control_tensor.shape[0],
|
||||
control_tensor.shape[1] * control_tensor.shape[2],
|
||||
control_tensor.shape[3],
|
||||
control_tensor.shape[4]
|
||||
)
|
||||
control_tensor_list = control_tensor.chunk(self.num_control_images, dim=1)
|
||||
control_latent_list = []
|
||||
for control_tensor in control_tensor_list:
|
||||
do_dropout = random.random() < self.config.control_image_dropout
|
||||
if do_dropout:
|
||||
# dropout with noise
|
||||
control_latent_list.append(torch.zeros_like(batch.latents))
|
||||
else:
|
||||
# it is 0-1 need to convert to -1 to 1
|
||||
control_tensor = control_tensor * 2 - 1
|
||||
|
||||
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
|
||||
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
|
||||
|
||||
# encode it
|
||||
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
|
||||
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
|
||||
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
|
||||
|
||||
# encode it
|
||||
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
|
||||
control_latent_list.append(control_latent)
|
||||
# stack them on the channel dimension
|
||||
control_latent = torch.cat(control_latent_list, dim=1)
|
||||
# concat it onto the latents
|
||||
latents = torch.cat((latents, control_latent), dim=1)
|
||||
return latents.detach()
|
||||
|
||||
@@ -743,75 +743,100 @@ class ControlFileItemDTOMixin:
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.has_control_image = False
|
||||
self.control_path: Union[str, None] = None
|
||||
self.control_path: Union[str, List[str], None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.full_size_control_images = False
|
||||
if dataset_config.control_path is not None:
|
||||
# find the control image path
|
||||
control_path = dataset_config.control_path
|
||||
control_path_list = dataset_config.control_path
|
||||
if not isinstance(control_path_list, list):
|
||||
control_path_list = [control_path_list]
|
||||
self.full_size_control_images = dataset_config.full_size_control_images
|
||||
# we are using control images
|
||||
img_path = kwargs.get('path', None)
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
for ext in img_ext_list:
|
||||
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
|
||||
self.control_path = os.path.join(control_path, file_name_no_ext + ext)
|
||||
self.has_control_image = True
|
||||
break
|
||||
|
||||
found_control_images = []
|
||||
for control_path in control_path_list:
|
||||
for ext in img_ext_list:
|
||||
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
|
||||
found_control_images.append(os.path.join(control_path, file_name_no_ext + ext))
|
||||
self.has_control_image = True
|
||||
break
|
||||
self.control_path = found_control_images
|
||||
if len(self.control_path) == 0:
|
||||
self.control_path = None
|
||||
elif len(self.control_path) == 1:
|
||||
# only do one
|
||||
self.control_path = self.control_path[0]
|
||||
|
||||
def load_control_image(self: 'FileItemDTO'):
|
||||
try:
|
||||
img = Image.open(self.control_path).convert('RGB')
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {self.control_path}")
|
||||
control_tensors = []
|
||||
control_path_list = self.control_path
|
||||
if not isinstance(self.control_path, list):
|
||||
control_path_list = [self.control_path]
|
||||
|
||||
for control_path in control_path_list:
|
||||
try:
|
||||
img = Image.open(control_path).convert('RGB')
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {control_path}")
|
||||
|
||||
if self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
w, h = img.size
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
if self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
w, h = img.size
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
else:
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Control images not supported for non-bucket datasets")
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
if self.aug_replay_spatial_transforms:
|
||||
self.control_tensor = self.augment_spatial_control(img, transform=transform)
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Control images not supported for non-bucket datasets")
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
if self.aug_replay_spatial_transforms:
|
||||
tensor = self.augment_spatial_control(img, transform=transform)
|
||||
else:
|
||||
tensor = transform(img)
|
||||
control_tensors.append(tensor)
|
||||
|
||||
if len(control_tensors) == 0:
|
||||
self.control_tensor = None
|
||||
elif len(control_tensors) == 1:
|
||||
self.control_tensor = control_tensors[0]
|
||||
else:
|
||||
self.control_tensor = transform(img)
|
||||
self.control_tensor = torch.stack(control_tensors, dim=0)
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.control_tensor = None
|
||||
|
||||
@@ -46,14 +46,14 @@ class ImgEmbedder(torch.nn.Module):
|
||||
cls,
|
||||
model: FluxTransformer2DModel,
|
||||
adapter: 'ControlLoraAdapter',
|
||||
num_channel_multiplier=2
|
||||
num_control_images=1
|
||||
):
|
||||
if model.__class__.__name__ == 'FluxTransformer2DModel':
|
||||
x_embedder: torch.nn.Linear = model.x_embedder
|
||||
img_embedder = cls(
|
||||
adapter,
|
||||
orig_layer=x_embedder,
|
||||
in_channels=x_embedder.in_features * (num_channel_multiplier - 1), # only our new channels
|
||||
in_channels=x_embedder.in_features * num_control_images,
|
||||
out_channels=x_embedder.out_features,
|
||||
)
|
||||
|
||||
@@ -62,7 +62,7 @@ class ImgEmbedder(torch.nn.Module):
|
||||
x_embedder.forward = img_embedder.forward
|
||||
|
||||
# update the config of the transformer
|
||||
model.config.in_channels = model.config.in_channels * num_channel_multiplier
|
||||
model.config.in_channels = model.config.in_channels * (num_control_images + 1)
|
||||
model.config["in_channels"] = model.config.in_channels
|
||||
|
||||
return img_embedder
|
||||
@@ -178,7 +178,11 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.control_lora.enable_gradient_checkpointing()
|
||||
|
||||
self.x_embedder = ImgEmbedder.from_model(sd.unet, self)
|
||||
self.x_embedder = ImgEmbedder.from_model(
|
||||
sd.unet,
|
||||
self,
|
||||
num_control_images=config.num_control_images
|
||||
)
|
||||
self.x_embedder.to(self.device_torch)
|
||||
|
||||
def get_params(self):
|
||||
@@ -230,6 +234,16 @@ class ControlLoraAdapter(torch.nn.Module):
|
||||
# todo process state dict before loading
|
||||
if self.control_lora is not None:
|
||||
self.control_lora.load_weights(lora_sd)
|
||||
# automatically upgrade the x imbedder if more dims are added
|
||||
if self.x_embedder.weight.shape[1] > img_embedder_sd['weight'].shape[1]:
|
||||
print("Upgrading x_embedder from {} to {}".format(
|
||||
img_embedder_sd['weight'].shape[1],
|
||||
self.x_embedder.weight.shape[1]
|
||||
))
|
||||
while img_embedder_sd['weight'].shape[1] < self.x_embedder.weight.shape[1]:
|
||||
img_embedder_sd['weight'] = torch.cat([img_embedder_sd['weight'] ] * 2, dim=1)
|
||||
if img_embedder_sd['weight'].shape[1] > self.x_embedder.weight.shape[1]:
|
||||
img_embedder_sd['weight'] = img_embedder_sd['weight'][:, :self.x_embedder.weight.shape[1]]
|
||||
self.x_embedder.load_state_dict(img_embedder_sd, strict=False)
|
||||
|
||||
def get_state_dict(self):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline, FluxControlPipeline
|
||||
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
@@ -15,6 +15,7 @@ from diffusers.utils import is_torch_xla_available
|
||||
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
||||
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -1423,4 +1424,293 @@ class FluxWithCFGPipeline(FluxPipeline):
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return FluxPipelineOutput(images=image)
|
||||
return FluxPipelineOutput(images=image)
|
||||
|
||||
|
||||
class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 3.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
control_image_idx: int = 0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
||||
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
||||
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
||||
images must be passed as a list such that each element of the list can be correctly batched for input
|
||||
to a single ControlNet.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
||||
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
||||
images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
text_ids,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
# num_channels_latents = self.transformer.config.in_channels // 8
|
||||
num_channels_latents = 128 // 8
|
||||
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
|
||||
if control_image.ndim == 4:
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# flux has 64 input channels.
|
||||
total_controls = (self.transformer.config.in_channels // 64) - 1
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
control_image_list = [torch.zeros_like(latents) for _ in range(total_controls)]
|
||||
control_image_list[control_image_idx] = control_image
|
||||
|
||||
latent_model_input = torch.cat([latents] + control_image_list, dim=2)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return FluxPipelineOutput(images=image)
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
from einops import rearrange, repeat
|
||||
import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline, \
|
||||
FluxAdvancedControlPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
@@ -1244,7 +1245,7 @@ class StableDiffusion:
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
# see if it is a control lora
|
||||
if self.adapter.control_lora is not None:
|
||||
Pipe = FluxControlPipeline
|
||||
Pipe = FluxAdvancedControlPipeline
|
||||
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
@@ -1367,6 +1368,7 @@ class StableDiffusion:
|
||||
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['control_image'] = validation_image
|
||||
extra['control_image_idx'] = gen_config.ctrl_idx
|
||||
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
|
||||
Reference in New Issue
Block a user