Handle multi control inputs for control lora training

This commit is contained in:
Jaret Burkett
2025-03-23 07:37:08 -06:00
parent ccb66c748f
commit f10937e6da
7 changed files with 446 additions and 75 deletions

View File

@@ -1063,7 +1063,6 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, CustomAdapter):
# condition the prompt
# todo handle more than one adapter image
self.adapter.num_control_images = 1
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
network_weight_list = batch.get_network_weight_list()

View File

@@ -241,6 +241,9 @@ class AdapterConfig:
self.lora_config: NetworkConfig = NetworkConfig(**lora_config)
else:
self.lora_config = None
self.num_control_images: int = kwargs.get('num_control_images', 1)
# decimal for how often the control is dropped out and replaced with noise 1.0 is 100%
self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0)
class EmbeddingConfig:
@@ -710,7 +713,7 @@ class DatasetConfig:
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
@@ -833,6 +836,7 @@ class GenerateImageConfig:
logger: Optional[EmptyLogger] = None,
num_frames: int = 1,
fps: int = 15,
ctrl_idx: int = 0
):
self.width: int = width
self.height: int = height
@@ -863,6 +867,7 @@ class GenerateImageConfig:
self.extra_values = extra_values if extra_values is not None else []
self.num_frames = num_frames
self.fps = fps
self.ctrl_idx = ctrl_idx
# prompt string will override any settings above
@@ -1056,6 +1061,8 @@ class GenerateImageConfig:
self.num_frames = int(content)
elif flag == 'fps':
self.fps = int(content)
elif flag == 'ctrl_idx':
self.ctrl_idx = int(content)
def post_process_embeddings(
self,

View File

@@ -82,7 +82,7 @@ class CustomAdapter(torch.nn.Module):
self.position_ids: Optional[List[int]] = None
self.num_control_images = 1
self.num_control_images = self.config.num_control_images
self.token_mask: Optional[torch.Tensor] = None
# setup clip
@@ -575,19 +575,53 @@ class CustomAdapter(torch.nn.Module):
# concat random normal noise onto the latents
# check dimension, this is before they are rearranged
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
latents = torch.cat((latents, torch.randn_like(latents)), dim=1)
ctrl = torch.randn(
latents.shape[0], # bs
latents.shape[1] * self.num_control_images, # ch
latents.shape[2],
latents.shape[3],
device=latents.device,
dtype=latents.dtype
)
latents = torch.cat((latents, ctrl), dim=1)
return latents.detach()
# it is 0-1 need to convert to -1 to 1
control_tensor = control_tensor * 2 - 1
# if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w]
# if we have 1, it comes in like [bs, ch, h, w]
# stack out control tensors to be [bs, ch * num_control_images, h, w]
control_tensor_list = []
if len(control_tensor.shape) == 4:
control_tensor_list.append(control_tensor)
else:
# reshape
control_tensor = control_tensor.view(
control_tensor.shape[0],
control_tensor.shape[1] * control_tensor.shape[2],
control_tensor.shape[3],
control_tensor.shape[4]
)
control_tensor_list = control_tensor.chunk(self.num_control_images, dim=1)
control_latent_list = []
for control_tensor in control_tensor_list:
do_dropout = random.random() < self.config.control_image_dropout
if do_dropout:
# dropout with noise
control_latent_list.append(torch.zeros_like(batch.latents))
else:
# it is 0-1 need to convert to -1 to 1
control_tensor = control_tensor * 2 - 1
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
# encode it
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
# encode it
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
control_latent_list.append(control_latent)
# stack them on the channel dimension
control_latent = torch.cat(control_latent_list, dim=1)
# concat it onto the latents
latents = torch.cat((latents, control_latent), dim=1)
return latents.detach()

View File

@@ -743,75 +743,100 @@ class ControlFileItemDTOMixin:
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_control_image = False
self.control_path: Union[str, None] = None
self.control_path: Union[str, List[str], None] = None
self.control_tensor: Union[torch.Tensor, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.full_size_control_images = False
if dataset_config.control_path is not None:
# find the control image path
control_path = dataset_config.control_path
control_path_list = dataset_config.control_path
if not isinstance(control_path_list, list):
control_path_list = [control_path_list]
self.full_size_control_images = dataset_config.full_size_control_images
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
self.control_path = os.path.join(control_path, file_name_no_ext + ext)
self.has_control_image = True
break
found_control_images = []
for control_path in control_path_list:
for ext in img_ext_list:
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
found_control_images.append(os.path.join(control_path, file_name_no_ext + ext))
self.has_control_image = True
break
self.control_path = found_control_images
if len(self.control_path) == 0:
self.control_path = None
elif len(self.control_path) == 1:
# only do one
self.control_path = self.control_path[0]
def load_control_image(self: 'FileItemDTO'):
try:
img = Image.open(self.control_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {self.control_path}")
control_tensors = []
control_path_list = self.control_path
if not isinstance(self.control_path, list):
control_path_list = [self.control_path]
for control_path in control_path_list:
try:
img = Image.open(control_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {control_path}")
if self.full_size_control_images:
# we just scale them to 512x512:
w, h = img.size
img = img.resize((512, 512), Image.BICUBIC)
if self.full_size_control_images:
# we just scale them to 512x512:
w, h = img.size
img = img.resize((512, 512), Image.BICUBIC)
else:
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
self.control_tensor = self.augment_spatial_control(img, transform=transform)
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
tensor = self.augment_spatial_control(img, transform=transform)
else:
tensor = transform(img)
control_tensors.append(tensor)
if len(control_tensors) == 0:
self.control_tensor = None
elif len(control_tensors) == 1:
self.control_tensor = control_tensors[0]
else:
self.control_tensor = transform(img)
self.control_tensor = torch.stack(control_tensors, dim=0)
def cleanup_control(self: 'FileItemDTO'):
self.control_tensor = None

View File

@@ -46,14 +46,14 @@ class ImgEmbedder(torch.nn.Module):
cls,
model: FluxTransformer2DModel,
adapter: 'ControlLoraAdapter',
num_channel_multiplier=2
num_control_images=1
):
if model.__class__.__name__ == 'FluxTransformer2DModel':
x_embedder: torch.nn.Linear = model.x_embedder
img_embedder = cls(
adapter,
orig_layer=x_embedder,
in_channels=x_embedder.in_features * (num_channel_multiplier - 1), # only our new channels
in_channels=x_embedder.in_features * num_control_images,
out_channels=x_embedder.out_features,
)
@@ -62,7 +62,7 @@ class ImgEmbedder(torch.nn.Module):
x_embedder.forward = img_embedder.forward
# update the config of the transformer
model.config.in_channels = model.config.in_channels * num_channel_multiplier
model.config.in_channels = model.config.in_channels * (num_control_images + 1)
model.config["in_channels"] = model.config.in_channels
return img_embedder
@@ -178,7 +178,11 @@ class ControlLoraAdapter(torch.nn.Module):
if self.train_config.gradient_checkpointing:
self.control_lora.enable_gradient_checkpointing()
self.x_embedder = ImgEmbedder.from_model(sd.unet, self)
self.x_embedder = ImgEmbedder.from_model(
sd.unet,
self,
num_control_images=config.num_control_images
)
self.x_embedder.to(self.device_torch)
def get_params(self):
@@ -230,6 +234,16 @@ class ControlLoraAdapter(torch.nn.Module):
# todo process state dict before loading
if self.control_lora is not None:
self.control_lora.load_weights(lora_sd)
# automatically upgrade the x imbedder if more dims are added
if self.x_embedder.weight.shape[1] > img_embedder_sd['weight'].shape[1]:
print("Upgrading x_embedder from {} to {}".format(
img_embedder_sd['weight'].shape[1],
self.x_embedder.weight.shape[1]
))
while img_embedder_sd['weight'].shape[1] < self.x_embedder.weight.shape[1]:
img_embedder_sd['weight'] = torch.cat([img_embedder_sd['weight'] ] * 2, dim=1)
if img_embedder_sd['weight'].shape[1] > self.x_embedder.weight.shape[1]:
img_embedder_sd['weight'] = img_embedder_sd['weight'][:, :self.x_embedder.weight.shape[1]]
self.x_embedder.load_state_dict(img_embedder_sd, strict=False)
def get_state_dict(self):

View File

@@ -4,7 +4,7 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import numpy as np
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline, FluxControlPipeline
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -15,6 +15,7 @@ from diffusers.utils import is_torch_xla_available
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
from diffusers.image_processor import PipelineImageInput
if is_torch_xla_available():
@@ -1423,4 +1424,293 @@ class FluxWithCFGPipeline(FluxPipeline):
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
return FluxPipelineOutput(images=image)
class FluxAdvancedControlPipeline(FluxControlPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
control_image_idx: int = 0,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Prepare text embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
# num_channels_latents = self.transformer.config.in_channels // 8
num_channels_latents = 128 // 8
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)
if control_image.ndim == 4:
control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# flux has 64 input channels.
total_controls = (self.transformer.config.in_channels // 64) - 1
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
control_image_list = [torch.zeros_like(latents) for _ in range(total_controls)]
control_image_list[control_image_idx] = control_image
latent_model_input = torch.cat([latents] + control_image_list, dim=2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)

View File

@@ -43,7 +43,8 @@ from toolkit.train_tools import get_torch_dtype, apply_noise_offset
from einops import rearrange, repeat
import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline, \
FluxAdvancedControlPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
@@ -1244,7 +1245,7 @@ class StableDiffusion:
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
# see if it is a control lora
if self.adapter.control_lora is not None:
Pipe = FluxControlPipeline
Pipe = FluxAdvancedControlPipeline
pipeline = Pipe(
vae=self.vae,
@@ -1367,6 +1368,7 @@ class StableDiffusion:
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
validation_image = validation_image.resize((gen_config.width, gen_config.height))
extra['control_image'] = validation_image
extra['control_image_idx'] = gen_config.ctrl_idx
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
transform = transforms.Compose([
transforms.ToTensor(),