From a05459afaff4ce2e41a01a676c72cf300b295729 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 15 Oct 2023 15:13:35 -0600 Subject: [PATCH] Fixed issue with adapters that only had 1 input channel. Added ability to set the percentage chance of adapter matching --- extensions_built_in/sd_trainer/SDTrainer.py | 109 +++++++++++++------- toolkit/config_modules.py | 8 +- 2 files changed, 79 insertions(+), 38 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index bc78b369..ca14ee89 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -22,6 +22,7 @@ def flush(): torch.cuda.empty_cache() gc.collect() + adapter_transforms = transforms.Compose([ # transforms.PILToTensor(), transforms.ToTensor(), @@ -51,7 +52,6 @@ class SDTrainer(BaseSDTrainProcess): self.assistant_adapter.requires_grad_(False) flush() - def hook_before_train_loop(self): # move vae to device if we did not cache latents if not self.is_latents_cached: @@ -62,7 +62,6 @@ class SDTrainer(BaseSDTrainProcess): self.sd.vae.to('cpu') flush() - def hook_train_loop(self, batch): self.timer.start('preprocess_batch') @@ -72,6 +71,17 @@ class SDTrainer(BaseSDTrainProcess): has_adapter_img = batch.control_tensor is not None + match_adapter_assist = False + + # check if we are matching the adapter assistant + if self.assistant_adapter: + if self.train_config.match_adapter_chance == 1.0: + match_adapter_assist = True + elif self.train_config.match_adapter_chance > 0.0: + match_adapter_assist = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) < self.train_config.match_adapter_chance + self.timer.stop('preprocess_batch') with torch.no_grad(): @@ -82,6 +92,12 @@ class SDTrainer(BaseSDTrainProcess): # todo move this to data loader if batch.control_tensor is not None: adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() + # match in channels + if self.assistant_adapter is not None: + in_channels = self.assistant_adapter.config.in_channels + if adapter_images.shape[1] != in_channels: + # we need to match the channels + adapter_images = adapter_images[:, :in_channels, :, :] else: raise NotImplementedError("Adapter images now must be loaded with dataloader") # not 100% sure what this does. But they do it here @@ -106,7 +122,7 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter and isinstance(self.adapter, T2IAdapter): # training a t2i adapter, not using as assistant. return 1.0 - elif self.train_config.match_adapter_assist: + elif match_adapter_assist: # training a texture. We want it high adapter_strength_min = 0.9 adapter_strength_max = 1.0 @@ -117,18 +133,18 @@ class SDTrainer(BaseSDTrainProcess): adapter_strength_min = 0.9 adapter_strength_max = 1.1 - adapter_conditioning_scale = torch.rand( - (1,), device=self.device_torch, dtype=dtype - ) + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) - adapter_conditioning_scale = value_map( - adapter_conditioning_scale, - 0.0, - 1.0, - adapter_strength_min, - adapter_strength_max - ) - return adapter_conditioning_scale + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return adapter_conditioning_scale # flush() with self.timer('grad_setup'): @@ -154,11 +170,26 @@ class SDTrainer(BaseSDTrainProcess): # activate network if it exits with network: with self.timer('encode_prompt'): - with torch.set_grad_enabled(grad_on_text_encoder): - conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype) - if not grad_on_text_encoder: + if grad_on_text_encoder: + with torch.set_grad_enabled(True): + conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to( + self.device_torch, + dtype=dtype) + else: + with torch.set_grad_enabled(False): + # make sure it is in eval mode + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + else: + self.sd.text_encoder.eval() + conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to( + self.device_torch, + dtype=dtype) + # detach the embeddings conditional_embeds = conditional_embeds.detach() + # flush() pred_kwargs = {} if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): @@ -170,33 +201,38 @@ class SDTrainer(BaseSDTrainProcess): if self.assistant_adapter: # not training. detach down_block_additional_residuals = [ - sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in + down_block_additional_residuals ] else: down_block_additional_residuals = [ - sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals + sample.to(dtype=dtype) * adapter_multiplier for sample in + down_block_additional_residuals ] pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals control_pred = None - if has_adapter_img and self.assistant_adapter and self.train_config.match_adapter_assist: - # do a prediction here so we can match its output with network multiplier set to 0.0 - with torch.no_grad(): - # dont use network on this - network.multiplier = 0.0 - control_pred = self.sd.predict_noise( - latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs # adapter residuals in here - ) - control_pred = control_pred.detach() - # remove the residuals as we wont use them on prediction when matching control - del pred_kwargs['down_block_additional_residuals'] - # restore network - network.multiplier = network_weight_list + if has_adapter_img and self.assistant_adapter and match_adapter_assist: + with self.timer('predict_with_adapter'): + # do a prediction here so we can match its output with network multiplier set to 0.0 + with torch.no_grad(): + # dont use network on this + network.multiplier = 0.0 + self.sd.unet.eval() + control_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + self.sd.unet.train() + control_pred = control_pred.detach() + # remove the residuals as we wont use them on prediction when matching control + del pred_kwargs['down_block_additional_residuals'] + # restore network + network.multiplier = network_weight_list if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter'): @@ -204,7 +240,6 @@ class SDTrainer(BaseSDTrainProcess): conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) - with self.timer('predict_unet'): noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 904084e8..e3149505 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -122,7 +122,13 @@ class TrainConfig: self.start_step = kwargs.get('start_step', None) self.free_u = kwargs.get('free_u', False) self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) - self.match_adapter_assist = kwargs.get('match_adapter_assist', False) + + match_adapter_assist = kwargs.get('match_adapter_assist', False) + self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0) + + # legacy + if match_adapter_assist and self.match_adapter_chance == 0.0: + self.match_adapter_chance = 1.0