mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Handle multi control inputs for control lora training
This commit is contained in:
@@ -1063,7 +1063,6 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||||
# condition the prompt
|
# condition the prompt
|
||||||
# todo handle more than one adapter image
|
# todo handle more than one adapter image
|
||||||
self.adapter.num_control_images = 1
|
|
||||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||||
|
|
||||||
network_weight_list = batch.get_network_weight_list()
|
network_weight_list = batch.get_network_weight_list()
|
||||||
|
|||||||
@@ -241,6 +241,9 @@ class AdapterConfig:
|
|||||||
self.lora_config: NetworkConfig = NetworkConfig(**lora_config)
|
self.lora_config: NetworkConfig = NetworkConfig(**lora_config)
|
||||||
else:
|
else:
|
||||||
self.lora_config = None
|
self.lora_config = None
|
||||||
|
self.num_control_images: int = kwargs.get('num_control_images', 1)
|
||||||
|
# decimal for how often the control is dropped out and replaced with noise 1.0 is 100%
|
||||||
|
self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfig:
|
class EmbeddingConfig:
|
||||||
@@ -710,7 +713,7 @@ class DatasetConfig:
|
|||||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||||
self.augments: List[str] = kwargs.get('augments', [])
|
self.augments: List[str] = kwargs.get('augments', [])
|
||||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc
|
||||||
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
|
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
|
||||||
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
|
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
|
||||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||||
@@ -833,6 +836,7 @@ class GenerateImageConfig:
|
|||||||
logger: Optional[EmptyLogger] = None,
|
logger: Optional[EmptyLogger] = None,
|
||||||
num_frames: int = 1,
|
num_frames: int = 1,
|
||||||
fps: int = 15,
|
fps: int = 15,
|
||||||
|
ctrl_idx: int = 0
|
||||||
):
|
):
|
||||||
self.width: int = width
|
self.width: int = width
|
||||||
self.height: int = height
|
self.height: int = height
|
||||||
@@ -863,6 +867,7 @@ class GenerateImageConfig:
|
|||||||
self.extra_values = extra_values if extra_values is not None else []
|
self.extra_values = extra_values if extra_values is not None else []
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
|
self.ctrl_idx = ctrl_idx
|
||||||
|
|
||||||
|
|
||||||
# prompt string will override any settings above
|
# prompt string will override any settings above
|
||||||
@@ -1056,6 +1061,8 @@ class GenerateImageConfig:
|
|||||||
self.num_frames = int(content)
|
self.num_frames = int(content)
|
||||||
elif flag == 'fps':
|
elif flag == 'fps':
|
||||||
self.fps = int(content)
|
self.fps = int(content)
|
||||||
|
elif flag == 'ctrl_idx':
|
||||||
|
self.ctrl_idx = int(content)
|
||||||
|
|
||||||
def post_process_embeddings(
|
def post_process_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
self.position_ids: Optional[List[int]] = None
|
self.position_ids: Optional[List[int]] = None
|
||||||
|
|
||||||
self.num_control_images = 1
|
self.num_control_images = self.config.num_control_images
|
||||||
self.token_mask: Optional[torch.Tensor] = None
|
self.token_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# setup clip
|
# setup clip
|
||||||
@@ -575,19 +575,53 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
# concat random normal noise onto the latents
|
# concat random normal noise onto the latents
|
||||||
# check dimension, this is before they are rearranged
|
# check dimension, this is before they are rearranged
|
||||||
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
|
# it is latent_model_input = torch.cat([latents, control_image], dim=2) after rearranging
|
||||||
latents = torch.cat((latents, torch.randn_like(latents)), dim=1)
|
ctrl = torch.randn(
|
||||||
|
latents.shape[0], # bs
|
||||||
|
latents.shape[1] * self.num_control_images, # ch
|
||||||
|
latents.shape[2],
|
||||||
|
latents.shape[3],
|
||||||
|
device=latents.device,
|
||||||
|
dtype=latents.dtype
|
||||||
|
)
|
||||||
|
latents = torch.cat((latents, ctrl), dim=1)
|
||||||
return latents.detach()
|
return latents.detach()
|
||||||
# it is 0-1 need to convert to -1 to 1
|
# if we have multiple control tensors, they come in like [bs, num_control_images, ch, h, w]
|
||||||
control_tensor = control_tensor * 2 - 1
|
# if we have 1, it comes in like [bs, ch, h, w]
|
||||||
|
# stack out control tensors to be [bs, ch * num_control_images, h, w]
|
||||||
|
|
||||||
|
control_tensor_list = []
|
||||||
|
if len(control_tensor.shape) == 4:
|
||||||
|
control_tensor_list.append(control_tensor)
|
||||||
|
else:
|
||||||
|
# reshape
|
||||||
|
control_tensor = control_tensor.view(
|
||||||
|
control_tensor.shape[0],
|
||||||
|
control_tensor.shape[1] * control_tensor.shape[2],
|
||||||
|
control_tensor.shape[3],
|
||||||
|
control_tensor.shape[4]
|
||||||
|
)
|
||||||
|
control_tensor_list = control_tensor.chunk(self.num_control_images, dim=1)
|
||||||
|
control_latent_list = []
|
||||||
|
for control_tensor in control_tensor_list:
|
||||||
|
do_dropout = random.random() < self.config.control_image_dropout
|
||||||
|
if do_dropout:
|
||||||
|
# dropout with noise
|
||||||
|
control_latent_list.append(torch.zeros_like(batch.latents))
|
||||||
|
else:
|
||||||
|
# it is 0-1 need to convert to -1 to 1
|
||||||
|
control_tensor = control_tensor * 2 - 1
|
||||||
|
|
||||||
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
|
control_tensor = control_tensor.to(sd.vae_device_torch, dtype=sd.torch_dtype)
|
||||||
|
|
||||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||||
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
|
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
|
||||||
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
|
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bicubic')
|
||||||
|
|
||||||
# encode it
|
# encode it
|
||||||
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
|
control_latent = sd.encode_images(control_tensor).to(latents.device, latents.dtype)
|
||||||
|
control_latent_list.append(control_latent)
|
||||||
|
# stack them on the channel dimension
|
||||||
|
control_latent = torch.cat(control_latent_list, dim=1)
|
||||||
# concat it onto the latents
|
# concat it onto the latents
|
||||||
latents = torch.cat((latents, control_latent), dim=1)
|
latents = torch.cat((latents, control_latent), dim=1)
|
||||||
return latents.detach()
|
return latents.detach()
|
||||||
|
|||||||
@@ -743,75 +743,100 @@ class ControlFileItemDTOMixin:
|
|||||||
if hasattr(super(), '__init__'):
|
if hasattr(super(), '__init__'):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.has_control_image = False
|
self.has_control_image = False
|
||||||
self.control_path: Union[str, None] = None
|
self.control_path: Union[str, List[str], None] = None
|
||||||
self.control_tensor: Union[torch.Tensor, None] = None
|
self.control_tensor: Union[torch.Tensor, None] = None
|
||||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||||
self.full_size_control_images = False
|
self.full_size_control_images = False
|
||||||
if dataset_config.control_path is not None:
|
if dataset_config.control_path is not None:
|
||||||
# find the control image path
|
# find the control image path
|
||||||
control_path = dataset_config.control_path
|
control_path_list = dataset_config.control_path
|
||||||
|
if not isinstance(control_path_list, list):
|
||||||
|
control_path_list = [control_path_list]
|
||||||
self.full_size_control_images = dataset_config.full_size_control_images
|
self.full_size_control_images = dataset_config.full_size_control_images
|
||||||
# we are using control images
|
# we are using control images
|
||||||
img_path = kwargs.get('path', None)
|
img_path = kwargs.get('path', None)
|
||||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||||
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||||
for ext in img_ext_list:
|
|
||||||
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
|
found_control_images = []
|
||||||
self.control_path = os.path.join(control_path, file_name_no_ext + ext)
|
for control_path in control_path_list:
|
||||||
self.has_control_image = True
|
for ext in img_ext_list:
|
||||||
break
|
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
|
||||||
|
found_control_images.append(os.path.join(control_path, file_name_no_ext + ext))
|
||||||
|
self.has_control_image = True
|
||||||
|
break
|
||||||
|
self.control_path = found_control_images
|
||||||
|
if len(self.control_path) == 0:
|
||||||
|
self.control_path = None
|
||||||
|
elif len(self.control_path) == 1:
|
||||||
|
# only do one
|
||||||
|
self.control_path = self.control_path[0]
|
||||||
|
|
||||||
def load_control_image(self: 'FileItemDTO'):
|
def load_control_image(self: 'FileItemDTO'):
|
||||||
try:
|
control_tensors = []
|
||||||
img = Image.open(self.control_path).convert('RGB')
|
control_path_list = self.control_path
|
||||||
img = exif_transpose(img)
|
if not isinstance(self.control_path, list):
|
||||||
except Exception as e:
|
control_path_list = [self.control_path]
|
||||||
print_acc(f"Error: {e}")
|
|
||||||
print_acc(f"Error loading image: {self.control_path}")
|
for control_path in control_path_list:
|
||||||
|
try:
|
||||||
|
img = Image.open(control_path).convert('RGB')
|
||||||
|
img = exif_transpose(img)
|
||||||
|
except Exception as e:
|
||||||
|
print_acc(f"Error: {e}")
|
||||||
|
print_acc(f"Error loading image: {control_path}")
|
||||||
|
|
||||||
if self.full_size_control_images:
|
if self.full_size_control_images:
|
||||||
# we just scale them to 512x512:
|
# we just scale them to 512x512:
|
||||||
w, h = img.size
|
w, h = img.size
|
||||||
img = img.resize((512, 512), Image.BICUBIC)
|
img = img.resize((512, 512), Image.BICUBIC)
|
||||||
|
|
||||||
else:
|
|
||||||
w, h = img.size
|
|
||||||
if w > h and self.scale_to_width < self.scale_to_height:
|
|
||||||
# throw error, they should match
|
|
||||||
raise ValueError(
|
|
||||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
|
||||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
|
||||||
# throw error, they should match
|
|
||||||
raise ValueError(
|
|
||||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
|
||||||
|
|
||||||
if self.flip_x:
|
|
||||||
# do a flip
|
|
||||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
|
||||||
if self.flip_y:
|
|
||||||
# do a flip
|
|
||||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
|
||||||
|
|
||||||
if self.dataset_config.buckets:
|
|
||||||
# scale and crop based on file item
|
|
||||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
|
||||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
|
||||||
# crop
|
|
||||||
img = img.crop((
|
|
||||||
self.crop_x,
|
|
||||||
self.crop_y,
|
|
||||||
self.crop_x + self.crop_width,
|
|
||||||
self.crop_y + self.crop_height
|
|
||||||
))
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Control images not supported for non-bucket datasets")
|
w, h = img.size
|
||||||
transform = transforms.Compose([
|
if w > h and self.scale_to_width < self.scale_to_height:
|
||||||
transforms.ToTensor(),
|
# throw error, they should match
|
||||||
])
|
raise ValueError(
|
||||||
if self.aug_replay_spatial_transforms:
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||||
self.control_tensor = self.augment_spatial_control(img, transform=transform)
|
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||||
|
# throw error, they should match
|
||||||
|
raise ValueError(
|
||||||
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||||
|
|
||||||
|
if self.flip_x:
|
||||||
|
# do a flip
|
||||||
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
if self.flip_y:
|
||||||
|
# do a flip
|
||||||
|
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
|
if self.dataset_config.buckets:
|
||||||
|
# scale and crop based on file item
|
||||||
|
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||||
|
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||||
|
# crop
|
||||||
|
img = img.crop((
|
||||||
|
self.crop_x,
|
||||||
|
self.crop_y,
|
||||||
|
self.crop_x + self.crop_width,
|
||||||
|
self.crop_y + self.crop_height
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
raise Exception("Control images not supported for non-bucket datasets")
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
if self.aug_replay_spatial_transforms:
|
||||||
|
tensor = self.augment_spatial_control(img, transform=transform)
|
||||||
|
else:
|
||||||
|
tensor = transform(img)
|
||||||
|
control_tensors.append(tensor)
|
||||||
|
|
||||||
|
if len(control_tensors) == 0:
|
||||||
|
self.control_tensor = None
|
||||||
|
elif len(control_tensors) == 1:
|
||||||
|
self.control_tensor = control_tensors[0]
|
||||||
else:
|
else:
|
||||||
self.control_tensor = transform(img)
|
self.control_tensor = torch.stack(control_tensors, dim=0)
|
||||||
|
|
||||||
def cleanup_control(self: 'FileItemDTO'):
|
def cleanup_control(self: 'FileItemDTO'):
|
||||||
self.control_tensor = None
|
self.control_tensor = None
|
||||||
|
|||||||
@@ -46,14 +46,14 @@ class ImgEmbedder(torch.nn.Module):
|
|||||||
cls,
|
cls,
|
||||||
model: FluxTransformer2DModel,
|
model: FluxTransformer2DModel,
|
||||||
adapter: 'ControlLoraAdapter',
|
adapter: 'ControlLoraAdapter',
|
||||||
num_channel_multiplier=2
|
num_control_images=1
|
||||||
):
|
):
|
||||||
if model.__class__.__name__ == 'FluxTransformer2DModel':
|
if model.__class__.__name__ == 'FluxTransformer2DModel':
|
||||||
x_embedder: torch.nn.Linear = model.x_embedder
|
x_embedder: torch.nn.Linear = model.x_embedder
|
||||||
img_embedder = cls(
|
img_embedder = cls(
|
||||||
adapter,
|
adapter,
|
||||||
orig_layer=x_embedder,
|
orig_layer=x_embedder,
|
||||||
in_channels=x_embedder.in_features * (num_channel_multiplier - 1), # only our new channels
|
in_channels=x_embedder.in_features * num_control_images,
|
||||||
out_channels=x_embedder.out_features,
|
out_channels=x_embedder.out_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class ImgEmbedder(torch.nn.Module):
|
|||||||
x_embedder.forward = img_embedder.forward
|
x_embedder.forward = img_embedder.forward
|
||||||
|
|
||||||
# update the config of the transformer
|
# update the config of the transformer
|
||||||
model.config.in_channels = model.config.in_channels * num_channel_multiplier
|
model.config.in_channels = model.config.in_channels * (num_control_images + 1)
|
||||||
model.config["in_channels"] = model.config.in_channels
|
model.config["in_channels"] = model.config.in_channels
|
||||||
|
|
||||||
return img_embedder
|
return img_embedder
|
||||||
@@ -178,7 +178,11 @@ class ControlLoraAdapter(torch.nn.Module):
|
|||||||
if self.train_config.gradient_checkpointing:
|
if self.train_config.gradient_checkpointing:
|
||||||
self.control_lora.enable_gradient_checkpointing()
|
self.control_lora.enable_gradient_checkpointing()
|
||||||
|
|
||||||
self.x_embedder = ImgEmbedder.from_model(sd.unet, self)
|
self.x_embedder = ImgEmbedder.from_model(
|
||||||
|
sd.unet,
|
||||||
|
self,
|
||||||
|
num_control_images=config.num_control_images
|
||||||
|
)
|
||||||
self.x_embedder.to(self.device_torch)
|
self.x_embedder.to(self.device_torch)
|
||||||
|
|
||||||
def get_params(self):
|
def get_params(self):
|
||||||
@@ -230,6 +234,16 @@ class ControlLoraAdapter(torch.nn.Module):
|
|||||||
# todo process state dict before loading
|
# todo process state dict before loading
|
||||||
if self.control_lora is not None:
|
if self.control_lora is not None:
|
||||||
self.control_lora.load_weights(lora_sd)
|
self.control_lora.load_weights(lora_sd)
|
||||||
|
# automatically upgrade the x imbedder if more dims are added
|
||||||
|
if self.x_embedder.weight.shape[1] > img_embedder_sd['weight'].shape[1]:
|
||||||
|
print("Upgrading x_embedder from {} to {}".format(
|
||||||
|
img_embedder_sd['weight'].shape[1],
|
||||||
|
self.x_embedder.weight.shape[1]
|
||||||
|
))
|
||||||
|
while img_embedder_sd['weight'].shape[1] < self.x_embedder.weight.shape[1]:
|
||||||
|
img_embedder_sd['weight'] = torch.cat([img_embedder_sd['weight'] ] * 2, dim=1)
|
||||||
|
if img_embedder_sd['weight'].shape[1] > self.x_embedder.weight.shape[1]:
|
||||||
|
img_embedder_sd['weight'] = img_embedder_sd['weight'][:, :self.x_embedder.weight.shape[1]]
|
||||||
self.x_embedder.load_state_dict(img_embedder_sd, strict=False)
|
self.x_embedder.load_state_dict(img_embedder_sd, strict=False)
|
||||||
|
|
||||||
def get_state_dict(self):
|
def get_state_dict(self):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline
|
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline, FluxControlPipeline
|
||||||
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||||
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
@@ -15,6 +15,7 @@ from diffusers.utils import is_torch_xla_available
|
|||||||
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||||
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
||||||
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
||||||
|
from diffusers.image_processor import PipelineImageInput
|
||||||
|
|
||||||
|
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
@@ -1423,4 +1424,293 @@ class FluxWithCFGPipeline(FluxPipeline):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (image,)
|
return (image,)
|
||||||
|
|
||||||
return FluxPipelineOutput(images=image)
|
return FluxPipelineOutput(images=image)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||||
|
control_image: PipelineImageInput = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 28,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
guidance_scale: float = 3.5,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||||
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
control_image_idx: int = 0,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
prompt_2 (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||||
|
will be used instead
|
||||||
|
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||||
|
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||||
|
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||||
|
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
||||||
|
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
||||||
|
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
||||||
|
images must be passed as a list such that each element of the list can be correctly batched for input
|
||||||
|
to a single ControlNet.
|
||||||
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||||
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||||
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||||
|
will be used.
|
||||||
|
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||||
|
to make generation deterministic.
|
||||||
|
latents (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||||
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
||||||
|
joint_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
callback_on_step_end (`Callable`, *optional*):
|
||||||
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||||
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||||
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||||
|
`callback_on_step_end_tensor_inputs`.
|
||||||
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||||
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||||
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||||
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||||
|
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
||||||
|
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
||||||
|
images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
height = height or self.default_sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.default_sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
prompt_2,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||||
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._guidance_scale = guidance_scale
|
||||||
|
self._joint_attention_kwargs = joint_attention_kwargs
|
||||||
|
self._interrupt = False
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
# 3. Prepare text embeddings
|
||||||
|
lora_scale = (
|
||||||
|
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||||
|
)
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
pooled_prompt_embeds,
|
||||||
|
text_ids,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_2=prompt_2,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
lora_scale=lora_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare latent variables
|
||||||
|
# num_channels_latents = self.transformer.config.in_channels // 8
|
||||||
|
num_channels_latents = 128 // 8
|
||||||
|
|
||||||
|
control_image = self.prepare_image(
|
||||||
|
image=control_image,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
batch_size=batch_size * num_images_per_prompt,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=device,
|
||||||
|
dtype=self.vae.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if control_image.ndim == 4:
|
||||||
|
control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
|
||||||
|
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||||
|
|
||||||
|
height_control_image, width_control_image = control_image.shape[2:]
|
||||||
|
control_image = self._pack_latents(
|
||||||
|
control_image,
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height_control_image,
|
||||||
|
width_control_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
latents, latent_image_ids = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Prepare timesteps
|
||||||
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||||
|
image_seq_len = latents.shape[1]
|
||||||
|
mu = calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
self.scheduler.config.get("base_image_seq_len", 256),
|
||||||
|
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||||
|
self.scheduler.config.get("base_shift", 0.5),
|
||||||
|
self.scheduler.config.get("max_shift", 1.15),
|
||||||
|
)
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
self.scheduler,
|
||||||
|
num_inference_steps,
|
||||||
|
device,
|
||||||
|
sigmas=sigmas,
|
||||||
|
mu=mu,
|
||||||
|
)
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
self._num_timesteps = len(timesteps)
|
||||||
|
|
||||||
|
# handle guidance
|
||||||
|
if self.transformer.config.guidance_embeds:
|
||||||
|
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||||
|
guidance = guidance.expand(latents.shape[0])
|
||||||
|
else:
|
||||||
|
guidance = None
|
||||||
|
|
||||||
|
# flux has 64 input channels.
|
||||||
|
total_controls = (self.transformer.config.in_channels // 64) - 1
|
||||||
|
|
||||||
|
# 6. Denoising loop
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
if self.interrupt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
control_image_list = [torch.zeros_like(latents) for _ in range(total_controls)]
|
||||||
|
control_image_list[control_image_idx] = control_image
|
||||||
|
|
||||||
|
latent_model_input = torch.cat([latents] + control_image_list, dim=2)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
|
noise_pred = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
pooled_projections=pooled_prompt_embeds,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
txt_ids=text_ids,
|
||||||
|
img_ids=latent_image_ids,
|
||||||
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents_dtype = latents.dtype
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|
||||||
|
if latents.dtype != latents_dtype:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||||
|
latents = latents.to(latents_dtype)
|
||||||
|
|
||||||
|
if callback_on_step_end is not None:
|
||||||
|
callback_kwargs = {}
|
||||||
|
for k in callback_on_step_end_tensor_inputs:
|
||||||
|
callback_kwargs[k] = locals()[k]
|
||||||
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||||
|
|
||||||
|
latents = callback_outputs.pop("latents", latents)
|
||||||
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if XLA_AVAILABLE:
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
else:
|
||||||
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||||
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||||
|
image = self.vae.decode(latents, return_dict=False)[0]
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return FluxPipelineOutput(images=image)
|
||||||
|
|
||||||
|
|
||||||
@@ -43,7 +43,8 @@ from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import torch
|
import torch
|
||||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
|
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline, \
|
||||||
|
FluxAdvancedControlPipeline
|
||||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||||
@@ -1244,7 +1245,7 @@ class StableDiffusion:
|
|||||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||||
# see if it is a control lora
|
# see if it is a control lora
|
||||||
if self.adapter.control_lora is not None:
|
if self.adapter.control_lora is not None:
|
||||||
Pipe = FluxControlPipeline
|
Pipe = FluxAdvancedControlPipeline
|
||||||
|
|
||||||
pipeline = Pipe(
|
pipeline = Pipe(
|
||||||
vae=self.vae,
|
vae=self.vae,
|
||||||
@@ -1367,6 +1368,7 @@ class StableDiffusion:
|
|||||||
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
|
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None:
|
||||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||||
extra['control_image'] = validation_image
|
extra['control_image'] = validation_image
|
||||||
|
extra['control_image_idx'] = gen_config.ctrl_idx
|
||||||
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
|
|||||||
Reference in New Issue
Block a user