Handle multi control inputs for control lora training

This commit is contained in:
Jaret Burkett
2025-03-23 07:37:08 -06:00
parent ccb66c748f
commit f10937e6da
7 changed files with 446 additions and 75 deletions

View File

@@ -1063,7 +1063,6 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, CustomAdapter): if self.adapter and isinstance(self.adapter, CustomAdapter):
# condition the prompt # condition the prompt
# todo handle more than one adapter image # todo handle more than one adapter image
self.adapter.num_control_images = 1
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
network_weight_list = batch.get_network_weight_list() network_weight_list = batch.get_network_weight_list()

View File

@@ -241,6 +241,9 @@ class AdapterConfig:
self.lora_config: NetworkConfig = NetworkConfig(**lora_config) self.lora_config: NetworkConfig = NetworkConfig(**lora_config)
else: else:
self.lora_config = None self.lora_config = None
self.num_control_images: int = kwargs.get('num_control_images', 1)
# decimal for how often the control is dropped out and replaced with noise 1.0 is 100%
self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0)
class EmbeddingConfig: class EmbeddingConfig:
@@ -710,7 +713,7 @@ class DatasetConfig:
self.flip_x: bool = kwargs.get('flip_x', False) self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False) self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', []) self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters) # instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False) self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
@@ -833,6 +836,7 @@ class GenerateImageConfig:
logger: Optional[EmptyLogger] = None, logger: Optional[EmptyLogger] = None,
num_frames: int = 1, num_frames: int = 1,
fps: int = 15, fps: int = 15,
ctrl_idx: int = 0
): ):
self.width: int = width self.width: int = width
self.height: int = height self.height: int = height
@@ -863,6 +867,7 @@ class GenerateImageConfig:
self.extra_values = extra_values if extra_values is not None else [] self.extra_values = extra_values if extra_values is not None else []
self.num_frames = num_frames self.num_frames = num_frames
self.fps = fps self.fps = fps
self.ctrl_idx = ctrl_idx
# prompt string will override any settings above # prompt string will override any settings above
@@ -1056,6 +1061,8 @@ class GenerateImageConfig:
self.num_frames = int(content) self.num_frames = int(content)
elif flag == 'fps': elif flag == 'fps':
self.fps = int(content) self.fps = int(content)
elif flag == 'ctrl_idx':
self.ctrl_idx = int(content)
def post_process_embeddings( def post_process_embeddings(
self, self,

View File

@@ -82,7 +82,7 @@ class CustomAdapter(torch.nn.Module):
self.position_ids: Optional[List[int]] = None self.position_ids: Optional[List[int]] = None
self.num_control_images = 1 self.num_control_images = self.config.num_control_images
self.token_mask: Optional[torch.Tensor] = None self.token_mask: Optional[torch.Tensor] = None
# setup clip # setup clip
@@ -575,19 +575,53 @@ class CustomAdapter(torch.nn.Module):
# concat random normal noise onto the latents # concat random normal noise onto the latents
# check dimension, this is before they are rearranged # check dimension, this is before they are rearranged
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging # it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
latents = torch.cat((latents, torch.randn_like(latents)), dim=1) ctrl = torch.randn(
latents.shape[0], # bs
latents.shape[1] * self.num_control_images, # ch
latents.shape[2],
latents.shape[3],
device=latents.device,
dtype=latents.dtype
)
latents = torch.cat((latents, ctrl), dim=1)
return latents.detach() return latents.detach()
# it is 0-1 need to convert to -1 to 1 # if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w]
control_tensor = control_tensor * 2 - 1 # if we have 1, it comes in like [bs, ch, h, w]
# stack out control tensors to be [bs, ch * num_control_images, h, w]
control_tensor_list = []
if len(control_tensor.shape) == 4:
control_tensor_list.append(control_tensor)
else:
# reshape
control_tensor = control_tensor.view(
control_tensor.shape[0],
control_tensor.shape[1] * control_tensor.shape[2],
control_tensor.shape[3],
control_tensor.shape[4]
)
control_tensor_list = control_tensor.chunk(self.num_control_images, dim=1)
control_latent_list = []
for control_tensor in control_tensor_list:
do_dropout = random.random() < self.config.control_image_dropout
if do_dropout:
# dropout with noise
control_latent_list.append(torch.zeros_like(batch.latents))
else:
# it is 0-1 need to convert to -1 to 1
control_tensor = control_tensor * 2 - 1
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype) control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]: if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic') control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
# encode it # encode it
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype) control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
control_latent_list.append(control_latent)
# stack them on the channel dimension
control_latent = torch.cat(control_latent_list, dim=1)
# concat it onto the latents # concat it onto the latents
latents = torch.cat((latents, control_latent), dim=1) latents = torch.cat((latents, control_latent), dim=1)
return latents.detach() return latents.detach()

