From 596e59dd6d9549a4129f7440bc39003a08b0c1d7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 21 Jul 2023 22:06:49 -0600 Subject: [PATCH] Slider training functioning, time to perfect it --- jobs/process/TrainSliderProcess.py | 222 +++++++++++++++++------------ 1 file changed, 130 insertions(+), 92 deletions(-) diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index eb3d008e..6e260940 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -3,6 +3,8 @@ import time from collections import OrderedDict import os +from typing import List + from toolkit.kohya_model_util import load_vae from toolkit.lora_special import LoRASpecialNetwork from toolkit.paths import REPOS_ROOT @@ -99,6 +101,21 @@ class ModelConfig: raise ValueError('name_or_path must be specified') +class SliderTargetConfig: + def __init__(self, **kwargs): + self.target_class: str = kwargs.get('target_class', None) + self.positive: str = kwargs.get('positive', None) + self.negative: str = kwargs.get('negative', None) + + +class SliderConfig: + def __init__(self, **kwargs): + targets = kwargs.get('targets', []) + targets = [SliderTargetConfig(**target) for target in targets] + self.targets: List[SliderTargetConfig] = targets + self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) + + class PromptSettingsOld: def __init__(self, **kwargs): self.target: str = kwargs.get('target', None) @@ -113,6 +130,24 @@ class PromptSettingsOld: self.dynamic_crops: bool = kwargs.get('dynamic_crops', False) # default is False. only used when model is XL +class EncodedPromptPair: + def __init__( + self, + target_class, + positive, + negative, + neutral, + width=512, + height=512 + ): + self.target_class = target_class + self.positive = positive + self.negative = negative + self.neutral = neutral + self.width = width + self.height = height + + class TrainSliderProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) @@ -127,10 +162,9 @@ class TrainSliderProcess(BaseTrainProcess): self.save_config = SaveConfig(**self.get_conf('save', {})) self.sample_config = SampleConfig(**self.get_conf('sample', {})) self.logging_config = LogingConfig(**self.get_conf('logging', {})) + self.slider_config = SliderConfig(**self.get_conf('slider', {})) self.sd = None - self.prompt_settings = self.get_prompt_settings() - # added later self.network = None self.scheduler = None @@ -142,14 +176,6 @@ class TrainSliderProcess(BaseTrainProcess): param.data = -param.data self.is_flipped = not self.is_flipped - def get_prompt_settings(self): - prompts = self.get_conf('prompts', required=True) - prompt_settings = [PromptSettingsOld(**prompt) for prompt in prompts] - # for i, prompt in enumerate(prompts): - # prompt_settings[i].fill_prompts(prompt) - - return prompt_settings - def sample(self, step=None): sample_folder = os.path.join(self.save_root, 'samples') if not os.path.exists(sample_folder): @@ -352,44 +378,38 @@ class TrainSliderProcess(BaseTrainProcess): max_iterations=self.train_config.steps, lr_min=self.train_config.lr / 100, # not sure why leco did this, but ill do it to ) - criteria = torch.nn.MSELoss() - - if self.logging_config.verbose: - print("Prompts") - for settings in self.prompt_settings: - print(settings) - - # debug - # debug_util.check_requires_grad(network) - # debug_util.check_training_mode(network) + loss_function = torch.nn.MSELoss() cache = PromptEmbedsCache() - prompt_pairs: list[PromptEmbedsPair] = [] + prompt_pairs: list[LatentPair] = [] + # get encoded latents for our prompts with torch.no_grad(): - for settings in self.prompt_settings: - self.print(settings) - for prompt in [ - settings.target, - settings.positive, - settings.neutral, - settings.unconditional, - ]: - if cache[prompt] == None: - cache[prompt] = train_util.encode_prompts( - tokenizer, text_encoder, [prompt] - ) + neutral = "" + for target in self.slider_config.targets: + for resolution in self.slider_config.resolutions: + width, height = resolution + for prompt in [ + target.target_class, + target.positive, + target.negative, + neutral # empty neutral + ]: + if cache[prompt] == None: + cache[prompt] = train_util.encode_prompts( + tokenizer, text_encoder, [prompt] + ) - prompt_pairs.append( - PromptEmbedsPair( - criteria, - cache[settings.target], - cache[settings.positive], - cache[settings.unconditional], - cache[settings.neutral], - settings, + prompt_pairs.append( + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + ) ) - ) # move to cpu to save vram # tokenizer.to("cpu") @@ -400,7 +420,6 @@ class TrainSliderProcess(BaseTrainProcess): self.print("Generating baseline samples before training") self.sample(0) - self.progress_bar = tqdm(range(self.train_config.steps)) self.progress_bar = tqdm( total=self.train_config.steps, desc=self.job.name, @@ -408,6 +427,29 @@ class TrainSliderProcess(BaseTrainProcess): ) self.step_num = 0 for step in range(self.train_config.steps): + + # get a random pair + prompt_pair: EncodedPromptPair = prompt_pairs[ + torch.randint(0, len(prompt_pairs), (1,)).item() + ] + + height = prompt_pair.height + width = prompt_pair.width + positive = prompt_pair.positive + target_class = prompt_pair.target_class + neutral = prompt_pair.neutral + negative = prompt_pair.negative + + # swap every other step and invert lora to spread slider + do_swap = step % 2 == 0 + + if do_swap: + negative = prompt_pair.positive + positive = prompt_pair.negative + # set the network in a negative weight + self.network.multiplier = -1.0 + + with torch.no_grad(): noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch @@ -415,34 +457,17 @@ class TrainSliderProcess(BaseTrainProcess): optimizer.zero_grad() - prompt_pair: PromptEmbedsPair = prompt_pairs[ - torch.randint(0, len(prompt_pairs), (1,)).item() - ] - - # 1 ~ 49 random from 1 to 49 + # ger a random number of steps timesteps_to = torch.randint( 1, self.train_config.max_denoising_steps, (1,) ).item() - height, width = ( - prompt_pair.resolution, - prompt_pair.resolution, - ) - if prompt_pair.dynamic_resolution: - height, width = train_util.get_random_resolution_in_bucket( - prompt_pair.resolution - ) - - if self.logging_config.verbose: - self.print("guidance_scale:", prompt_pair.guidance_scale) - self.print("resolution:", prompt_pair.resolution) - self.print("dynamic_resolution:", prompt_pair.dynamic_resolution) - if prompt_pair.dynamic_resolution: - self.print("bucketed resolution:", (height, width)) - self.print("batch_size:", prompt_pair.batch_size) - latents = train_util.get_initial_latents( - noise_scheduler, prompt_pair.batch_size, height, width, 1 + noise_scheduler, + self.train_config.batch_size, + height, + width, + 1 ).to(self.device_torch, dtype=dtype) with self.network: @@ -453,9 +478,9 @@ class TrainSliderProcess(BaseTrainProcess): noise_scheduler, latents, # pass simple noise latents train_util.concat_embeddings( - prompt_pair.unconditional, - prompt_pair.target, - prompt_pair.batch_size, + positive, # unconditional + target_class, # target + self.train_config.batch_size, ), start_timesteps=0, total_timesteps=timesteps_to, @@ -468,16 +493,16 @@ class TrainSliderProcess(BaseTrainProcess): int(timesteps_to * 1000 / self.train_config.max_denoising_steps) ] - # with network: Only empty LoRA is enabled outside with network : - positive_latents = train_util.predict_noise( + # with network: 0 weight LoRA is enabled outside "with network:" + positive_latents = train_util.predict_noise( # positive_latents unet, noise_scheduler, current_timestep, denoised_latents, train_util.concat_embeddings( - prompt_pair.unconditional, - prompt_pair.positive, - prompt_pair.batch_size, + positive, # unconditional + negative, # positive + self.train_config.batch_size, ), guidance_scale=1, ).to("cpu", dtype=torch.float32) @@ -487,9 +512,9 @@ class TrainSliderProcess(BaseTrainProcess): current_timestep, denoised_latents, train_util.concat_embeddings( - prompt_pair.unconditional, - prompt_pair.neutral, - prompt_pair.batch_size, + positive, # unconditional + neutral, # neutral + self.train_config.batch_size, ), guidance_scale=1, ).to("cpu", dtype=torch.float32) @@ -499,16 +524,12 @@ class TrainSliderProcess(BaseTrainProcess): current_timestep, denoised_latents, train_util.concat_embeddings( - prompt_pair.unconditional, - prompt_pair.unconditional, - prompt_pair.batch_size, + positive, # unconditional + positive, # unconditional + self.train_config.batch_size, ), guidance_scale=1, ).to("cpu", dtype=torch.float32) - # if self.logging_config.verbose: - # self.print("positive_latents:", positive_latents[0, 0, :5, :5]) - # self.print("neutral_latents:", neutral_latents[0, 0, :5, :5]) - # self.print("unconditional_latents:", unconditional_latents[0, 0, :5, :5]) with self.network: target_latents = train_util.predict_noise( @@ -517,9 +538,9 @@ class TrainSliderProcess(BaseTrainProcess): current_timestep, denoised_latents, train_util.concat_embeddings( - prompt_pair.unconditional, - prompt_pair.target, - prompt_pair.batch_size, + positive, # unconditional + target_class, # target + self.train_config.batch_size, ), guidance_scale=1, ).to("cpu", dtype=torch.float32) @@ -531,12 +552,23 @@ class TrainSliderProcess(BaseTrainProcess): neutral_latents.requires_grad = False unconditional_latents.requires_grad = False - loss = prompt_pair.loss( - target_latents=target_latents, - positive_latents=positive_latents, - neutral_latents=neutral_latents, - unconditional_latents=unconditional_latents, + erase = True + guidance_scale = 1.0 + + offset = guidance_scale * (positive_latents - unconditional_latents) + + offset_neutral = neutral_latents + if erase: + offset_neutral -= offset + else: + # enhance + offset_neutral += offset + + loss = loss_function( + target_latents, + offset_neutral, ) + loss_float = loss.item() if self.train_config.optimizer.startswith('dadaptation'): learning_rate = ( @@ -561,6 +593,9 @@ class TrainSliderProcess(BaseTrainProcess): ) flush() + # reset network + self.network.multiplier = 1.0 + # don't do on first step if self.step_num != self.start_step: # pause progress bar @@ -594,8 +629,11 @@ class TrainSliderProcess(BaseTrainProcess): # end of step self.step_num = step + print("") + self.save() + del ( unet, noise_scheduler,