diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 9a7c8302..7b0b14a1 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -114,7 +114,7 @@ class BaseSDTrainProcess(BaseTrainProcess): tokenizer=self.sd.tokenizer[0], tokenizer_2=self.sd.tokenizer[1], scheduler=self.sd.noise_scheduler, - ) + ).to(self.device_torch) else: pipeline = StableDiffusionPipeline( vae=self.sd.vae, @@ -125,7 +125,7 @@ class BaseSDTrainProcess(BaseTrainProcess): safety_checker=None, feature_extractor=None, requires_safety_checker=False, - ) + ).to(self.device_torch) # disable progress bar pipeline.set_progress_bar_config(disable=True) @@ -387,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess): text_embeddings: PromptEmbeds, timestep: int, guidance_scale=7.5, - guidance_rescale=0, # 0.7 + guidance_rescale=0, # 0.7 add_time_ids=None, **kwargs, ): @@ -585,17 +585,16 @@ class BaseSDTrainProcess(BaseTrainProcess): unet.eval() if self.network_config is not None: - conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None self.network = LoRASpecialNetwork( text_encoder=text_encoder, unet=unet, lora_dim=self.network_config.linear, multiplier=1.0, - alpha=self.network_config.alpha, + alpha=self.network_config.linear_alpha, train_unet=self.train_config.train_unet, train_text_encoder=self.train_config.train_text_encoder, - conv_lora_dim=conv, - conv_alpha=self.network_config.alpha if conv is not None else None, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, ) self.network.force_to(self.device_torch, dtype=dtype) diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py index 51e4f3d2..a0ed1855 100644 --- a/jobs/process/TrainSDRescaleProcess.py +++ b/jobs/process/TrainSDRescaleProcess.py @@ -43,6 +43,7 @@ class RescaleConfig: self.prompt_file = kwargs.get('prompt_file', None) self.prompt_tensors = kwargs.get('prompt_tensors', None) self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale)) + self.prompt_dropout = kwargs.get('prompt_dropout', 0.1) if self.prompt_file is None: raise ValueError("prompt_file is required") @@ -64,7 +65,7 @@ class PromptEmbedsCache: class TrainSDRescaleProcess(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): # pass our custom pipeline to super so it sets it up - super().__init__(process_id, job, config, custom_pipeline=TransferStableDiffusionXLPipeline) + super().__init__(process_id, job, config) self.step_num = 0 self.start_step = 0 self.device = self.get_conf('device', self.job.device) @@ -158,31 +159,36 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): def hook_train_loop(self): dtype = get_torch_dtype(self.train_config.dtype) + do_dropout = False + + # see if we should dropout + if self.rescale_config.prompt_dropout > 0.0: + thresh = int(self.rescale_config.prompt_dropout * 100) + if torch.randint(0, 100, (1,)).item() < thresh: + do_dropout = True + # get random encoded prompt from cache - prompt_txt = self.prompt_txt_list[ + positive_prompt_txt = self.prompt_txt_list[ torch.randint(0, len(self.prompt_txt_list), (1,)).item() ] - prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype) - prompt.text_embeds.to(device=self.device_torch, dtype=dtype) - neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype) - neutral.text_embeds.to(device=self.device_torch, dtype=dtype) - if hasattr(prompt, 'pooled_embeds') \ - and hasattr(neutral, 'pooled_embeds') \ - and prompt.pooled_embeds is not None \ - and neutral.pooled_embeds is not None: - prompt.pooled_embeds.to(device=self.device_torch, dtype=dtype) - neutral.pooled_embeds.to(device=self.device_torch, dtype=dtype) + negative_prompt_txt = self.prompt_txt_list[ + torch.randint(0, len(self.prompt_txt_list), (1,)).item() + ] + if do_dropout: + positive_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype) + negative_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype) + else: + positive_prompt = self.prompt_cache[positive_prompt_txt].to(device=self.device_torch, dtype=dtype) + negative_prompt = self.prompt_cache[negative_prompt_txt].to(device=self.device_torch, dtype=dtype) - if prompt is None: - raise ValueError(f"Prompt {prompt_txt} is not in cache") + if positive_prompt is None: + raise ValueError(f"Prompt {positive_prompt_txt} is not in cache") + if negative_prompt is None: + raise ValueError(f"Prompt {negative_prompt_txt} is not in cache") loss_function = torch.nn.MSELoss() with torch.no_grad(): - # self.sd.noise_scheduler.set_timesteps( - # self.train_config.max_denoising_steps, device=self.device_torch - # ) - self.optimizer.zero_grad() # # ger a random number of steps @@ -190,63 +196,89 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): 1, self.train_config.max_denoising_steps, (1,) ).item() + # set the scheduler to the number of steps + self.sd.noise_scheduler.set_timesteps( + timesteps_to, device=self.device_torch + ) + # get noise - latents = self.get_latent_noise( + noise = self.get_latent_noise( pixel_height=self.rescale_config.from_resolution, pixel_width=self.rescale_config.from_resolution, ).to(self.device_torch, dtype=dtype) - self.sd.pipeline.to(self.device_torch) torch.set_default_device(self.device_torch) - # turn off progress bar - self.sd.pipeline.set_progress_bar_config(disable=True) + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) - # get random guidance scale from 1.0 to 10.0 + # get random guidance scale from 1.0 to 10.0 (CFG) guidance_scale = torch.rand(1).item() * 9.0 + 1.0 loss_arr = [] - max_len_timestep_str = len(str(self.train_config.max_denoising_steps)) # pad with spaces timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ") new_description = f"{self.job.name} ts: {timestep_str}" self.progress_bar.set_description(new_description) - def pre_condition_callback(target_pred, input_latents): - # handle any manipulations before feeding to our network - reduced_pred = self.reduce_size_fn(target_pred) - reduced_latents = self.reduce_size_fn(input_latents) - self.optimizer.zero_grad() - return reduced_pred, reduced_latents + # Begin gradient accumulation + self.optimizer.zero_grad() - def each_step_callback(noise_target, noise_train_pred): - noise_target.requires_grad = False - loss = loss_function(noise_target, noise_train_pred) - loss_arr.append(loss.item()) - loss.backward() - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() + # perform the diffusion + for timestep in tqdm(self.sd.noise_scheduler.timesteps, leave=False): + assert not self.network.is_active - # run the pipeline - self.sd.pipeline.transfer_diffuse( - num_inference_steps=timesteps_to, - latents=latents, - prompt_embeds=prompt.text_embeds, - negative_prompt_embeds=neutral.text_embeds, - pooled_prompt_embeds=prompt.pooled_embeds, - negative_pooled_prompt_embeds=neutral.pooled_embeds, - output_type="latent", - num_images_per_prompt=self.train_config.batch_size, - guidance_scale=guidance_scale, - network=self.network, - target_unet=self.sd.unet, - pre_condition_callback=pre_condition_callback, - each_step_callback=each_step_callback, + text_embeddings = train_tools.concat_prompt_embeddings( + negative_prompt, # unconditional (negative prompt) + positive_prompt, # conditional (positive prompt) + self.train_config.batch_size, ) + with torch.no_grad(): + noise_pred_target = self.predict_noise( + latents, + text_embeddings=text_embeddings, + timestep=timestep, + guidance_scale=guidance_scale + ) + + # todo should we do every step? + do_train_cycle = True + + if do_train_cycle: + # get the reduced latents + with torch.no_grad(): + reduced_pred = self.reduce_size_fn(noise_pred_target.detach()) + reduced_latents = self.reduce_size_fn(latents.detach()) + with self.network: + assert self.network.is_active + self.network.multiplier = 1.0 + noise_pred_train = self.predict_noise( + reduced_latents, + text_embeddings=text_embeddings, + timestep=timestep, + guidance_scale=guidance_scale + ) + + reduced_pred.requires_grad = False + loss = loss_function(noise_pred_train, reduced_pred) + loss_arr.append(loss.item()) + loss.backward() + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # get next latents + # todo allow to show latent here + latents = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample + + # reset prompt embeds + positive_prompt.to(device="cpu") + negative_prompt.to(device="cpu") + flush() # reset network diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 2f618512..3848461f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -42,6 +42,8 @@ class NetworkConfig: self.linear: int = linear self.conv: int = kwargs.get('conv', None) self.alpha: float = kwargs.get('alpha', 1.0) + self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) + self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) class TrainConfig: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 14d5cb07..5922d2cc 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -241,6 +241,9 @@ class LoRASpecialNetwork(LoRANetwork): @multiplier.setter def multiplier(self, value): + # only update if changed + if self._multiplier == value: + return self._multiplier = value self._update_lora_multiplier()