View File

@@ -743,75 +743,100 @@ class ControlFileItemDTOMixin:
if hasattr(super(), '__init__'): if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.has_control_image = False self.has_control_image = False
self.control_path: Union[str, None] = None self.control_path: Union[str, List[str], None] = None
self.control_tensor: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.full_size_control_images = False self.full_size_control_images = False
if dataset_config.control_path is not None: if dataset_config.control_path is not None:
# find the control image path # find the control image path
control_path = dataset_config.control_path control_path_list = dataset_config.control_path
if not isinstance(control_path_list, list):
control_path_list = [control_path_list]
self.full_size_control_images = dataset_config.full_size_control_images self.full_size_control_images = dataset_config.full_size_control_images
# we are using control images # we are using control images
img_path = kwargs.get('path', None) img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)): found_control_images = []
self.control_path = os.path.join(control_path, file_name_no_ext + ext) for control_path in control_path_list:
self.has_control_image = True for ext in img_ext_list:
break if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
found_control_images.append(os.path.join(control_path, file_name_no_ext + ext))
self.has_control_image = True
break
self.control_path = found_control_images
if len(self.control_path) == 0:
self.control_path = None
elif len(self.control_path) == 1:
# only do one
self.control_path = self.control_path[0]
def load_control_image(self: 'FileItemDTO'): def load_control_image(self: 'FileItemDTO'):
try: control_tensors = []
img = Image.open(self.control_path).convert('RGB') control_path_list = self.control_path
img = exif_transpose(img) if not isinstance(self.control_path, list):
except Exception as e: control_path_list = [self.control_path]
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {self.control_path}") for control_path in control_path_list:
try:
img = Image.open(control_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {control_path}")
if self.full_size_control_images: if self.full_size_control_images:
# we just scale them to 512x512: # we just scale them to 512x512:
w, h = img.size w, h = img.size
img = img.resize((512, 512), Image.BICUBIC) img = img.resize((512, 512), Image.BICUBIC)
else:
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else: else:
raise Exception("Control images not supported for non-bucket datasets") w, h = img.size
transform = transforms.Compose([ if w > h and self.scale_to_width < self.scale_to_height:
transforms.ToTensor(), # throw error, they should match
]) raise ValueError(
if self.aug_replay_spatial_transforms: f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
self.control_tensor = self.augment_spatial_control(img, transform=transform) elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
tensor = self.augment_spatial_control(img, transform=transform)
else:
tensor = transform(img)
control_tensors.append(tensor)
if len(control_tensors) == 0:
self.control_tensor = None
elif len(control_tensors) == 1:
self.control_tensor = control_tensors[0]
else: else:
self.control_tensor = transform(img) self.control_tensor = torch.stack(control_tensors, dim=0)
def cleanup_control(self: 'FileItemDTO'): def cleanup_control(self: 'FileItemDTO'):
self.control_tensor = None self.control_tensor = None

View File

@@ -46,14 +46,14 @@ class ImgEmbedder(torch.nn.Module):
cls, cls,
model: FluxTransformer2DModel, model: FluxTransformer2DModel,
adapter: 'ControlLoraAdapter', adapter: 'ControlLoraAdapter',
num_channel_multiplier=2 num_control_images=1
): ):
if model.__class__.__name__ == 'FluxTransformer2DModel': if model.__class__.__name__ == 'FluxTransformer2DModel':
x_embedder: torch.nn.Linear = model.x_embedder x_embedder: torch.nn.Linear = model.x_embedder
img_embedder = cls( img_embedder = cls(
adapter, adapter,
orig_layer=x_embedder, orig_layer=x_embedder,
in_channels=x_embedder.in_features * (num_channel_multiplier - 1), # only our new channels in_channels=x_embedder.in_features * num_control_images,
out_channels=x_embedder.out_features, out_channels=x_embedder.out_features,
) )
@@ -62,7 +62,7 @@ class ImgEmbedder(torch.nn.Module):
x_embedder.forward = img_embedder.forward x_embedder.forward = img_embedder.forward
# update the config of the transformer # update the config of the transformer
model.config.in_channels = model.config.in_channels * num_channel_multiplier model.config.in_channels = model.config.in_channels * (num_control_images + 1)
model.config["in_channels"] = model.config.in_channels model.config["in_channels"] = model.config.in_channels
return img_embedder return img_embedder
@@ -178,7 +178,11 @@ class ControlLoraAdapter(torch.nn.Module):
if self.train_config.gradient_checkpointing: if self.train_config.gradient_checkpointing:
self.control_lora.enable_gradient_checkpointing() self.control_lora.enable_gradient_checkpointing()
self.x_embedder = ImgEmbedder.from_model(sd.unet, self) self.x_embedder = ImgEmbedder.from_model(
sd.unet,
self,
num_control_images=config.num_control_images
)
self.x_embedder.to(self.device_torch) self.x_embedder.to(self.device_torch)
def get_params(self): def get_params(self):
@@ -230,6 +234,16 @@ class ControlLoraAdapter(torch.nn.Module):
# todo process state dict before loading # todo process state dict before loading
if self.control_lora is not None: if self.control_lora is not None:
self.control_lora.load_weights(lora_sd) self.control_lora.load_weights(lora_sd)
# automatically upgrade the x imbedder if more dims are added
if self.x_embedder.weight.shape[1] > img_embedder_sd['weight'].shape[1]:
print("Upgrading x_embedder from {} to {}".format(
img_embedder_sd['weight'].shape[1],
self.x_embedder.weight.shape[1]
))
while img_embedder_sd['weight'].shape[1] < self.x_embedder.weight.shape[1]:
img_embedder_sd['weight'] = torch.cat([img_embedder_sd['weight'] ] * 2, dim=1)
if img_embedder_sd['weight'].shape[1] > self.x_embedder.weight.shape[1]:
img_embedder_sd['weight'] = img_embedder_sd['weight'][:, :self.x_embedder.weight.shape[1]]
self.x_embedder.load_state_dict(img_embedder_sd, strict=False) self.x_embedder.load_state_dict(img_embedder_sd, strict=False)
def get_state_dict(self): def get_state_dict(self):

