From dc8448d958037108165145c8ecdc41cd3a2658a7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 6 Nov 2023 09:22:58 -0700 Subject: [PATCH] Added a way pass refiner ratio to sample config --- jobs/process/BaseSDTrainProcess.py | 1 + toolkit/config_modules.py | 5 +++++ toolkit/sd_device_states_presets.py | 3 +++ toolkit/stable_diffusion_model.py | 8 ++++---- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 30ce7ec6..863e4dca 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -220,6 +220,7 @@ class BaseSDTrainProcess(BaseTrainProcess): output_path=output_path, output_ext=sample_config.ext, adapter_conditioning_scale=sample_config.adapter_conditioning_scale, + refiner_start_at=sample_config.refiner_start_at, **extra_args )) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index c3ca7c3b..b8f8122e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -45,6 +45,7 @@ class SampleConfig: self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) self.ext: ImgExt = kwargs.get('format', 'jpg') self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) + self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists class LormModuleSettingsConfig: @@ -430,6 +431,7 @@ class GenerateImageConfig: adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning latents: Union[torch.Tensor | None] = None, # input latent to start with, extra_kwargs: dict = None, # extra data to save with prompt file + refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end ): self.width: int = width self.height: int = height @@ -456,6 +458,7 @@ class GenerateImageConfig: self.adapter_image_path: str = adapter_image_path self.adapter_conditioning_scale: float = adapter_conditioning_scale self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {} + self.refiner_start_at = refiner_start_at # prompt string will override any settings above self._process_prompt_string() @@ -612,6 +615,8 @@ class GenerateImageConfig: self.guidance_rescale = float(content) elif flag == 'a': self.adapter_conditioning_scale = float(content) + elif flag == 'ref': + self.refiner_start_at = float(content) def post_process_embeddings( self, diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py index f55744d2..fb4663a0 100644 --- a/toolkit/sd_device_states_presets.py +++ b/toolkit/sd_device_states_presets.py @@ -69,6 +69,9 @@ def get_train_sd_device_state_preset( preset['refiner_unet']['training'] = True preset['refiner_unet']['requires_grad'] = True preset['refiner_unet']['device'] = device + # if not training unet, move that to cpu + if not train_unet: + preset['unet']['device'] = 'cpu' if train_lora: # preset['text_encoder']['requires_grad'] = False diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 8077f4c5..891230b6 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -472,9 +472,9 @@ class StableDiffusion: conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds) - if self.refiner_unet is not None: + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: # if we have a refiner loaded, set the denoising end at the refiner start - extra['denoising_end'] = self.model_config.refiner_start_at + extra['denoising_end'] = gen_config.refiner_start_at extra['output_type'] = 'latent' if not self.is_xl: raise ValueError("Refiner is only supported for XL models") @@ -526,7 +526,7 @@ class StableDiffusion: **extra ).images[0] - if refiner_pipeline is not None: + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: # slide off just the last 1280 on the last dim as refiner does not use first text encoder # todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:] @@ -546,7 +546,7 @@ class StableDiffusion: num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, guidance_rescale=grs, - denoising_start=self.model_config.refiner_start_at, + denoising_start=gen_config.refiner_start_at, denoising_end=gen_config.num_inference_steps, image=img.unsqueeze(0) ).images[0]