From eebd3c8212705bf55b17307a8f4c343c62ce8924 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 15 Jan 2024 18:46:26 -0700 Subject: [PATCH] Initial training script for photomaker training. Needs a little more work. --- extensions_built_in/sd_trainer/SDTrainer.py | 66 ++- jobs/process/BaseSDTrainProcess.py | 30 +- toolkit/config_modules.py | 4 +- toolkit/custom_adapter.py | 417 +++++++++++++++++ toolkit/photomaker.py | 141 ++++++ toolkit/photomaker_pipeline.py | 491 ++++++++++++++++++++ toolkit/saving.py | 19 + toolkit/stable_diffusion_model.py | 39 ++ 8 files changed, 1183 insertions(+), 24 deletions(-) create mode 100644 toolkit/custom_adapter.py create mode 100644 toolkit/photomaker.py create mode 100644 toolkit/photomaker_pipeline.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 64f1da48..56ec7134 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -14,6 +14,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileIte from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss from toolkit.image_utils import show_tensors, show_latents from toolkit.ip_adapter import IPAdapter +from toolkit.custom_adapter import CustomAdapter from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork @@ -145,10 +146,11 @@ class SDTrainer(BaseSDTrainProcess): self.sd.noise_scheduler._step_index = None denoised_latent = self.sd.noise_scheduler.step( - pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False + pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False )[0] - residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype( + self.train_config.dtype)) # remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically) denoised_latent = denoised_latent - residual_noise @@ -232,7 +234,7 @@ class SDTrainer(BaseSDTrainProcess): pred = noise_pred if self.train_config.train_turbo: - pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch) + pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch) ignore_snr = False @@ -298,7 +300,8 @@ class SDTrainer(BaseSDTrainProcess): loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: # add snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, + fixed=True) elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: # add min_snr_gamma loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) @@ -631,7 +634,9 @@ class SDTrainer(BaseSDTrainProcess): self.network.is_active = False can_disable_adapter = False was_adapter_active = False - if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter)): + if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or + isinstance(self.adapter, ReferenceAdapter) + ): can_disable_adapter = True was_adapter_active = self.adapter.is_active self.adapter.is_active = False @@ -698,6 +703,13 @@ class SDTrainer(BaseSDTrainProcess): batch = self.preprocess_batch(batch) dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + + if self.adapter and isinstance(self.adapter, CustomAdapter): + # condition the prompt + # todo handle more than one adapter image + self.adapter.num_control_images = 1 + conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) + network_weight_list = batch.get_network_weight_list() if self.train_config.single_item_batching: network_weight_list = network_weight_list + network_weight_list @@ -706,7 +718,8 @@ class SDTrainer(BaseSDTrainProcess): has_clip_image = batch.clip_image_tensor is not None if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: - raise ValueError("IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") + raise ValueError( + "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") match_adapter_assist = False @@ -752,7 +765,6 @@ class SDTrainer(BaseSDTrainProcess): if batch.clip_image_tensor is not None: clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() - mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) if batch.mask_tensor is not None: with self.timer('get_mask_multiplier'): @@ -879,12 +891,13 @@ class SDTrainer(BaseSDTrainProcess): mask_multiplier_list, prompt_2_list ): - if self.train_config.negative_prompt is not None: - # add negative prompt - conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in - range(len(conditioned_prompts))] - if prompt_2 is not None: - prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] + + # if self.train_config.negative_prompt is not None: + # # add negative prompt + # conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in + # range(len(conditioned_prompts))] + # if prompt_2 is not None: + # prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] with network: # encode clip adapter here so embeds are active for tokenizer @@ -977,7 +990,6 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals - if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): image_size = self.adapter.input_size @@ -1029,11 +1041,11 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.do_cfg: unconditional_clip_embeds = unconditional_clip_embeds.detach() - with self.timer('encode_adapter'): conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds) if self.train_config.do_cfg: - unconditional_embeds = self.adapter(unconditional_embeds.detach(), unconditional_clip_embeds) + unconditional_embeds = self.adapter(unconditional_embeds.detach(), + unconditional_clip_embeds) if self.adapter and isinstance(self.adapter, ReferenceAdapter): # pass in our scheduler @@ -1060,7 +1072,8 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: do_inverted_masked_prior = True - if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior): + if (( + has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior): with self.timer('prior predict'): prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, @@ -1074,6 +1087,25 @@ class SDTrainer(BaseSDTrainProcess): unconditional_embeds=unconditional_embeds ) + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: + self.adapter.train() + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=conditional_embeds, + is_training=True, + has_been_preprocessed=True, + ) + if self.train_config.do_cfg and unconditional_embeds is not None: + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=unconditional_embeds, + is_training=True, + has_been_preprocessed=True, + is_unconditional=True + ) + + self.before_unet_predict() # do a prior pred if we have an unconditional image, we will swap out the giadance later if batch.unconditional_latents is not None or self.do_guided_loss: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index fcccb68c..36b0cb1a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -18,6 +18,7 @@ import torch.backends.cuda from toolkit.basic import value_map from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.embedding import Embedding @@ -34,7 +35,7 @@ from toolkit.progress_bar import ToolkitProgressBar from toolkit.reference_adapter import ReferenceAdapter from toolkit.sampler import get_sampler from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \ - load_ip_adapter_model + load_ip_adapter_model, load_custom_adapter_model from toolkit.scheduler import get_lr_scheduler from toolkit.sd_device_states_presets import get_train_sd_device_state_preset @@ -141,7 +142,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network: Union[Network, None] = None - self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, None] = None + self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, None] = None self.embedding: Union[Embedding, None] = None is_training_adapter = self.adapter_config is not None and self.adapter_config.train @@ -412,8 +413,10 @@ class BaseSDTrainProcess(BaseTrainProcess): adapter_name += '_t2i' elif self.adapter_config.type == 'clip': adapter_name += '_clip' - else: + elif self.adapter_config.type == 'ip': adapter_name += '_ip' + else: + adapter_name += '_adapter' filename = f'{adapter_name}{step_num}.safetensors' file_path = os.path.join(self.save_root, filename) @@ -931,8 +934,10 @@ class BaseSDTrainProcess(BaseTrainProcess): suffix = 'clip' elif self.adapter_config.type == 'reference': suffix = 'ref' - else: + elif self.adapter_config.type.startswith('ip'): suffix = 'ip' + else: + suffix = 'adapter' adapter_name = self.name if self.network_config is not None: adapter_name = f"{adapter_name}_{suffix}" @@ -967,13 +972,18 @@ class BaseSDTrainProcess(BaseTrainProcess): sd=self.sd, adapter_config=self.adapter_config, ) - else: + elif self.adapter_config.type.startswith('ip'): self.adapter = IPAdapter( sd=self.sd, adapter_config=self.adapter_config, ) if self.train_config.gradient_checkpointing: self.adapter.enable_gradient_checkpointing() + else: + self.adapter = CustomAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) self.adapter.to(self.device_torch, dtype=dtype) if latest_save_path is not None: # load adapter from path @@ -985,7 +995,7 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype=dtype ) self.adapter.load_state_dict(loaded_state_dict) - else: + elif self.adapter_config.type.startswith('ip'): # ip adapter loaded_state_dict = load_ip_adapter_model( latest_save_path, @@ -993,6 +1003,14 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype=dtype ) self.adapter.load_state_dict(loaded_state_dict) + else: + # custom adapter + loaded_state_dict = load_custom_adapter_model( + latest_save_path, + self.device, + dtype=dtype + ) + self.adapter.load_state_dict(loaded_state_dict) if self.adapter_config.train: self.load_training_state_from_metadata(latest_save_path) # set trainable params diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 9293ebb4..2803a8e6 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -127,7 +127,7 @@ class NetworkConfig: self.conv = 4 -AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora'] +AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker'] class AdapterConfig: @@ -162,6 +162,8 @@ class AdapterConfig: self.trigger = kwargs.get('trigger', 'tri993r') self.trigger_class_name = kwargs.get('trigger_class_name', 'person') + self.class_names = kwargs.get('class_names', []) + class EmbeddingConfig: def __init__(self, **kwargs): diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py new file mode 100644 index 00000000..26fb7614 --- /dev/null +++ b/toolkit/custom_adapter.py @@ -0,0 +1,417 @@ +import torch +import sys + +from PIL import Image +from torch.nn import Parameter +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from toolkit.models.clip_pre_processor import CLIPImagePreProcessor +from toolkit.paths import REPOS_ROOT +from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder +from toolkit.saving import load_ip_adapter_model +from toolkit.train_tools import get_torch_dtype + +sys.path.append(REPOS_ROOT) +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional +from collections import OrderedDict +from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ + AttnProcessor2_0 +from ipadapter.ip_adapter.ip_adapter import ImageProjModel +from ipadapter.ip_adapter.resampler import Resampler +from toolkit.config_modules import AdapterConfig, AdapterTypes +from toolkit.prompt_utils import PromptEmbeds +import weakref + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + CLIPVisionModel, + AutoImageProcessor, + ConvNextModel, + ConvNextForImageClassification, + ConvNextImageProcessor +) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + +from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification + +from transformers import ViTFeatureExtractor, ViTForImageClassification + +import torch.nn.functional as F + + +class CustomAdapter(torch.nn.Module): + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): + super().__init__() + self.config = adapter_config + self.sd_ref: weakref.ref = weakref.ref(sd) + self.device = self.sd_ref().unet.device + self.image_processor: CLIPImageProcessor = None + self.input_size = 224 + self.adapter_type: AdapterTypes = self.config.type + self.current_scale = 1.0 + self.is_active = True + self.flag_word = "fla9wor0" + + self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None + + self.fuse_module: FuseModule = None + + self.lora: None = None + + self.position_ids: Optional[List[int]] = None + + self.num_control_images = 1 + self.token_mask: Optional[torch.Tensor] = None + + # setup clip + self.setup_clip() + # add for dataloader + self.clip_image_processor = self.image_processor + + self.setup_adapter() + + # try to load from our name_or_path + if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'): + self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False) + + # add the trigger word to the tokenizer + if isinstance(self.sd_ref().tokenizer, list): + for tokenizer in self.sd_ref().tokenizer: + tokenizer.add_tokens([self.flag_word], special_tokens=True) + else: + self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True) + + def setup_adapter(self): + if self.adapter_type == 'photo_maker': + sd = self.sd_ref() + embed_dim = sd.unet.config['cross_attention_dim'] + self.fuse_module = FuseModule(embed_dim) + else: + raise ValueError(f"unknown adapter type: {self.adapter_type}") + + def forward(self, *args, **kwargs): + if self.adapter_type == 'photo_maker': + id_pixel_values = args[0] + prompt_embeds: PromptEmbeds = args[1] + class_tokens_mask = args[2] + + grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled() + + with torch.set_grad_enabled(grads_on_image_encoder): + id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False) + + if not grads_on_image_encoder: + id_embeds = id_embeds.detach() + + prompt_embeds = prompt_embeds.detach() + + updated_prompt_embeds = self.fuse_module( + prompt_embeds, id_embeds, class_tokens_mask + ) + + return updated_prompt_embeds + else: + raise NotImplementedError + + def setup_clip(self): + adapter_config = self.config + sd = self.sd_ref() + if self.config.type == 'photo_maker': + try: + self.image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path) + except EnvironmentError: + self.image_processor = CLIPImageProcessor() + if self.config.image_encoder_path is None: + self.vision_encoder = PhotoMakerCLIPEncoder() + else: + self.vision_encoder = PhotoMakerCLIPEncoder.from_pretrained(self.config.image_encoder_path) + elif self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+': + try: + self.image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = CLIPImageProcessor() + self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'siglip': + from transformers import SiglipImageProcessor, SiglipVisionModel + try: + self.image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = SiglipImageProcessor() + self.vision_encoder = SiglipVisionModel.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'vit': + try: + self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = ViTFeatureExtractor() + self.vision_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'safe': + try: + self.image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.image_processor = SAFEImageProcessor() + self.vision_encoder = SAFEVisionModel( + in_channels=3, + num_tokens=self.config.safe_tokens, + num_vectors=sd.unet.config['cross_attention_dim'], + reducer_channels=self.config.safe_reducer_channels, + channels=self.config.safe_channels, + downscale_factor=8 + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'convnext': + try: + self.image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.image_processor = ConvNextImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.vision_encoder = ConvNextForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'vit-hybrid': + try: + self.image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.image_processor = ViTHybridImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.vision_encoder = ViTHybridForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + else: + raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}") + + self.input_size = self.vision_encoder.config.image_size + + if self.config.image_encoder_arch == 'clip+': + # self.image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.vision_encoder.config.image_size * 4 + + # update the preprocessor so images come in at the right size + self.image_processor.size['shortest_edge'] = preprocessor_input_size + self.image_processor.crop_size['height'] = preprocessor_input_size + self.image_processor.crop_size['width'] = preprocessor_input_size + + self.preprocessor = CLIPImagePreProcessor( + input_size=preprocessor_input_size, + clip_input_size=self.vision_encoder.config.image_size, + ) + if 'height' in self.image_processor.size: + self.input_size = self.image_processor.size['height'] + else: + self.input_size = self.image_processor.crop_size['height'] + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + strict = False + + if 'lora_weights' in state_dict: + # todo add LoRA + # self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") + # self.sd_ref().pipeline.fuse_lora() + pass + if 'id_encoder' in state_dict and self.adapter_type == 'photo_maker': + self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict) + # check to see if the fuse weights are there + fuse_weights = {} + for k, v in state_dict['id_encoder'].items(): + if k.startswith('fuse_module'): + k = k.replace('fuse_module.', '') + fuse_weights[k] = v + if len(fuse_weights) > 0: + self.fuse_module.load_state_dict(fuse_weights, strict=strict) + + if 'fuse_module' in state_dict: + self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict) + + pass + + def state_dict(self) -> OrderedDict: + state_dict = OrderedDict() + if self.adapter_type == 'photo_maker': + if self.train_image_encoder: + state_dict["id_encoder"] = self.vision_encoder.state_dict() + + state_dict["fuse_module"] = self.fuse_module.state_dict() + + # todo save LoRA + else: + raise NotImplementedError + + def condition_prompt( + self, + prompt: Union[List[str], str], + is_unconditional: bool = False, + ): + if self.adapter_type == 'photo_maker': + if is_unconditional: + return prompt + else: + + with torch.no_grad(): + was_list = isinstance(prompt, list) + if not was_list: + prompt_list = [prompt] + else: + prompt_list = prompt + + new_prompt_list = [] + token_mask_list = [] + + for prompt in prompt_list: + + our_class = None + # find a class in the prompt + prompt_parts = prompt.split(' ') + prompt_parts = [p.strip().lower() for p in prompt_parts if len(p) > 0] + + new_prompt_parts = [] + tokened_prompt_parts = [] + for idx, prompt_part in enumerate(prompt_parts): + new_prompt_parts.append(prompt_part) + tokened_prompt_parts.append(prompt_part) + if prompt_part in self.config.class_names: + our_class = prompt_part + # add the flag word + tokened_prompt_parts.append(self.flag_word) + + if self.num_control_images > 1: + # add the rest + for _ in range(self.num_control_images - 1): + new_prompt_parts.extend(prompt_parts[idx + 1:]) + + # add the rest + tokened_prompt_parts.extend(prompt_parts[idx + 1:]) + new_prompt_parts.extend(prompt_parts[idx + 1:]) + + break + + prompt = " ".join(new_prompt_parts) + tokened_prompt = " ".join(tokened_prompt_parts) + + if our_class is None: + # add the first one to the front of the prompt + tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt + our_class = self.config.class_names[0] + prompt = " ".join([self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt + + # add the prompt to the list + new_prompt_list.append(prompt) + + # tokenize them with just the first tokenizer + tokenizer = self.sd_ref().tokenizer + if isinstance(tokenizer, list): + tokenizer = tokenizer[0] + + flag_token = tokenizer.convert_tokens_to_ids(self.flag_word) + + tokenized_prompt = tokenizer.encode(prompt) + tokenized_tokened_prompt = tokenizer.encode(tokened_prompt) + + flag_idx = tokenized_tokened_prompt.index(flag_token) + + class_token = tokenized_prompt[flag_idx - 1] + + + boolean_mask = torch.zeros(flag_idx-1, dtype=torch.bool) + boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool))) + boolean_mask = boolean_mask.to(self.device) + # zero pad it to 77 + boolean_mask = F.pad(boolean_mask, (0, 77 - boolean_mask.shape[0]), value=False) + + token_mask_list.append(boolean_mask) + + self.token_mask = torch.cat(token_mask_list, dim=0).to(self.device) + + prompt_list = new_prompt_list + + if not was_list: + prompt = prompt_list[0] + else: + prompt = prompt_list + + return prompt + + def condition_encoded_embeds( + self, + tensors_0_1: torch.Tensor, + prompt_embeds: PromptEmbeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=False + ) -> PromptEmbeds: + if self.adapter_type == 'photo_maker': + if is_unconditional: + # we dont condition the negative embeds for photo maker + return prompt_embeds + with torch.no_grad(): + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + clip_image = self.image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + + + # Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image + clip_image = clip_image.unsqueeze(1) + + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + do_projection2=isinstance(self.sd_ref().text_encoder, list), + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list) + ).detach() + + prompt_embeds.text_embeds = self.fuse_module( + prompt_embeds.text_embeds, + id_embeds, + self.token_mask + ) + + return prompt_embeds + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + if self.config.type == 'photo_maker': + yield from self.fuse_module.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) diff --git a/toolkit/photomaker.py b/toolkit/photomaker.py new file mode 100644 index 00000000..9f8d69ef --- /dev/null +++ b/toolkit/photomaker.py @@ -0,0 +1,141 @@ +# Merge image encoder and fuse module to create an ID Encoder +# send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding + +import torch +import torch.nn as nn +from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection +from transformers.models.clip.configuration_clip import CLIPVisionConfig +from transformers import PretrainedConfig + +VISION_CONFIG_DICT = { + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768 +} + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class FuseModule(nn.Module): + def __init__(self, embed_dim): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) + self.layer_norm = nn.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + ) -> torch.Tensor: + # id_embeds shape: [b, max_num_inputs, 1, 2048] + id_embeds = id_embeds.to(prompt_embeds.dtype) + num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case + batch_size, max_num_inputs = id_embeds.shape[:2] + # seq_length: 77 + seq_length = prompt_embeds.shape[1] + # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] + flat_id_embeds = id_embeds.view( + -1, id_embeds.shape[-2], id_embeds.shape[-1] + ) + # valid_id_mask [b*max_num_inputs] + valid_id_mask = ( + torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] + < num_inputs[:, None] + ) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) + class_tokens_mask = class_tokens_mask.view(-1) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) + # slice out the image token embeddings + image_token_embeds = prompt_embeds[class_tokens_mask] + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + return updated_prompt_embeds + +class PhotoMakerIDEncoder(CLIPVisionModelWithProjection): + def __init__(self, config=None, *model_args, **model_kwargs): + if config is None: + config = CLIPVisionConfig(**VISION_CONFIG_DICT) + super().__init__(config, *model_args, **model_kwargs) + self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) + self.fuse_module = FuseModule(2048) + + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[1] + id_embeds = self.visual_projection(shared_id_embeds) + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + updated_prompt_embeds = self.fuse_module( + prompt_embeds, id_embeds, class_tokens_mask) + + return updated_prompt_embeds + + +class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection): + def __init__(self, config=None, *model_args, **model_kwargs): + if config is None: + config = CLIPVisionConfig(**VISION_CONFIG_DICT) + super().__init__(config, *model_args, **model_kwargs) + self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) + + def forward(self, id_pixel_values, do_projection2=True): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[1] + id_embeds = self.visual_projection(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + + if do_projection2: + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + + return id_embeds + + + +if __name__ == "__main__": + PhotoMakerIDEncoder() \ No newline at end of file diff --git a/toolkit/photomaker_pipeline.py b/toolkit/photomaker_pipeline.py new file mode 100644 index 00000000..d6437b64 --- /dev/null +++ b/toolkit/photomaker_pipeline.py @@ -0,0 +1,491 @@ +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from collections import OrderedDict +import os +import PIL +import numpy as np + +import torch +from torchvision import transforms as T + +from safetensors import safe_open +from huggingface_hub.utils import validate_hf_hub_args +from transformers import CLIPImageProcessor, CLIPTokenizer +from diffusers import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.utils import ( + _get_model_file, + is_transformers_available, + logging, +) + +from .photomaker import PhotoMakerIDEncoder + +PipelineImageInput = Union[ + PIL.Image.Image, + torch.FloatTensor, + List[PIL.Image.Image], + List[torch.FloatTensor], +] + + +class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): + @validate_hf_hub_args + def load_photomaker_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str, + subfolder: str = '', + trigger_word: str = 'img', + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + weight_name (`str`): + The weight name NOT the path to the weight. + + subfolder (`str`, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + trigger_word (`str`, *optional*, defaults to `"img"`): + The trigger word is used to identify the position of class word in the text prompt, + and it is recommended not to set it as a common word. + This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. + """ + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"id_encoder": {}, "lora_weights": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("id_encoder."): + state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key) + elif key.startswith("lora_weights."): + state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if keys != ["id_encoder", "lora_weights"]: + raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.") + + self.trigger_word = trigger_word + # load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet + print(f"Loading PhotoMaker components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...") + id_encoder = PhotoMakerIDEncoder() + id_encoder.load_state_dict(state_dict["id_encoder"], strict=True) + id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype) + self.id_encoder = id_encoder + self.id_image_processor = CLIPImageProcessor() + + # load lora into models + print(f"Loading PhotoMaker components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]") + self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") + + # Add trigger word token + if self.tokenizer is not None: + self.tokenizer.add_tokens([self.trigger_word], special_tokens=True) + + self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True) + + def encode_prompt_with_trigger_word( + self, + prompt: str, + prompt_2: Optional[str] = None, + num_id_images: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + class_tokens_mask: Optional[torch.LongTensor] = None, + ): + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Find the token id of the trigger word + image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word) + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + input_ids = tokenizer.encode(prompt) # TODO: batch encode + clean_index = 0 + clean_input_ids = [] + class_token_index = [] + # Find out the corrresponding class word token based on the newly added trigger word token + for i, token_id in enumerate(input_ids): + if token_id == image_token_id: + class_token_index.append(clean_index - 1) + else: + clean_input_ids.append(token_id) + clean_index += 1 + + if len(class_token_index) != 1: + raise ValueError( + f"PhotoMaker currently does not support multiple trigger words in a single prompt.\ + Trigger word: {self.trigger_word}, Prompt: {prompt}." + ) + class_token_index = class_token_index[0] + + # Expand the class word token and corresponding mask + class_token = clean_input_ids[class_token_index] + clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \ + clean_input_ids[class_token_index + 1:] + + # Truncation or padding + max_len = tokenizer.model_max_length + if len(clean_input_ids) > max_len: + clean_input_ids = clean_input_ids[:max_len] + else: + clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( + max_len - len(clean_input_ids) + ) + + class_tokens_mask = [True if class_token_index <= i < class_token_index + num_id_images else False \ + for i in range(len(clean_input_ids))] + + clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0) + class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0) + + prompt_embeds = text_encoder( + clean_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case + + return prompt_embeds, pooled_prompt_embeds, class_tokens_mask + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + # Added parameters (for PhotoMaker) + input_id_images: PipelineImageInput = None, + start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future + class_tokens_mask: Optional[torch.LongTensor] = None, + prompt_embeds_text_only: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + Only the parameters introduced by PhotoMaker are discussed here. + For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py + + Args: + input_id_images (`PipelineImageInput`, *optional*): + Input ID Image to work with PhotoMaker. + class_tokens_mask (`torch.LongTensor`, *optional*): + Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word. + prompt_embeds_text_only (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + # + if prompt_embeds is not None and class_tokens_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`." + ) + # check the input id images + if input_id_images is None: + raise ValueError( + "Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline." + ) + if not isinstance(input_id_images, list): + input_id_images = [input_id_images] + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + assert do_classifier_free_guidance + + # 3. Encode input prompt + num_id_images = len(input_id_images) + + ( + prompt_embeds, + pooled_prompt_embeds, + class_tokens_mask, + ) = self.encode_prompt_with_trigger_word( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_id_images=num_id_images, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + class_tokens_mask=class_tokens_mask, + ) + + # 4. Encode input prompt without the trigger word for delayed conditioning + prompt_text_only = prompt.replace(" " + self.trigger_word, "") # sensitive to white space + ( + prompt_embeds_text_only, + negative_prompt_embeds, + pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt_text_only, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds_text_only, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds_text_only, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 5. Prepare the input ID images + dtype = next(self.id_encoder.parameters()).dtype + if not isinstance(input_id_images[0], torch.Tensor): + id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values + + id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts + + # 6. Get the update text embedding with the stacked ID embedding + prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + # 7. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 8. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Prepare added time ids & embeddings + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if i <= start_merge_step: + current_prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds_text_only], dim=0 + ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0) + else: + current_prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=current_prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + # if self.watermark is not None: + # image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file diff --git a/toolkit/saving.py b/toolkit/saving.py index e3c63b22..597f4620 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -246,6 +246,25 @@ def load_ip_adapter_model( else: return torch.load(path_to_file, map_location=device) +def load_custom_adapter_model( + path_to_file, + device: Union[str] = 'cpu', + dtype: torch.dtype = torch.float32 +): + # check if it is safetensors or checkpoint + if path_to_file.endswith('.safetensors'): + raw_state_dict = load_file(path_to_file, device) + combined_state_dict = OrderedDict() + for combo_key, value in raw_state_dict.items(): + key_split = combo_key.split('.') + module_name = key_split.pop(0) + if module_name not in combined_state_dict: + combined_state_dict[module_name] = OrderedDict() + combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype) + return combined_state_dict + else: + return torch.load(path_to_file, map_location=device) + def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict': lora_keymap = OrderedDict() diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index cb30273a..d0fb6b87 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -19,6 +19,7 @@ from tqdm import tqdm from torchvision.transforms import Resize, transforms from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter from toolkit.ip_adapter import IPAdapter from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ convert_vae_state_dict, load_vae @@ -483,6 +484,13 @@ class StableDiffusion: transforms.ToTensor(), ]) validation_image = transform(validation_image) + if isinstance(self.adapter, CustomAdapter): + # todo allow loading multiple + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + self.adapter.num_images = 1 if isinstance(self.adapter, ReferenceAdapter): # need -1 to 1 validation_image = transforms.ToTensor()(validation_image) @@ -501,6 +509,19 @@ class StableDiffusion: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) self.adapter(conditional_clip_embeds) + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # handle condition the prompts + gen_config.prompt = self.adapter.condition_prompt( + gen_config.prompt, + is_unconditional=False, + ) + gen_config.prompt_2 = gen_config.prompt + gen_config.negative_prompt = self.adapter.condition_prompt( + gen_config.negative_prompt, + is_unconditional=True, + ) + gen_config.negative_prompt_2 = gen_config.negative_prompt + # encode the prompt ourselves so we can do fun stuff with embeddings conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) @@ -524,6 +545,21 @@ class StableDiffusion: conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds) + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=conditional_embeds, + is_training=False, + has_been_preprocessed=False, + ) + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=unconditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=True, + ) + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: # if we have a refiner loaded, set the denoising end at the refiner start extra['denoising_end'] = gen_config.refiner_start_at @@ -1468,6 +1504,9 @@ class StableDiffusion: elif isinstance(self.adapter, ClipVisionAdapter): requires_grad = self.adapter.embedder.training adapter_device = self.adapter.device + elif isinstance(self.adapter, CustomAdapter): + requires_grad = self.adapter.training + adapter_device = self.adapter.device elif isinstance(self.adapter, ReferenceAdapter): # todo update this!! requires_grad = True