mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Initial training script for photomaker training. Needs a little more work.
This commit is contained in:
@@ -14,6 +14,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileIte
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
@@ -148,7 +149,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
|
||||
)[0]
|
||||
|
||||
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(
|
||||
self.train_config.dtype))
|
||||
# remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
|
||||
denoised_latent = denoised_latent - residual_noise
|
||||
|
||||
@@ -298,7 +300,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma,
|
||||
fixed=True)
|
||||
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
@@ -631,7 +634,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.network.is_active = False
|
||||
can_disable_adapter = False
|
||||
was_adapter_active = False
|
||||
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter)):
|
||||
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or
|
||||
isinstance(self.adapter, ReferenceAdapter)
|
||||
):
|
||||
can_disable_adapter = True
|
||||
was_adapter_active = self.adapter.is_active
|
||||
self.adapter.is_active = False
|
||||
@@ -698,6 +703,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
# condition the prompt
|
||||
# todo handle more than one adapter image
|
||||
self.adapter.num_control_images = 1
|
||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
@@ -706,7 +718,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError("IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
raise ValueError(
|
||||
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
@@ -752,7 +765,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if batch.clip_image_tensor is not None:
|
||||
clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
|
||||
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
if batch.mask_tensor is not None:
|
||||
with self.timer('get_mask_multiplier'):
|
||||
@@ -879,12 +891,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
mask_multiplier_list,
|
||||
prompt_2_list
|
||||
):
|
||||
if self.train_config.negative_prompt is not None:
|
||||
# add negative prompt
|
||||
conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
|
||||
range(len(conditioned_prompts))]
|
||||
if prompt_2 is not None:
|
||||
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||
|
||||
# if self.train_config.negative_prompt is not None:
|
||||
# # add negative prompt
|
||||
# conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
|
||||
# range(len(conditioned_prompts))]
|
||||
# if prompt_2 is not None:
|
||||
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||
|
||||
with network:
|
||||
# encode clip adapter here so embeds are active for tokenizer
|
||||
@@ -977,7 +990,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
image_size = self.adapter.input_size
|
||||
@@ -1029,11 +1041,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = unconditional_clip_embeds.detach()
|
||||
|
||||
|
||||
with self.timer('encode_adapter'):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds = self.adapter(unconditional_embeds.detach(), unconditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
|
||||
unconditional_clip_embeds)
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
||||
# pass in our scheduler
|
||||
@@ -1060,7 +1072,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||
do_inverted_masked_prior = True
|
||||
|
||||
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
|
||||
if ((
|
||||
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior):
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
@@ -1074,6 +1087,25 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeds=unconditional_embeds
|
||||
)
|
||||
|
||||
# do the custom adapter after the prior prediction
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||
self.adapter.train()
|
||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=clip_images,
|
||||
prompt_embeds=conditional_embeds,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
)
|
||||
if self.train_config.do_cfg and unconditional_embeds is not None:
|
||||
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=clip_images,
|
||||
prompt_embeds=unconditional_embeds,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
is_unconditional=True
|
||||
)
|
||||
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch.backends.cuda
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
@@ -34,7 +35,7 @@ from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
|
||||
load_ip_adapter_model
|
||||
load_ip_adapter_model, load_custom_adapter_model
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
@@ -141,7 +142,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
@@ -412,8 +413,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
adapter_name += '_t2i'
|
||||
elif self.adapter_config.type == 'clip':
|
||||
adapter_name += '_clip'
|
||||
else:
|
||||
elif self.adapter_config.type == 'ip':
|
||||
adapter_name += '_ip'
|
||||
else:
|
||||
adapter_name += '_adapter'
|
||||
|
||||
filename = f'{adapter_name}{step_num}.safetensors'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
@@ -931,8 +934,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
suffix = 'clip'
|
||||
elif self.adapter_config.type == 'reference':
|
||||
suffix = 'ref'
|
||||
else:
|
||||
elif self.adapter_config.type.startswith('ip'):
|
||||
suffix = 'ip'
|
||||
else:
|
||||
suffix = 'adapter'
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_{suffix}"
|
||||
@@ -967,13 +972,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
else:
|
||||
elif self.adapter_config.type.startswith('ip'):
|
||||
self.adapter = IPAdapter(
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.adapter.enable_gradient_checkpointing()
|
||||
else:
|
||||
self.adapter = CustomAdapter(
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
@@ -985,7 +995,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
else:
|
||||
elif self.adapter_config.type.startswith('ip'):
|
||||
# ip adapter
|
||||
loaded_state_dict = load_ip_adapter_model(
|
||||
latest_save_path,
|
||||
@@ -993,6 +1003,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
else:
|
||||
# custom adapter
|
||||
loaded_state_dict = load_custom_adapter_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
if self.adapter_config.train:
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
# set trainable params
|
||||
|
||||
@@ -127,7 +127,7 @@ class NetworkConfig:
|
||||
self.conv = 4
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora']
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker']
|
||||
|
||||
|
||||
class AdapterConfig:
|
||||
@@ -162,6 +162,8 @@ class AdapterConfig:
|
||||
self.trigger = kwargs.get('trigger', 'tri993r')
|
||||
self.trigger_class_name = kwargs.get('trigger_class_name', 'person')
|
||||
|
||||
self.class_names = kwargs.get('class_names', [])
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
417
toolkit/custom_adapter.py
Normal file
417
toolkit/custom_adapter.py
Normal file
@@ -0,0 +1,417 @@
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||
AttnProcessor2_0
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from toolkit.config_modules import AdapterConfig, AdapterTypes
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
CLIPVisionModel,
|
||||
AutoImageProcessor,
|
||||
ConvNextModel,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CustomAdapter(torch.nn.Module):
|
||||
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
|
||||
super().__init__()
|
||||
self.config = adapter_config
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.image_processor: CLIPImageProcessor = None
|
||||
self.input_size = 224
|
||||
self.adapter_type: AdapterTypes = self.config.type
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
self.flag_word = "fla9wor0"
|
||||
|
||||
self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None
|
||||
|
||||
self.fuse_module: FuseModule = None
|
||||
|
||||
self.lora: None = None
|
||||
|
||||
self.position_ids: Optional[List[int]] = None
|
||||
|
||||
self.num_control_images = 1
|
||||
self.token_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# setup clip
|
||||
self.setup_clip()
|
||||
# add for dataloader
|
||||
self.clip_image_processor = self.image_processor
|
||||
|
||||
self.setup_adapter()
|
||||
|
||||
# try to load from our name_or_path
|
||||
if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
|
||||
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
|
||||
|
||||
# add the trigger word to the tokenizer
|
||||
if isinstance(self.sd_ref().tokenizer, list):
|
||||
for tokenizer in self.sd_ref().tokenizer:
|
||||
tokenizer.add_tokens([self.flag_word], special_tokens=True)
|
||||
else:
|
||||
self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
|
||||
|
||||
def setup_adapter(self):
|
||||
if self.adapter_type == 'photo_maker':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
self.fuse_module = FuseModule(embed_dim)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.adapter_type == 'photo_maker':
|
||||
id_pixel_values = args[0]
|
||||
prompt_embeds: PromptEmbeds = args[1]
|
||||
class_tokens_mask = args[2]
|
||||
|
||||
grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
|
||||
|
||||
with torch.set_grad_enabled(grads_on_image_encoder):
|
||||
id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
|
||||
|
||||
if not grads_on_image_encoder:
|
||||
id_embeds = id_embeds.detach()
|
||||
|
||||
prompt_embeds = prompt_embeds.detach()
|
||||
|
||||
updated_prompt_embeds = self.fuse_module(
|
||||
prompt_embeds, id_embeds, class_tokens_mask
|
||||
)
|
||||
|
||||
return updated_prompt_embeds
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_clip(self):
|
||||
adapter_config = self.config
|
||||
sd = self.sd_ref()
|
||||
if self.config.type == 'photo_maker':
|
||||
try:
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.image_processor = CLIPImageProcessor()
|
||||
if self.config.image_encoder_path is None:
|
||||
self.vision_encoder = PhotoMakerCLIPEncoder()
|
||||
else:
|
||||
self.vision_encoder = PhotoMakerCLIPEncoder.from_pretrained(self.config.image_encoder_path)
|
||||
elif self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
||||
try:
|
||||
self.image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.image_processor = CLIPImageProcessor()
|
||||
self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'siglip':
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
try:
|
||||
self.image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.image_processor = SiglipImageProcessor()
|
||||
self.vision_encoder = SiglipVisionModel.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit':
|
||||
try:
|
||||
self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.image_processor = ViTFeatureExtractor()
|
||||
self.vision_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'safe':
|
||||
try:
|
||||
self.image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.image_processor = SAFEImageProcessor()
|
||||
self.vision_encoder = SAFEVisionModel(
|
||||
in_channels=3,
|
||||
num_tokens=self.config.safe_tokens,
|
||||
num_vectors=sd.unet.config['cross_attention_dim'],
|
||||
reducer_channels=self.config.safe_reducer_channels,
|
||||
channels=self.config.safe_channels,
|
||||
downscale_factor=8
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'convnext':
|
||||
try:
|
||||
self.image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
print(f"could not load image processor from {adapter_config.image_encoder_path}")
|
||||
self.image_processor = ConvNextImageProcessor(
|
||||
size=320,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
)
|
||||
self.vision_encoder = ConvNextForImageClassification.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit-hybrid':
|
||||
try:
|
||||
self.image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
print(f"could not load image processor from {adapter_config.image_encoder_path}")
|
||||
self.image_processor = ViTHybridImageProcessor(
|
||||
size=320,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
)
|
||||
self.vision_encoder = ViTHybridForImageClassification.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
else:
|
||||
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
|
||||
|
||||
self.input_size = self.vision_encoder.config.image_size
|
||||
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.image_processor.config
|
||||
# We do a 3x downscale of the image, so we need to adjust the input size
|
||||
preprocessor_input_size = self.vision_encoder.config.image_size * 4
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['width'] = preprocessor_input_size
|
||||
|
||||
self.preprocessor = CLIPImagePreProcessor(
|
||||
input_size=preprocessor_input_size,
|
||||
clip_input_size=self.vision_encoder.config.image_size,
|
||||
)
|
||||
if 'height' in self.image_processor.size:
|
||||
self.input_size = self.image_processor.size['height']
|
||||
else:
|
||||
self.input_size = self.image_processor.crop_size['height']
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
|
||||
if 'lora_weights' in state_dict:
|
||||
# todo add LoRA
|
||||
# self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
|
||||
# self.sd_ref().pipeline.fuse_lora()
|
||||
pass
|
||||
if 'id_encoder' in state_dict and self.adapter_type == 'photo_maker':
|
||||
self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict)
|
||||
# check to see if the fuse weights are there
|
||||
fuse_weights = {}
|
||||
for k, v in state_dict['id_encoder'].items():
|
||||
if k.startswith('fuse_module'):
|
||||
k = k.replace('fuse_module.', '')
|
||||
fuse_weights[k] = v
|
||||
if len(fuse_weights) > 0:
|
||||
self.fuse_module.load_state_dict(fuse_weights, strict=strict)
|
||||
|
||||
if 'fuse_module' in state_dict:
|
||||
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
|
||||
|
||||
pass
|
||||
|
||||
def state_dict(self) -> OrderedDict:
|
||||
state_dict = OrderedDict()
|
||||
if self.adapter_type == 'photo_maker':
|
||||
if self.train_image_encoder:
|
||||
state_dict["id_encoder"] = self.vision_encoder.state_dict()
|
||||
|
||||
state_dict["fuse_module"] = self.fuse_module.state_dict()
|
||||
|
||||
# todo save LoRA
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def condition_prompt(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type == 'photo_maker':
|
||||
if is_unconditional:
|
||||
return prompt
|
||||
else:
|
||||
|
||||
with torch.no_grad():
|
||||
was_list = isinstance(prompt, list)
|
||||
if not was_list:
|
||||
prompt_list = [prompt]
|
||||
else:
|
||||
prompt_list = prompt
|
||||
|
||||
new_prompt_list = []
|
||||
token_mask_list = []
|
||||
|
||||
for prompt in prompt_list:
|
||||
|
||||
our_class = None
|
||||
# find a class in the prompt
|
||||
prompt_parts = prompt.split(' ')
|
||||
prompt_parts = [p.strip().lower() for p in prompt_parts if len(p) > 0]
|
||||
|
||||
new_prompt_parts = []
|
||||
tokened_prompt_parts = []
|
||||
for idx, prompt_part in enumerate(prompt_parts):
|
||||
new_prompt_parts.append(prompt_part)
|
||||
tokened_prompt_parts.append(prompt_part)
|
||||
if prompt_part in self.config.class_names:
|
||||
our_class = prompt_part
|
||||
# add the flag word
|
||||
tokened_prompt_parts.append(self.flag_word)
|
||||
|
||||
if self.num_control_images > 1:
|
||||
# add the rest
|
||||
for _ in range(self.num_control_images - 1):
|
||||
new_prompt_parts.extend(prompt_parts[idx + 1:])
|
||||
|
||||
# add the rest
|
||||
tokened_prompt_parts.extend(prompt_parts[idx + 1:])
|
||||
new_prompt_parts.extend(prompt_parts[idx + 1:])
|
||||
|
||||
break
|
||||
|
||||
prompt = " ".join(new_prompt_parts)
|
||||
tokened_prompt = " ".join(tokened_prompt_parts)
|
||||
|
||||
if our_class is None:
|
||||
# add the first one to the front of the prompt
|
||||
tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt
|
||||
our_class = self.config.class_names[0]
|
||||
prompt = " ".join([self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
|
||||
|
||||
# add the prompt to the list
|
||||
new_prompt_list.append(prompt)
|
||||
|
||||
# tokenize them with just the first tokenizer
|
||||
tokenizer = self.sd_ref().tokenizer
|
||||
if isinstance(tokenizer, list):
|
||||
tokenizer = tokenizer[0]
|
||||
|
||||
flag_token = tokenizer.convert_tokens_to_ids(self.flag_word)
|
||||
|
||||
tokenized_prompt = tokenizer.encode(prompt)
|
||||
tokenized_tokened_prompt = tokenizer.encode(tokened_prompt)
|
||||
|
||||
flag_idx = tokenized_tokened_prompt.index(flag_token)
|
||||
|
||||
class_token = tokenized_prompt[flag_idx - 1]
|
||||
|
||||
|
||||
boolean_mask = torch.zeros(flag_idx-1, dtype=torch.bool)
|
||||
boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool)))
|
||||
boolean_mask = boolean_mask.to(self.device)
|
||||
# zero pad it to 77
|
||||
boolean_mask = F.pad(boolean_mask, (0, 77 - boolean_mask.shape[0]), value=False)
|
||||
|
||||
token_mask_list.append(boolean_mask)
|
||||
|
||||
self.token_mask = torch.cat(token_mask_list, dim=0).to(self.device)
|
||||
|
||||
prompt_list = new_prompt_list
|
||||
|
||||
if not was_list:
|
||||
prompt = prompt_list[0]
|
||||
else:
|
||||
prompt = prompt_list
|
||||
|
||||
return prompt
|
||||
|
||||
def condition_encoded_embeds(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
prompt_embeds: PromptEmbeds,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=False
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'photo_maker':
|
||||
if is_unconditional:
|
||||
# we dont condition the negative embeds for photo maker
|
||||
return prompt_embeds
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
# tensors should be 0-1
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
clip_image = self.image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
|
||||
|
||||
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
|
||||
clip_image = clip_image.unsqueeze(1)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image,
|
||||
do_projection2=isinstance(self.sd_ref().text_encoder, list),
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.vision_encoder.eval()
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list)
|
||||
).detach()
|
||||
|
||||
prompt_embeds.text_embeds = self.fuse_module(
|
||||
prompt_embeds.text_embeds,
|
||||
id_embeds,
|
||||
self.token_mask
|
||||
)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.type == 'photo_maker':
|
||||
yield from self.fuse_module.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
141
toolkit/photomaker.py
Normal file
141
toolkit/photomaker.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Merge image encoder and fuse module to create an ID Encoder
|
||||
# send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection
|
||||
from transformers.models.clip.configuration_clip import CLIPVisionConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
VISION_CONFIG_DICT = {
|
||||
"hidden_size": 1024,
|
||||
"intermediate_size": 4096,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768
|
||||
}
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
|
||||
super().__init__()
|
||||
if use_residual:
|
||||
assert in_dim == out_dim
|
||||
self.layernorm = nn.LayerNorm(in_dim)
|
||||
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.use_residual = use_residual
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
return x
|
||||
|
||||
|
||||
class FuseModule(nn.Module):
|
||||
def __init__(self, embed_dim):
|
||||
super().__init__()
|
||||
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
|
||||
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
def fuse_fn(self, prompt_embeds, id_embeds):
|
||||
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
|
||||
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
|
||||
stacked_id_embeds = self.mlp2(stacked_id_embeds)
|
||||
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
|
||||
return stacked_id_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompt_embeds,
|
||||
id_embeds,
|
||||
class_tokens_mask,
|
||||
) -> torch.Tensor:
|
||||
# id_embeds shape: [b, max_num_inputs, 1, 2048]
|
||||
id_embeds = id_embeds.to(prompt_embeds.dtype)
|
||||
num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
|
||||
batch_size, max_num_inputs = id_embeds.shape[:2]
|
||||
# seq_length: 77
|
||||
seq_length = prompt_embeds.shape[1]
|
||||
# flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
|
||||
flat_id_embeds = id_embeds.view(
|
||||
-1, id_embeds.shape[-2], id_embeds.shape[-1]
|
||||
)
|
||||
# valid_id_mask [b*max_num_inputs]
|
||||
valid_id_mask = (
|
||||
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
|
||||
< num_inputs[:, None]
|
||||
)
|
||||
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
|
||||
|
||||
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
|
||||
class_tokens_mask = class_tokens_mask.view(-1)
|
||||
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
|
||||
# slice out the image token embeddings
|
||||
image_token_embeds = prompt_embeds[class_tokens_mask]
|
||||
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
|
||||
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
|
||||
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
|
||||
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
|
||||
return updated_prompt_embeds
|
||||
|
||||
class PhotoMakerIDEncoder(CLIPVisionModelWithProjection):
|
||||
def __init__(self, config=None, *model_args, **model_kwargs):
|
||||
if config is None:
|
||||
config = CLIPVisionConfig(**VISION_CONFIG_DICT)
|
||||
super().__init__(config, *model_args, **model_kwargs)
|
||||
self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
|
||||
self.fuse_module = FuseModule(2048)
|
||||
|
||||
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
|
||||
b, num_inputs, c, h, w = id_pixel_values.shape
|
||||
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
|
||||
|
||||
shared_id_embeds = self.vision_model(id_pixel_values)[1]
|
||||
id_embeds = self.visual_projection(shared_id_embeds)
|
||||
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
|
||||
|
||||
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
|
||||
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
|
||||
|
||||
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
|
||||
updated_prompt_embeds = self.fuse_module(
|
||||
prompt_embeds, id_embeds, class_tokens_mask)
|
||||
|
||||
return updated_prompt_embeds
|
||||
|
||||
|
||||
class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
|
||||
def __init__(self, config=None, *model_args, **model_kwargs):
|
||||
if config is None:
|
||||
config = CLIPVisionConfig(**VISION_CONFIG_DICT)
|
||||
super().__init__(config, *model_args, **model_kwargs)
|
||||
self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
|
||||
|
||||
def forward(self, id_pixel_values, do_projection2=True):
|
||||
b, num_inputs, c, h, w = id_pixel_values.shape
|
||||
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
|
||||
|
||||
shared_id_embeds = self.vision_model(id_pixel_values)[1]
|
||||
id_embeds = self.visual_projection(shared_id_embeds)
|
||||
|
||||
id_embeds = id_embeds.view(b, num_inputs, 1, -1)
|
||||
|
||||
if do_projection2:
|
||||
id_embeds_2 = self.visual_projection_2(shared_id_embeds)
|
||||
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
|
||||
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
|
||||
|
||||
return id_embeds
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PhotoMakerIDEncoder()
|
||||
491
toolkit/photomaker_pipeline.py
Normal file
491
toolkit/photomaker_pipeline.py
Normal file
@@ -0,0 +1,491 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torchvision import transforms as T
|
||||
|
||||
from safetensors import safe_open
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.utils import (
|
||||
_get_model_file,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
from .photomaker import PhotoMakerIDEncoder
|
||||
|
||||
PipelineImageInput = Union[
|
||||
PIL.Image.Image,
|
||||
torch.FloatTensor,
|
||||
List[PIL.Image.Image],
|
||||
List[torch.FloatTensor],
|
||||
]
|
||||
|
||||
|
||||
class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
@validate_hf_hub_args
|
||||
def load_photomaker_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
weight_name: str,
|
||||
subfolder: str = '',
|
||||
trigger_word: str = 'img',
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
weight_name (`str`):
|
||||
The weight name NOT the path to the weight.
|
||||
|
||||
subfolder (`str`, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
trigger_word (`str`, *optional*, defaults to `"img"`):
|
||||
The trigger word is used to identify the position of class word in the text prompt,
|
||||
and it is recommended not to set it as a common word.
|
||||
This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation.
|
||||
"""
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"id_encoder": {}, "lora_weights": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("id_encoder."):
|
||||
state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("lora_weights."):
|
||||
state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["id_encoder", "lora_weights"]:
|
||||
raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.")
|
||||
|
||||
self.trigger_word = trigger_word
|
||||
# load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet
|
||||
print(f"Loading PhotoMaker components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...")
|
||||
id_encoder = PhotoMakerIDEncoder()
|
||||
id_encoder.load_state_dict(state_dict["id_encoder"], strict=True)
|
||||
id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype)
|
||||
self.id_encoder = id_encoder
|
||||
self.id_image_processor = CLIPImageProcessor()
|
||||
|
||||
# load lora into models
|
||||
print(f"Loading PhotoMaker components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]")
|
||||
self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
|
||||
|
||||
# Add trigger word token
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.add_tokens([self.trigger_word], special_tokens=True)
|
||||
|
||||
self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True)
|
||||
|
||||
def encode_prompt_with_trigger_word(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_2: Optional[str] = None,
|
||||
num_id_images: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
class_tokens_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Find the token id of the trigger word
|
||||
image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word)
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt_2]
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
||||
input_ids = tokenizer.encode(prompt) # TODO: batch encode
|
||||
clean_index = 0
|
||||
clean_input_ids = []
|
||||
class_token_index = []
|
||||
# Find out the corrresponding class word token based on the newly added trigger word token
|
||||
for i, token_id in enumerate(input_ids):
|
||||
if token_id == image_token_id:
|
||||
class_token_index.append(clean_index - 1)
|
||||
else:
|
||||
clean_input_ids.append(token_id)
|
||||
clean_index += 1
|
||||
|
||||
if len(class_token_index) != 1:
|
||||
raise ValueError(
|
||||
f"PhotoMaker currently does not support multiple trigger words in a single prompt.\
|
||||
Trigger word: {self.trigger_word}, Prompt: {prompt}."
|
||||
)
|
||||
class_token_index = class_token_index[0]
|
||||
|
||||
# Expand the class word token and corresponding mask
|
||||
class_token = clean_input_ids[class_token_index]
|
||||
clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \
|
||||
clean_input_ids[class_token_index + 1:]
|
||||
|
||||
# Truncation or padding
|
||||
max_len = tokenizer.model_max_length
|
||||
if len(clean_input_ids) > max_len:
|
||||
clean_input_ids = clean_input_ids[:max_len]
|
||||
else:
|
||||
clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
|
||||
max_len - len(clean_input_ids)
|
||||
)
|
||||
|
||||
class_tokens_mask = [True if class_token_index <= i < class_token_index + num_id_images else False \
|
||||
for i in range(len(clean_input_ids))]
|
||||
|
||||
clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0)
|
||||
class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0)
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
clean_input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, class_tokens_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
# Added parameters (for PhotoMaker)
|
||||
input_id_images: PipelineImageInput = None,
|
||||
start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future
|
||||
class_tokens_mask: Optional[torch.LongTensor] = None,
|
||||
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Only the parameters introduced by PhotoMaker are discussed here.
|
||||
For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
|
||||
|
||||
Args:
|
||||
input_id_images (`PipelineImageInput`, *optional*):
|
||||
Input ID Image to work with PhotoMaker.
|
||||
class_tokens_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word.
|
||||
prompt_embeds_text_only (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
#
|
||||
if prompt_embeds is not None and class_tokens_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`."
|
||||
)
|
||||
# check the input id images
|
||||
if input_id_images is None:
|
||||
raise ValueError(
|
||||
"Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline."
|
||||
)
|
||||
if not isinstance(input_id_images, list):
|
||||
input_id_images = [input_id_images]
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
assert do_classifier_free_guidance
|
||||
|
||||
# 3. Encode input prompt
|
||||
num_id_images = len(input_id_images)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
class_tokens_mask,
|
||||
) = self.encode_prompt_with_trigger_word(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_id_images=num_id_images,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
class_tokens_mask=class_tokens_mask,
|
||||
)
|
||||
|
||||
# 4. Encode input prompt without the trigger word for delayed conditioning
|
||||
prompt_text_only = prompt.replace(" " + self.trigger_word, "") # sensitive to white space
|
||||
(
|
||||
prompt_embeds_text_only,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt_text_only,
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds_text_only,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds_text_only,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 5. Prepare the input ID images
|
||||
dtype = next(self.id_encoder.parameters()).dtype
|
||||
if not isinstance(input_id_images[0], torch.Tensor):
|
||||
id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values
|
||||
|
||||
id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts
|
||||
|
||||
# 6. Get the update text embedding with the stacked ID embedding
|
||||
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
# 7. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 8. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 10. Prepare added time ids & embeddings
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 11. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
if i <= start_merge_step:
|
||||
current_prompt_embeds = torch.cat(
|
||||
[negative_prompt_embeds, prompt_embeds_text_only], dim=0
|
||||
)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0)
|
||||
else:
|
||||
current_prompt_embeds = torch.cat(
|
||||
[negative_prompt_embeds, prompt_embeds], dim=0
|
||||
)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=current_prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# apply watermark if available
|
||||
# if self.watermark is not None:
|
||||
# image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
@@ -246,6 +246,25 @@ def load_ip_adapter_model(
|
||||
else:
|
||||
return torch.load(path_to_file, map_location=device)
|
||||
|
||||
def load_custom_adapter_model(
|
||||
path_to_file,
|
||||
device: Union[str] = 'cpu',
|
||||
dtype: torch.dtype = torch.float32
|
||||
):
|
||||
# check if it is safetensors or checkpoint
|
||||
if path_to_file.endswith('.safetensors'):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
combined_state_dict = OrderedDict()
|
||||
for combo_key, value in raw_state_dict.items():
|
||||
key_split = combo_key.split('.')
|
||||
module_name = key_split.pop(0)
|
||||
if module_name not in combined_state_dict:
|
||||
combined_state_dict[module_name] = OrderedDict()
|
||||
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
|
||||
return combined_state_dict
|
||||
else:
|
||||
return torch.load(path_to_file, map_location=device)
|
||||
|
||||
|
||||
def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
|
||||
lora_keymap = OrderedDict()
|
||||
|
||||
@@ -19,6 +19,7 @@ from tqdm import tqdm
|
||||
from torchvision.transforms import Resize, transforms
|
||||
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
@@ -483,6 +484,13 @@ class StableDiffusion:
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
validation_image = transform(validation_image)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
# todo allow loading multiple
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
validation_image = transform(validation_image)
|
||||
self.adapter.num_images = 1
|
||||
if isinstance(self.adapter, ReferenceAdapter):
|
||||
# need -1 to 1
|
||||
validation_image = transforms.ToTensor()(validation_image)
|
||||
@@ -501,6 +509,19 @@ class StableDiffusion:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
# handle condition the prompts
|
||||
gen_config.prompt = self.adapter.condition_prompt(
|
||||
gen_config.prompt,
|
||||
is_unconditional=False,
|
||||
)
|
||||
gen_config.prompt_2 = gen_config.prompt
|
||||
gen_config.negative_prompt = self.adapter.condition_prompt(
|
||||
gen_config.negative_prompt,
|
||||
is_unconditional=True,
|
||||
)
|
||||
gen_config.negative_prompt_2 = gen_config.negative_prompt
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
@@ -524,6 +545,21 @@ class StableDiffusion:
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=validation_image,
|
||||
prompt_embeds=conditional_embeds,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
)
|
||||
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=validation_image,
|
||||
prompt_embeds=unconditional_embeds,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=True,
|
||||
)
|
||||
|
||||
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||
# if we have a refiner loaded, set the denoising end at the refiner start
|
||||
extra['denoising_end'] = gen_config.refiner_start_at
|
||||
@@ -1468,6 +1504,9 @@ class StableDiffusion:
|
||||
elif isinstance(self.adapter, ClipVisionAdapter):
|
||||
requires_grad = self.adapter.embedder.training
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, CustomAdapter):
|
||||
requires_grad = self.adapter.training
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, ReferenceAdapter):
|
||||
# todo update this!!
|
||||
requires_grad = True
|
||||
|
||||
Reference in New Issue
Block a user