diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index c5ede3ca..a40a20e7 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -319,6 +319,7 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs: dict, batch: 'DataLoaderBatchDTO', noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, **kwargs ): loss = get_guidance_loss( @@ -331,6 +332,7 @@ class SDTrainer(BaseSDTrainProcess): batch=batch, noise=noise, sd=self.sd, + unconditional_embeds=unconditional_embeds, **kwargs ) @@ -618,6 +620,7 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs: dict, batch: 'DataLoaderBatchDTO', noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, **kwargs ): # todo for embeddings, we need to run without trigger words @@ -655,9 +658,13 @@ class SDTrainer(BaseSDTrainProcess): # self.network.multiplier = 0.0 self.sd.unet.eval() + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), + unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=1.0, **pred_kwargs # adapter residuals in here @@ -901,6 +908,7 @@ class SDTrainer(BaseSDTrainProcess): self.adapter(conditional_clip_embeds) with self.timer('encode_prompt'): + unconditional_embeds = None if grad_on_text_encoder: with torch.set_grad_enabled(True): conditional_embeds = self.sd.encode_prompt( @@ -909,6 +917,15 @@ class SDTrainer(BaseSDTrainProcess): long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) + + if self.train_config.do_cfg: + # todo only do one and repeat it + unconditional_embeds = self.sd.encode_prompt( + ["" for _ in range(noisy_latents.shape[0])], + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) else: with torch.set_grad_enabled(False): # make sure it is in eval mode @@ -923,9 +940,19 @@ class SDTrainer(BaseSDTrainProcess): long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) + if self.train_config.do_cfg: + # todo only do one and repeat it + unconditional_embeds = self.sd.encode_prompt( + ["" for _ in range(noisy_latents.shape[0])], + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) # detach the embeddings conditional_embeds = conditional_embeds.detach() + if self.train_config.do_cfg: + unconditional_embeds = unconditional_embeds.detach() # flush() pred_kwargs = {} @@ -965,21 +992,43 @@ class SDTrainer(BaseSDTrainProcess): drop=True, is_training=True ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ).detach(), + is_training=True, + drop=True + ) elif has_clip_image: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ).detach(), + is_training=True, + drop=True + ) else: raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") if not self.adapter_config.train_image_encoder: # we are not training the image encoder, so we need to detach the embeds conditional_clip_embeds = conditional_clip_embeds.detach() + if self.train_config.do_cfg: + unconditional_clip_embeds = unconditional_clip_embeds.detach() with self.timer('encode_adapter'): conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds) + if self.train_config.do_cfg: + unconditional_embeds = self.adapter(unconditional_embeds.detach(), unconditional_clip_embeds) if self.adapter and isinstance(self.adapter, ReferenceAdapter): # pass in our scheduler @@ -1017,6 +1066,7 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs=pred_kwargs, noise=noise, batch=batch, + unconditional_embeds=unconditional_embeds ) self.before_unet_predict() @@ -1032,13 +1082,17 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs=pred_kwargs, batch=batch, noise=noise, + unconditional_embeds=unconditional_embeds ) else: with self.timer('predict_unet'): + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype) noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=1.0, **pred_kwargs diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 387d2a9d..1cd2c8f2 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -156,6 +156,7 @@ class AdapterConfig: self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512) self.safe_channels: int = kwargs.get('safe_channels', 2048) + self.safe_tokens: int = kwargs.get('safe_tokens', 8) # clip vision self.trigger = kwargs.get('trigger', 'tri993r') @@ -270,6 +271,7 @@ class TrainConfig: raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers") self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False) + self.do_cfg = kwargs.get('do_cfg', False) class ModelConfig: diff --git a/toolkit/guidance.py b/toolkit/guidance.py index 96d89c5e..c83a77f7 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -1,5 +1,5 @@ import torch -from typing import Literal +from typing import Literal, Optional from toolkit.basic import value_map from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO @@ -193,6 +193,7 @@ def get_direct_guidance_loss( batch: 'DataLoaderBatchDTO', noise: torch.Tensor, sd: 'StableDiffusion', + unconditional_embeds: Optional[PromptEmbeds] = None, **kwargs ): with torch.no_grad(): @@ -222,9 +223,14 @@ def get_direct_guidance_loss( # sd.network.multiplier = network_weight_list # do our prediction with LoRA active on the scaled guidance latents + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach() + unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds]) + prediction = sd.predict_noise( latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(), conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(), + unconditional_embeddings=unconditional_embeds, timestep=torch.cat([timesteps, timesteps]), guidance_scale=1.0, **pred_kwargs # adapter residuals in here @@ -482,12 +488,14 @@ def get_guidance_loss( batch: 'DataLoaderBatchDTO', noise: torch.Tensor, sd: 'StableDiffusion', + unconditional_embeds: Optional[PromptEmbeds] = None, **kwargs ): # TODO add others and process individual batch items separately guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type if guidance_type == "targeted": + assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance" return get_targeted_guidance_loss( noisy_latents, conditional_embeds, @@ -501,6 +509,7 @@ def get_guidance_loss( **kwargs ) elif guidance_type == "polarity": + assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" return get_guided_loss_polarity( noisy_latents, conditional_embeds, @@ -515,6 +524,7 @@ def get_guidance_loss( ) elif guidance_type == "targeted_polarity": + assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance" return get_targeted_polarity_loss( noisy_latents, conditional_embeds, @@ -538,6 +548,7 @@ def get_guidance_loss( batch, noise, sd, + unconditional_embeds=unconditional_embeds, **kwargs ) else: diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index f01977c3..49cbac2f 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -184,7 +184,7 @@ class IPAdapter(torch.nn.Module): self.clip_image_processor = SAFEImageProcessor() self.image_encoder = SAFEVisionModel( in_channels=3, - num_tokens=8, + num_tokens=self.config.safe_tokens, num_vectors=sd.unet.config['cross_attention_dim'], reducer_channels=self.config.safe_reducer_channels, channels=self.config.safe_channels, @@ -234,8 +234,8 @@ class IPAdapter(torch.nn.Module): dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1] - if self.config.image_encoder_arch == 'safe': - embedding_dim = self.config.safe_channels + # if self.config.image_encoder_arch == 'safe': + # embedding_dim = self.config.safe_tokens # size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]). # size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280]) # ip-adapter-plus