Initial training script for photomaker training. Needs a little more work.

This commit is contained in:
Jaret Burkett
2024-01-15 18:46:26 -07:00
parent 5276975fb0
commit eebd3c8212
8 changed files with 1183 additions and 24 deletions

View File

@@ -14,6 +14,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileIte
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
@@ -148,7 +149,8 @@ class SDTrainer(BaseSDTrainProcess):
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
)[0]
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(
self.train_config.dtype))
# remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
denoised_latent = denoised_latent - residual_noise
@@ -298,7 +300,8 @@ class SDTrainer(BaseSDTrainProcess):
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma,
fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
@@ -631,7 +634,9 @@ class SDTrainer(BaseSDTrainProcess):
self.network.is_active = False
can_disable_adapter = False
was_adapter_active = False
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter)):
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or
isinstance(self.adapter, ReferenceAdapter)
):
can_disable_adapter = True
was_adapter_active = self.adapter.is_active
self.adapter.is_active = False
@@ -698,6 +703,13 @@ class SDTrainer(BaseSDTrainProcess):
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.adapter and isinstance(self.adapter, CustomAdapter):
# condition the prompt
# todo handle more than one adapter image
self.adapter.num_control_images = 1
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
network_weight_list = batch.get_network_weight_list()
if self.train_config.single_item_batching:
network_weight_list = network_weight_list + network_weight_list
@@ -706,7 +718,8 @@ class SDTrainer(BaseSDTrainProcess):
has_clip_image = batch.clip_image_tensor is not None
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
raise ValueError("IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
raise ValueError(
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
match_adapter_assist = False
@@ -752,7 +765,6 @@ class SDTrainer(BaseSDTrainProcess):
if batch.clip_image_tensor is not None:
clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
if batch.mask_tensor is not None:
with self.timer('get_mask_multiplier'):
@@ -879,12 +891,13 @@ class SDTrainer(BaseSDTrainProcess):
mask_multiplier_list,
prompt_2_list
):
if self.train_config.negative_prompt is not None:
# add negative prompt
conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
range(len(conditioned_prompts))]
if prompt_2 is not None:
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
# if self.train_config.negative_prompt is not None:
# # add negative prompt
# conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
# range(len(conditioned_prompts))]
# if prompt_2 is not None:
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
with network:
# encode clip adapter here so embeds are active for tokenizer
@@ -977,7 +990,6 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
image_size = self.adapter.input_size
@@ -1029,11 +1041,11 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.do_cfg:
unconditional_clip_embeds = unconditional_clip_embeds.detach()
with self.timer('encode_adapter'):
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
if self.train_config.do_cfg:
unconditional_embeds = self.adapter(unconditional_embeds.detach(), unconditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
unconditional_clip_embeds)
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
# pass in our scheduler
@@ -1060,7 +1072,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
do_inverted_masked_prior = True
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
if ((
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
@@ -1074,6 +1087,25 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeds=unconditional_embeds
)
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
self.adapter.train()
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=conditional_embeds,
is_training=True,
has_been_preprocessed=True,
)
if self.train_config.do_cfg and unconditional_embeds is not None:
unconditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=unconditional_embeds,
is_training=True,
has_been_preprocessed=True,
is_unconditional=True
)
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss:

View File