View File

@@ -4,7 +4,7 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import numpy as np import numpy as np
import torch import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline, FluxControlPipeline
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -15,6 +15,7 @@ from diffusers.utils import is_torch_xla_available
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
from diffusers.image_processor import PipelineImageInput
if is_torch_xla_available(): if is_torch_xla_available():
@@ -1423,4 +1424,293 @@ class FluxWithCFGPipeline(FluxPipeline):
if not return_dict: if not return_dict:
return (image,) return (image,)
return FluxPipelineOutput(images=image) return FluxPipelineOutput(images=image)
class FluxAdvancedControlPipeline(FluxControlPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
control_image_idx: int = 0,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Prepare text embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
# num_channels_latents = self.transformer.config.in_channels // 8
num_channels_latents = 128 // 8
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)
if control_image.ndim == 4:
control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# flux has 64 input channels.
total_controls = (self.transformer.config.in_channels // 64) - 1
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
control_image_list = [torch.zeros_like(latents) for _ in range(total_controls)]
control_image_list[control_image_idx] = control_image
latent_model_input = torch.cat([latents] + control_image_list, dim=2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)

View File

@@ -43,7 +43,8 @@ from toolkit.train_tools import get_torch_dtype, apply_noise_offset
from einops import rearrange, repeat from einops import rearrange, repeat
import torch import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline, \
FluxAdvancedControlPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
@@ -1244,7 +1245,7 @@ class StableDiffusion:
if self.adapter is not None and isinstance(self.adapter, CustomAdapter): if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
# see if it is a control lora # see if it is a control lora
if self.adapter.control_lora is not None: if self.adapter.control_lora is not None:
Pipe = FluxControlPipeline Pipe = FluxAdvancedControlPipeline
pipeline = Pipe( pipeline = Pipe(
vae=self.vae, vae=self.vae,
@@ -1367,6 +1368,7 @@ class StableDiffusion:
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
validation_image = validation_image.resize((gen_config.width, gen_config.height)) validation_image = validation_image.resize((gen_config.width, gen_config.height))
extra['control_image'] = validation_image extra['control_image'] = validation_image
extra['control_image_idx'] = gen_config.ctrl_idx
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),