diff --git a/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py b/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py index a6bb2f46..76ff7994 100644 --- a/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py +++ b/extensions_built_in/ultimate_slider_trainer/UltimateSliderTrainerProcess.py @@ -8,7 +8,7 @@ from torch.utils.data import ConcatDataset, DataLoader from toolkit.config_modules import ReferenceDatasetConfig from toolkit.data_loader import PairedImageDataset -from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds, build_latent_image_batch_for_prompt_pair from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds from toolkit.train_tools import get_torch_dtype, apply_snr_weight import gc @@ -44,6 +44,8 @@ class UltimateSliderConfig(SliderConfig): super().__init__(**kwargs) self.additional_losses: List[str] = kwargs.get('additional_losses', []) self.weight_jitter: float = kwargs.get('weight_jitter', 0.0) + self.img_loss_weight: float = kwargs.get('img_loss_weight', 1.0) + self.cfg_loss_weight: float = kwargs.get('cfg_loss_weight', 1.0) self.datasets: List[ReferenceDatasetConfig] = [ReferenceDatasetConfig(**d) for d in kwargs.get('datasets', [])] @@ -189,7 +191,6 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess): # do them one at a time (probably not necessary after new optimizations) prompt_pairs += [x.to('cpu') for x in prompt_pair_batch] - # move to cpu to save vram # We don't need text encoder anymore, but keep it on cpu for sampling # if text encoder is list @@ -216,12 +217,22 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess): # end hook_before_train_loop def hook_train_loop(self, batch): + dtype = get_torch_dtype(self.train_config.dtype) + with torch.no_grad(): ### LOOP SETUP ### noise_scheduler = self.sd.noise_scheduler optimizer = self.optimizer lr_scheduler = self.lr_scheduler + ### TARGET_PROMPTS ### + # get a random pair + prompt_pair: EncodedPromptPair = self.prompt_pairs[ + torch.randint(0, len(self.prompt_pairs), (1,)).item() + ] + # move to device and dtype + prompt_pair.to(self.device_torch, dtype=dtype) + ### PREP REFERENCE IMAGES ### imgs, prompts, network_weights = batch @@ -240,8 +251,6 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess): network_neg_weight += jitter_list # if items in network_weight list are tensors, convert them to floats - - dtype = get_torch_dtype(self.train_config.dtype) imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) # split batched images in half so left is negative and right is positive negative_images, positive_images = torch.chunk(imgs, 2, dim=3) @@ -258,6 +267,8 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess): ) timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) + current_timestep_index = timesteps.item() + current_timestep = noise_scheduler.timesteps[current_timestep_index] timesteps = timesteps.long() # get noise @@ -275,6 +286,63 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess): noisy_positive_latents = noise_scheduler.add_noise(positive_latents, noise_positive, timesteps) noisy_negative_latents = noise_scheduler.add_noise(negative_latents, noise_negative, timesteps) + ### CFG SLIDER TRAINING PREP ### + + # get CFG txt latents + noisy_cfg_latents = build_latent_image_batch_for_prompt_pair( + pos_latent=noisy_positive_latents, + neg_latent=noisy_negative_latents, + prompt_pair=prompt_pair, + prompt_chunk_size=self.prompt_chunk_size, + ) + noisy_cfg_latents.requires_grad = False + + assert not self.network.is_active + + # 4.20 GB RAM for 512x512 + positive_latents = self.sd.predict_noise( + latents=noisy_cfg_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # negative prompt + prompt_pair.negative_target, # positive prompt + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + positive_latents.requires_grad = False + + neutral_latents = self.sd.predict_noise( + latents=noisy_cfg_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # negative prompt + prompt_pair.empty_prompt, # positive prompt (normally neutral + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + neutral_latents.requires_grad = False + + unconditional_latents = self.sd.predict_noise( + latents=noisy_cfg_latents, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # negative prompt + prompt_pair.positive_target, # positive prompt + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + unconditional_latents.requires_grad = False + + positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0) + neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0) + unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0) + prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size) + noisy_cfg_latents_chunks = torch.chunk(noisy_cfg_latents, self.prompt_chunk_size, dim=0) + assert len(prompt_pair_chunks) == len(noisy_cfg_latents_chunks) + noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) noise = torch.cat([noise_positive, noise_negative], dim=0) timesteps = torch.cat([timesteps, timesteps], dim=0) @@ -329,7 +397,9 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess): timesteps_list = [timesteps] conditional_embeds_list = [conditional_embeds] - losses = [] + ## DO REFERENCE IMAGE TRAINING ## + + reference_image_losses = [] # allow to chunk it out to save vram for network_multiplier, noisy_latents, noise, timesteps, conditional_embeds in zip( network_multiplier_list, noisy_latent_list, noise_list, timesteps_list, conditional_embeds_list @@ -361,15 +431,88 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess): loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() + loss = loss * self.slider_config.img_loss_weight loss_slide_float = loss.item() loss_float = loss.item() - losses.append(loss_float) + reference_image_losses.append(loss_float) # back propagate loss to free ram loss.backward() flush() + ## DO CFG SLIDER TRAINING ## + + cfg_loss_list = [] + + with self.network: + assert self.network.is_active + for prompt_pair_chunk, \ + noisy_cfg_latent_chunk, \ + positive_latents_chunk, \ + neutral_latents_chunk, \ + unconditional_latents_chunk \ + in zip( + prompt_pair_chunks, + noisy_cfg_latents_chunks, + positive_latents_chunks, + neutral_latents_chunks, + unconditional_latents_chunks, + ): + self.network.multiplier = prompt_pair_chunk.multiplier_list + + target_latents = self.sd.predict_noise( + latents=noisy_cfg_latent_chunk, + text_embeddings=train_tools.concat_prompt_embeddings( + prompt_pair_chunk.positive_target, # negative prompt + prompt_pair_chunk.target_class, # positive prompt + self.train_config.batch_size, + ), + timestep=current_timestep, + guidance_scale=1.0 + ) + + guidance_scale = 1.0 + + offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk) + + # make offset multiplier based on actions + offset_multiplier_list = [] + for action in prompt_pair_chunk.action_list: + if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE: + offset_multiplier_list += [-1.0] + elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE: + offset_multiplier_list += [1.0] + + offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype) + # make offset multiplier match rank of offset + offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1) + offset *= offset_multiplier + + offset_neutral = neutral_latents_chunk + # offsets are already adjusted on a per-batch basis + offset_neutral += offset + + # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing + loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, + self.train_config.min_snr_gamma) + + loss = loss.mean() * prompt_pair_chunk.weight * self.slider_config.cfg_loss_weight + + loss.backward() + cfg_loss_list.append(loss.item()) + del target_latents + del offset_neutral + del loss + flush() + # apply gradients optimizer.step() lr_scheduler.step() @@ -377,9 +520,14 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess): # reset network self.network.multiplier = 1.0 - loss_dict = OrderedDict( - {'loss': sum(losses) / len(losses) if len(losses) > 0 else 0.0} - ) + reference_image_loss = sum(reference_image_losses) / len(reference_image_losses) if len( + reference_image_losses) > 0 else 0.0 + cfg_loss = sum(cfg_loss_list) / len(cfg_loss_list) if len(cfg_loss_list) > 0 else 0.0 + + loss_dict = OrderedDict({ + 'loss/img': reference_image_loss, + 'loss/cfg': cfg_loss, + }) return loss_dict # end hook_train_loop diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index aa5f42e5..d7e9da1f 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -34,7 +34,8 @@ class EncodedPromptPair: action_list=None, multiplier=1.0, multiplier_list=None, - weight=1.0 + weight=1.0, + target: 'SliderTargetConfig' = None, ): self.target_class: PromptEmbeds = target_class self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral @@ -46,6 +47,7 @@ class EncodedPromptPair: self.empty_prompt: PromptEmbeds = empty_prompt self.both_targets: PromptEmbeds = both_targets self.multiplier: float = multiplier + self.target: 'SliderTargetConfig' = target if multiplier_list is not None: self.multiplier_list: list[float] = multiplier_list else: @@ -109,7 +111,8 @@ def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]): both_targets=both_targets, action_list=action_list, multiplier_list=multiplier_list, - weight=weight + weight=weight, + target=prompt_pairs[0].target ) @@ -160,7 +163,8 @@ def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List both_targets=both_targets_splits[i], action_list=action_list_split, multiplier_list=multiplier_list_split, - weight=concatenated.weight + weight=concatenated.weight, + target=concatenated.target ) prompt_pairs.append(prompt_pair) @@ -358,7 +362,8 @@ def build_prompt_pair_batch_from_cache( multiplier=target.multiplier, both_targets=cache[f"{target.positive} {target.negative}"], empty_prompt=cache[""], - weight=target.weight + weight=target.weight, + target=target ), ] if both or enhance_positive: @@ -377,7 +382,8 @@ def build_prompt_pair_batch_from_cache( multiplier=target.multiplier, both_targets=cache[f"{target.positive} {target.negative}"], empty_prompt=cache[""], - weight=target.weight + weight=target.weight, + target=target ), ] if both or enhance_positive: @@ -396,7 +402,8 @@ def build_prompt_pair_batch_from_cache( both_targets=cache[f"{target.positive} {target.negative}"], empty_prompt=cache[""], multiplier=target.multiplier * -1.0, - weight=target.weight + weight=target.weight, + target=target ), ] if both or erase_negative: @@ -415,8 +422,39 @@ def build_prompt_pair_batch_from_cache( action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, empty_prompt=cache[""], multiplier=target.multiplier * -1.0, - weight=target.weight + weight=target.weight, + target=target ), ] return prompt_pair_batch + + +def build_latent_image_batch_for_prompt_pair( + pos_latent, + neg_latent, + prompt_pair: EncodedPromptPair, + prompt_chunk_size +): + erase_negative = len(prompt_pair.target.positive.strip()) == 0 + enhance_positive = len(prompt_pair.target.negative.strip()) == 0 + both = not erase_negative and not enhance_positive + + prompt_pair_chunks = split_prompt_pairs(prompt_pair, prompt_chunk_size) + if both and len(prompt_pair_chunks) != 4: + raise Exception("Invalid prompt pair chunks") + if (erase_negative or enhance_positive) and len(prompt_pair_chunks) != 2: + raise Exception("Invalid prompt pair chunks") + + latent_list = [] + + if both or erase_negative: + latent_list.append(pos_latent) + if both or enhance_positive: + latent_list.append(pos_latent) + if both or enhance_positive: + latent_list.append(neg_latent) + if both or erase_negative: + latent_list.append(neg_latent) + + return torch.cat(latent_list, dim=0)