@@ -18,6 +18,7 @@ import torch.backends.cuda
from toolkit.basic import value_map
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
@@ -34,7 +35,7 @@ from toolkit.progress_bar import ToolkitProgressBar
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.sampler import get_sampler
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
load_ip_adapter_model
load_ip_adapter_model, load_custom_adapter_model
from toolkit.scheduler import get_lr_scheduler
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
@@ -141,7 +142,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, None] = None
self.embedding: Union[Embedding, None] = None
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
@@ -412,8 +413,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
adapter_name += '_t2i'
elif self.adapter_config.type == 'clip':
adapter_name += '_clip'
else:
elif self.adapter_config.type == 'ip':
adapter_name += '_ip'
else:
adapter_name += '_adapter'
filename = f'{adapter_name}{step_num}.safetensors'
file_path = os.path.join(self.save_root, filename)
@@ -931,8 +934,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
suffix = 'clip'
elif self.adapter_config.type == 'reference':
suffix = 'ref'
else:
elif self.adapter_config.type.startswith('ip'):
suffix = 'ip'
else:
suffix = 'adapter'
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_{suffix}"
@@ -967,13 +972,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
sd=self.sd,
adapter_config=self.adapter_config,
)
else:
elif self.adapter_config.type.startswith('ip'):
self.adapter = IPAdapter(
sd=self.sd,
adapter_config=self.adapter_config,
)
if self.train_config.gradient_checkpointing:
self.adapter.enable_gradient_checkpointing()
else:
self.adapter = CustomAdapter(
sd=self.sd,
adapter_config=self.adapter_config,
)
self.adapter.to(self.device_torch, dtype=dtype)
if latest_save_path is not None:
# load adapter from path
@@ -985,7 +995,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
else:
elif self.adapter_config.type.startswith('ip'):
# ip adapter
loaded_state_dict = load_ip_adapter_model(
latest_save_path,
@@ -993,6 +1003,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
else:
# custom adapter
loaded_state_dict = load_custom_adapter_model(
latest_save_path,
self.device,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
if self.adapter_config.train:
self.load_training_state_from_metadata(latest_save_path)
# set trainable params

View File

@@ -127,7 +127,7 @@ class NetworkConfig:
self.conv = 4
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora']
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker']
class AdapterConfig:
@@ -162,6 +162,8 @@ class AdapterConfig:
self.trigger = kwargs.get('trigger', 'tri993r')
self.trigger_class_name = kwargs.get('trigger_class_name', 'person')
self.class_names = kwargs.get('class_names', [])
class EmbeddingConfig:
def __init__(self, **kwargs):

417
toolkit/custom_adapter.py Normal file
View File

@@ -0,0 +1,417 @@
import torch
import sys
from PIL import Image
from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig, AdapterTypes
from toolkit.prompt_utils import PromptEmbeds
import weakref
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
CLIPVisionModel,
AutoImageProcessor,
ConvNextModel,
ConvNextForImageClassification,
ConvNextImageProcessor
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch.nn.functional as F
class CustomAdapter(torch.nn.Module):
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
super().__init__()
self.config = adapter_config
self.sd_ref: weakref.ref = weakref.ref(sd)
self.device = self.sd_ref().unet.device
self.image_processor: CLIPImageProcessor = None
self.input_size = 224
self.adapter_type: AdapterTypes = self.config.type
self.current_scale = 1.0
self.is_active = True
self.flag_word = "fla9wor0"
self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None
self.fuse_module: FuseModule = None
self.lora: None = None
self.position_ids: Optional[List[int]] = None
self.num_control_images = 1
self.token_mask: Optional[torch.Tensor] = None
# setup clip
self.setup_clip()
# add for dataloader
self.clip_image_processor = self.image_processor
self.setup_adapter()
# try to load from our name_or_path
if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
# add the trigger word to the tokenizer
if isinstance(self.sd_ref().tokenizer, list):
for tokenizer in self.sd_ref().tokenizer:
tokenizer.add_tokens([self.flag_word], special_tokens=True)
else:
self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
def setup_adapter(self):
if self.adapter_type == 'photo_maker':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
self.fuse_module = FuseModule(embed_dim)
else:
raise ValueError(f"unknown adapter type: {self.adapter_type}")
def forward(self, *args, **kwargs):
if self.adapter_type == 'photo_maker':
id_pixel_values = args[0]
prompt_embeds: PromptEmbeds = args[1]
class_tokens_mask = args[2]
grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
with torch.set_grad_enabled(grads_on_image_encoder):
id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
if not grads_on_image_encoder:
id_embeds = id_embeds.detach()
prompt_embeds = prompt_embeds.detach()
updated_prompt_embeds = self.fuse_module(
prompt_embeds, id_embeds, class_tokens_mask
)
return updated_prompt_embeds
else:
raise NotImplementedError
def setup_clip(self):
adapter_config = self.config
sd = self.sd_ref()
if self.config.type == 'photo_maker':
try:
self.image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path)
except EnvironmentError:
self.image_processor = CLIPImageProcessor()
if self.config.image_encoder_path is None:
self.vision_encoder = PhotoMakerCLIPEncoder()
else:
self.vision_encoder = PhotoMakerCLIPEncoder.from_pretrained(self.config.image_encoder_path)
elif self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
try:
self.image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.image_processor = CLIPImageProcessor()
self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained(
adapter_config.image_encoder_path,
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'siglip':
from transformers import SiglipImageProcessor, SiglipVisionModel
try:
self.image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.image_processor = SiglipImageProcessor()
self.vision_encoder = SiglipVisionModel.from_pretrained(
adapter_config.image_encoder_path,
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'vit':
try:
self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.image_processor = ViTFeatureExtractor()
self.vision_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'safe':
try:
self.image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.image_processor = SAFEImageProcessor()
self.vision_encoder = SAFEVisionModel(
in_channels=3,
num_tokens=self.config.safe_tokens,
num_vectors=sd.unet.config['cross_attention_dim'],
reducer_channels=self.config.safe_reducer_channels,
channels=self.config.safe_channels,
downscale_factor=8
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'convnext':
try:
self.image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
print(f"could not load image processor from {adapter_config.image_encoder_path}")
self.image_processor = ConvNextImageProcessor(
size=320,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
)
self.vision_encoder = ConvNextForImageClassification.from_pretrained(
adapter_config.image_encoder_path,
use_safetensors=True,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'vit-hybrid':
try:
self.image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
print(f"could not load image processor from {adapter_config.image_encoder_path}")
self.image_processor = ViTHybridImageProcessor(
size=320,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
)
self.vision_encoder = ViTHybridForImageClassification.from_pretrained(
adapter_config.image_encoder_path,
use_safetensors=True,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
else:
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
self.input_size = self.vision_encoder.config.image_size
if self.config.image_encoder_arch == 'clip+':
# self.image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
preprocessor_input_size = self.vision_encoder.config.image_size * 4
# update the preprocessor so images come in at the right size
self.image_processor.size['shortest_edge'] = preprocessor_input_size
self.image_processor.crop_size['height'] = preprocessor_input_size
self.image_processor.crop_size['width'] = preprocessor_input_size
self.preprocessor = CLIPImagePreProcessor(
input_size=preprocessor_input_size,
clip_input_size=self.vision_encoder.config.image_size,
)
if 'height' in self.image_processor.size:
self.input_size = self.image_processor.size['height']
else:
self.input_size = self.image_processor.crop_size['height']
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
if 'lora_weights' in state_dict:
# todo add LoRA
# self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
# self.sd_ref().pipeline.fuse_lora()
pass
if 'id_encoder' in state_dict and self.adapter_type == 'photo_maker':
self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict)
# check to see if the fuse weights are there
fuse_weights = {}
for k, v in state_dict['id_encoder'].items():
if k.startswith('fuse_module'):
k = k.replace('fuse_module.', '')
fuse_weights[k] = v
if len(fuse_weights) > 0:
self.fuse_module.load_state_dict(fuse_weights, strict=strict)
if 'fuse_module' in state_dict:
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
pass
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
if self.adapter_type == 'photo_maker':
if self.train_image_encoder:
state_dict["id_encoder"] = self.vision_encoder.state_dict()
state_dict["fuse_module"] = self.fuse_module.state_dict()
# todo save LoRA
else:
raise NotImplementedError
def condition_prompt(
self,
prompt: Union[List[str], str],
is_unconditional: bool = False,
):
if self.adapter_type == 'photo_maker':
if is_unconditional:
return prompt
else:
with torch.no_grad():
was_list = isinstance(prompt, list)
if not was_list:
prompt_list = [prompt]
else:
prompt_list = prompt
new_prompt_list = []
token_mask_list = []
for prompt in prompt_list:
our_class = None
# find a class in the prompt
prompt_parts = prompt.split(' ')
prompt_parts = [p.strip().lower() for p in prompt_parts if len(p) > 0]
new_prompt_parts = []
tokened_prompt_parts = []
for idx, prompt_part in enumerate(prompt_parts):
new_prompt_parts.append(prompt_part)
tokened_prompt_parts.append(prompt_part)
if prompt_part in self.config.class_names:
our_class = prompt_part
# add the flag word
tokened_prompt_parts.append(self.flag_word)
if self.num_control_images > 1:
# add the rest
for _ in range(self.num_control_images - 1):
new_prompt_parts.extend(prompt_parts[idx + 1:])
# add the rest
tokened_prompt_parts.extend(prompt_parts[idx + 1:])
new_prompt_parts.extend(prompt_parts[idx + 1:])
break
prompt = " ".join(new_prompt_parts)
tokened_prompt = " ".join(tokened_prompt_parts)
if our_class is None:
# add the first one to the front of the prompt
tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt
our_class = self.config.class_names[0]
prompt = " ".join([self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
# add the prompt to the list
new_prompt_list.append(prompt)
# tokenize them with just the first tokenizer
tokenizer = self.sd_ref().tokenizer
if isinstance(tokenizer, list):
tokenizer = tokenizer[0]
flag_token = tokenizer.convert_tokens_to_ids(self.flag_word)
tokenized_prompt = tokenizer.encode(prompt)
tokenized_tokened_prompt = tokenizer.encode(tokened_prompt)
flag_idx = tokenized_tokened_prompt.index(flag_token)
class_token = tokenized_prompt[flag_idx - 1]
boolean_mask = torch.zeros(flag_idx-1, dtype=torch.bool)
boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool)))
boolean_mask = boolean_mask.to(self.device)
# zero pad it to 77
boolean_mask = F.pad(boolean_mask, (0, 77 - boolean_mask.shape[0]), value=False)
token_mask_list.append(boolean_mask)
self.token_mask = torch.cat(token_mask_list, dim=0).to(self.device)
prompt_list = new_prompt_list
if not was_list:
prompt = prompt_list[0]
else:
prompt = prompt_list
return prompt
def condition_encoded_embeds(
self,
tensors_0_1: torch.Tensor,
prompt_embeds: PromptEmbeds,
is_training=False,
has_been_preprocessed=False,
is_unconditional=False
) -> PromptEmbeds:
if self.adapter_type == 'photo_maker':
if is_unconditional:
# we dont condition the negative embeds for photo maker
return prompt_embeds
with torch.no_grad():
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
# tensors should be 0-1
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
clip_image = clip_image.unsqueeze(1)
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
do_projection2=isinstance(self.sd_ref().text_encoder, list),
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list)
).detach()
prompt_embeds.text_embeds = self.fuse_module(
prompt_embeds.text_embeds,
id_embeds,
self.token_mask
)
return prompt_embeds
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.type == 'photo_maker':
yield from self.fuse_module.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)

141
toolkit/photomaker.py Normal file
View File

@@ -0,0 +1,141 @@
# Merge image encoder and fuse module to create an ID Encoder
# send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding
import torch
import torch.nn as nn
from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection
from transformers.models.clip.configuration_clip import CLIPVisionConfig
from transformers import PretrainedConfig
VISION_CONFIG_DICT = {
"hidden_size": 1024,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768
}
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = nn.LayerNorm(in_dim)
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
if self.use_residual:
x = x + residual
return x
class FuseModule(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
self.layer_norm = nn.LayerNorm(embed_dim)
def fuse_fn(self, prompt_embeds, id_embeds):
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
stacked_id_embeds = self.mlp2(stacked_id_embeds)
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
return stacked_id_embeds
def forward(
self,
prompt_embeds,
id_embeds,
class_tokens_mask,
) -> torch.Tensor:
# id_embeds shape: [b, max_num_inputs, 1, 2048]
id_embeds = id_embeds.to(prompt_embeds.dtype)
num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
batch_size, max_num_inputs = id_embeds.shape[:2]
# seq_length: 77
seq_length = prompt_embeds.shape[1]
# flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
flat_id_embeds = id_embeds.view(
-1, id_embeds.shape[-2], id_embeds.shape[-1]
)
# valid_id_mask [b*max_num_inputs]
valid_id_mask = (
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
< num_inputs[:, None]
)
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
class_tokens_mask = class_tokens_mask.view(-1)
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
# slice out the image token embeddings
image_token_embeds = prompt_embeds[class_tokens_mask]
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
return updated_prompt_embeds
class PhotoMakerIDEncoder(CLIPVisionModelWithProjection):
def __init__(self, config=None, *model_args, **model_kwargs):
if config is None:
config = CLIPVisionConfig(**VISION_CONFIG_DICT)
super().__init__(config, *model_args, **model_kwargs)
self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
self.fuse_module = FuseModule(2048)
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
b, num_inputs, c, h, w = id_pixel_values.shape
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
shared_id_embeds = self.vision_model(id_pixel_values)[1]
id_embeds = self.visual_projection(shared_id_embeds)
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
updated_prompt_embeds = self.fuse_module(
prompt_embeds, id_embeds, class_tokens_mask)
return updated_prompt_embeds
class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
def __init__(self, config=None, *model_args, **model_kwargs):
if config is None:
config = CLIPVisionConfig(**VISION_CONFIG_DICT)
super().__init__(config, *model_args, **model_kwargs)
self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
def forward(self, id_pixel_values, do_projection2=True):
b, num_inputs, c, h, w = id_pixel_values.shape
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
shared_id_embeds = self.vision_model(id_pixel_values)[1]
id_embeds = self.visual_projection(shared_id_embeds)
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
if do_projection2:
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
return id_embeds
if __name__ == "__main__":
PhotoMakerIDEncoder()

View File

@@ -0,0 +1,491 @@
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from collections import OrderedDict
import os
import PIL
import numpy as np
import torch
from torchvision import transforms as T
from safetensors import safe_open
from huggingface_hub.utils import validate_hf_hub_args
from transformers import CLIPImageProcessor, CLIPTokenizer
from diffusers import StableDiffusionXLPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.utils import (
_get_model_file,
is_transformers_available,
logging,
)
from .photomaker import PhotoMakerIDEncoder
PipelineImageInput = Union[
PIL.Image.Image,
torch.FloatTensor,
List[PIL.Image.Image],
List[torch.FloatTensor],
]
class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
@validate_hf_hub_args
def load_photomaker_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
weight_name: str,
subfolder: str = '',
trigger_word: str = 'img',
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
weight_name (`str`):
The weight name NOT the path to the weight.
subfolder (`str`, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
trigger_word (`str`, *optional*, defaults to `"img"`):
The trigger word is used to identify the position of class word in the text prompt,
and it is recommended not to set it as a common word.
This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation.
"""
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"id_encoder": {}, "lora_weights": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("id_encoder."):
state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key)
elif key.startswith("lora_weights."):
state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if keys != ["id_encoder", "lora_weights"]:
raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.")
self.trigger_word = trigger_word
# load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet
print(f"Loading PhotoMaker components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...")
id_encoder = PhotoMakerIDEncoder()
id_encoder.load_state_dict(state_dict["id_encoder"], strict=True)
id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype)
self.id_encoder = id_encoder
self.id_image_processor = CLIPImageProcessor()
# load lora into models
print(f"Loading PhotoMaker components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]")
self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
# Add trigger word token
if self.tokenizer is not None:
self.tokenizer.add_tokens([self.trigger_word], special_tokens=True)
self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True)
def encode_prompt_with_trigger_word(
self,
prompt: str,
prompt_2: Optional[str] = None,
num_id_images: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
class_tokens_mask: Optional[torch.LongTensor] = None,
):
device = device or self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Find the token id of the trigger word
image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word)
# Define tokenizers and text encoders
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
text_encoders = (
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
input_ids = tokenizer.encode(prompt) # TODO: batch encode
clean_index = 0
clean_input_ids = []
class_token_index = []
# Find out the corrresponding class word token based on the newly added trigger word token
for i, token_id in enumerate(input_ids):
if token_id == image_token_id:
class_token_index.append(clean_index - 1)
else:
clean_input_ids.append(token_id)
clean_index += 1
if len(class_token_index) != 1:
raise ValueError(
f"PhotoMaker currently does not support multiple trigger words in a single prompt.\
Trigger word: {self.trigger_word}, Prompt: {prompt}."
)
class_token_index = class_token_index[0]
# Expand the class word token and corresponding mask
class_token = clean_input_ids[class_token_index]
clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \
clean_input_ids[class_token_index + 1:]
# Truncation or padding
max_len = tokenizer.model_max_length
if len(clean_input_ids) > max_len:
clean_input_ids = clean_input_ids[:max_len]
else:
clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
max_len - len(clean_input_ids)
)
class_tokens_mask = [True if class_token_index <= i < class_token_index + num_id_images else False \
for i in range(len(clean_input_ids))]
clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0)
class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0)
prompt_embeds = text_encoder(
clean_input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case
return prompt_embeds, pooled_prompt_embeds, class_tokens_mask
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
# Added parameters (for PhotoMaker)
input_id_images: PipelineImageInput = None,
start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future
class_tokens_mask: Optional[torch.LongTensor] = None,
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Only the parameters introduced by PhotoMaker are discussed here.
For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Args:
input_id_images (`PipelineImageInput`, *optional*):
Input ID Image to work with PhotoMaker.
class_tokens_mask (`torch.LongTensor`, *optional*):
Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word.
prompt_embeds_text_only (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
#
if prompt_embeds is not None and class_tokens_mask is None:
raise ValueError(
"If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`."
)
# check the input id images
if input_id_images is None:
raise ValueError(
"Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline."
)
if not isinstance(input_id_images, list):
input_id_images = [input_id_images]
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
assert do_classifier_free_guidance
# 3. Encode input prompt
num_id_images = len(input_id_images)
(
prompt_embeds,
pooled_prompt_embeds,
class_tokens_mask,
) = self.encode_prompt_with_trigger_word(
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_id_images=num_id_images,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
class_tokens_mask=class_tokens_mask,
)
# 4. Encode input prompt without the trigger word for delayed conditioning
prompt_text_only = prompt.replace(" " + self.trigger_word, "") # sensitive to white space
(
prompt_embeds_text_only,
negative_prompt_embeds,
pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt_text_only,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds_text_only,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds_text_only,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
# 5. Prepare the input ID images
dtype = next(self.id_encoder.parameters()).dtype
if not isinstance(input_id_images[0], torch.Tensor):
id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values
id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts
# 6. Get the update text embedding with the stacked ID embedding
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
# 7. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 8. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 10. Prepare added time ids & embeddings
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 11. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if i <= start_merge_step:
current_prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds_text_only], dim=0
)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0)
else:
current_prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0
)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=current_prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = latents
return StableDiffusionXLPipelineOutput(images=image)
# apply watermark if available
# if self.watermark is not None:
# image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)

View File

@@ -246,6 +246,25 @@ def load_ip_adapter_model(
else:
return torch.load(path_to_file, map_location=device)
def load_custom_adapter_model(
path_to_file,
device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32
):
# check if it is safetensors or checkpoint
if path_to_file.endswith('.safetensors'):
raw_state_dict = load_file(path_to_file, device)
combined_state_dict = OrderedDict()
for combo_key, value in raw_state_dict.items():
key_split = combo_key.split('.')
module_name = key_split.pop(0)
if module_name not in combined_state_dict:
combined_state_dict[module_name] = OrderedDict()
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
return combined_state_dict
else:
return torch.load(path_to_file, map_location=device)
def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
lora_keymap = OrderedDict()

View File

@@ -19,6 +19,7 @@ from tqdm import tqdm
from torchvision.transforms import Resize, transforms
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
@@ -483,6 +484,13 @@ class StableDiffusion:
transforms.ToTensor(),
])
validation_image = transform(validation_image)
if isinstance(self.adapter, CustomAdapter):
# todo allow loading multiple
transform = transforms.Compose([
transforms.ToTensor(),
])
validation_image = transform(validation_image)
self.adapter.num_images = 1
if isinstance(self.adapter, ReferenceAdapter):
# need -1 to 1
validation_image = transforms.ToTensor()(validation_image)
@@ -501,6 +509,19 @@ class StableDiffusion:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
self.adapter(conditional_clip_embeds)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
# handle condition the prompts
gen_config.prompt = self.adapter.condition_prompt(
gen_config.prompt,
is_unconditional=False,
)
gen_config.prompt_2 = gen_config.prompt
gen_config.negative_prompt = self.adapter.condition_prompt(
gen_config.negative_prompt,
is_unconditional=True,
)
gen_config.negative_prompt_2 = gen_config.negative_prompt
# encode the prompt ourselves so we can do fun stuff with embeddings
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
@@ -524,6 +545,21 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=validation_image,
prompt_embeds=conditional_embeds,
is_training=False,
has_been_preprocessed=False,
)
unconditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=validation_image,
prompt_embeds=unconditional_embeds,
is_training=False,
has_been_preprocessed=False,
is_unconditional=True,
)
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
# if we have a refiner loaded, set the denoising end at the refiner start
extra['denoising_end'] = gen_config.refiner_start_at
@@ -1468,6 +1504,9 @@ class StableDiffusion:
elif isinstance(self.adapter, ClipVisionAdapter):
requires_grad = self.adapter.embedder.training
adapter_device = self.adapter.device
elif isinstance(self.adapter, CustomAdapter):
requires_grad = self.adapter.training
adapter_device = self.adapter.device
elif isinstance(self.adapter, ReferenceAdapter):
# todo update this!!
requires_grad = True