diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 56ec7134..4dc8669b 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1085,7 +1085,7 @@ class SDTrainer(BaseSDTrainProcess): noise=noise, batch=batch, unconditional_embeds=unconditional_embeds - ) + ).detach() # do the custom adapter after the prior prediction if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 36b0cb1a..ab27bd2e 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -186,6 +186,16 @@ class BaseSDTrainProcess(BaseTrainProcess): sample_config = self.first_sample_config if is_first else self.sample_config start_seed = sample_config.seed current_seed = start_seed + + test_image_paths = [] + if self.adapter_config is not None and self.adapter_config.test_img_path is not None: + test_image_path_list = self.adapter_config.test_img_path.split(',') + test_image_path_list = [p.strip() for p in test_image_path_list] + test_image_path_list = [p for p in test_image_path_list if p != ''] + # divide up images so they are evenly distributed across prompts + for i in range(len(sample_config.prompts)): + test_image_paths.append(test_image_path_list[i % len(test_image_path_list)]) + for i in range(len(sample_config.prompts)): if sample_config.walk_seed: current_seed = start_seed + i @@ -219,7 +229,7 @@ class BaseSDTrainProcess(BaseTrainProcess): extra_args = {} if self.adapter_config is not None and self.adapter_config.test_img_path is not None: - extra_args['adapter_image_path'] = self.adapter_config.test_img_path + extra_args['adapter_image_path'] = test_image_paths[i] gen_img_config_list.append(GenerateImageConfig( prompt=prompt, # it will autoparse the prompt diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 2803a8e6..284a3227 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -403,6 +403,7 @@ class DatasetConfig: # remove empty lines random_triggers = [line for line in random_triggers if line.strip() != ''] self.random_triggers: List[str] = random_triggers + self.random_triggers_max: int = kwargs.get('random_triggers_max', 1) self.caption_ext: str = kwargs.get('caption_ext', None) self.random_scale: bool = kwargs.get('random_scale', False) self.random_crop: bool = kwargs.get('random_crop', False) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index ebc1b1fe..0aef07b4 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -277,7 +277,7 @@ class CustomAdapter(torch.nn.Module): raise ValueError(f"unknown shape: {v.shape}") self.fuse_module.load_state_dict(current_state_dict, strict=strict) - if 'vision_encoder' in state_dict: + if 'vision_encoder' in state_dict and self.config.train_image_encoder: self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict) if 'fuse_module' in state_dict: @@ -411,7 +411,7 @@ class CustomAdapter(torch.nn.Module): if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion': if is_unconditional: # we dont condition the negative embeds for photo maker - return prompt_embeds + return prompt_embeds.clone() with torch.no_grad(): # on training the clip image is created in the dataloader if not has_been_preprocessed: diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 7f2c52c9..3a82511e 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -348,8 +348,14 @@ class CaptionProcessingDTOMixin: caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0: - # add random triggers - caption = caption + ', ' + random.choice(self.dataset_config.random_triggers) + num_triggers = self.dataset_config.random_triggers_max + if num_triggers > 1: + num_triggers = random.randint(0, num_triggers) + + if num_triggers > 0: + # add random triggers + for i in range(num_triggers): + caption = caption + ', ' + random.choice(self.dataset_config.random_triggers) if self.dataset_config.shuffle_tokens: # shuffle again diff --git a/toolkit/models/clip_fusion.py b/toolkit/models/clip_fusion.py index 3a9448b0..73fb3ffb 100644 --- a/toolkit/models/clip_fusion.py +++ b/toolkit/models/clip_fusion.py @@ -86,6 +86,49 @@ class ZipperBlock(nn.Module): return x +class ContextualAlphaMask(nn.Module): + def __init__( + self, + dim: int = 768, + ): + super(ContextualAlphaMask, self).__init__() + self.dim = dim + + half_dim = dim // 2 + quarter_dim = dim // 4 + + self.fc1 = nn.Linear(self.dim, self.dim) + self.fc2 = nn.Linear(self.dim, half_dim) + self.norm1 = nn.LayerNorm(half_dim) + self.fc3 = nn.Linear(half_dim, half_dim) + self.fc4 = nn.Linear(half_dim, quarter_dim) + self.norm2 = nn.LayerNorm(quarter_dim) + self.fc5 = nn.Linear(quarter_dim, quarter_dim) + self.fc6 = nn.Linear(quarter_dim, 1) + # set fc6 weights to near zero + self.fc6.weight.data.normal_(mean=0.0, std=0.0001) + self.act_fn = nn.GELU() + + def forward(self, x): + # x = (batch_size, 77, 768) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.norm1(x) + x = self.act_fn(x) + x = self.fc3(x) + x = self.act_fn(x) + x = self.fc4(x) + x = self.norm2(x) + x = self.act_fn(x) + x = self.fc5(x) + x = self.act_fn(x) + x = self.fc6(x) + x = torch.sigmoid(x) + return x + + + # CLIPFusionModule # Fuses any size of vision and text embeddings into a single embedding. # remaps tokens and vectors. @@ -96,7 +139,7 @@ class CLIPFusionModule(nn.Module): text_tokens: int = 77, vision_hidden_size: int = 1024, vision_tokens: int = 257, - num_blocks: int = 2, + num_blocks: int = 1, ): super(CLIPFusionModule, self).__init__() @@ -125,6 +168,10 @@ class CLIPFusionModule(nn.Module): ) for i in range(num_blocks) ]) + self.ctx_alpha = ContextualAlphaMask( + dim=self.text_hidden_size, + ) + def forward(self, text_embeds, vision_embeds): # text_embeds = (batch_size, 77, 768) # vision_embeds = (batch_size, 257, 1024) @@ -138,6 +185,8 @@ class CLIPFusionModule(nn.Module): x = block(x) x = x + res - x = text_embeds + x + # alpha mask + alpha = self.ctx_alpha(text_embeds) + x = alpha * x + (1 - alpha) * text_embeds return x diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index d0fb6b87..c81ef442 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -804,28 +804,27 @@ class StableDiffusion: detach_unconditional=False, **kwargs, ): - with torch.no_grad(): - # get the embeddings - if text_embeddings is None and conditional_embeddings is None: - raise ValueError("Either text_embeddings or conditional_embeddings must be specified") - if text_embeddings is None and unconditional_embeddings is not None: - text_embeddings = concat_prompt_embeds([ - unconditional_embeddings, # negative embedding - conditional_embeddings, # positive embedding - ]) - elif text_embeddings is None and conditional_embeddings is not None: - # not doing cfg - text_embeddings = conditional_embeddings + # get the embeddings + if text_embeddings is None and conditional_embeddings is None: + raise ValueError("Either text_embeddings or conditional_embeddings must be specified") + if text_embeddings is None and unconditional_embeddings is not None: + text_embeddings = concat_prompt_embeds([ + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + ]) + elif text_embeddings is None and conditional_embeddings is not None: + # not doing cfg + text_embeddings = conditional_embeddings - # CFG is comparing neg and positive, if we have concatenated embeddings - # then we are doing it, otherwise we are not and takes half the time. - do_classifier_free_guidance = True + # CFG is comparing neg and positive, if we have concatenated embeddings + # then we are doing it, otherwise we are not and takes half the time. + do_classifier_free_guidance = True - # check if batch size of embeddings matches batch size of latents - if latents.shape[0] == text_embeddings.text_embeds.shape[0]: - do_classifier_free_guidance = False - elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: - raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") + # check if batch size of embeddings matches batch size of latents + if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + do_classifier_free_guidance = False + elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") latents = latents.to(self.device_torch) text_embeddings = text_embeddings.to(self.device_torch) timestep = timestep.to(self.device_torch)