Added a way pass refiner ratio to sample config

This commit is contained in:
Jaret Burkett
2023-11-06 09:22:58 -07:00
parent a8b3b8b8da
commit dc8448d958
4 changed files with 13 additions and 4 deletions

View File

@@ -220,6 +220,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
output_path=output_path,
output_ext=sample_config.ext,
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
refiner_start_at=sample_config.refiner_start_at,
**extra_args
))

View File

@@ -45,6 +45,7 @@ class SampleConfig:
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext: ImgExt = kwargs.get('format', 'jpg')
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists
class LormModuleSettingsConfig:
@@ -430,6 +431,7 @@ class GenerateImageConfig:
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
latents: Union[torch.Tensor | None] = None, # input latent to start with,
extra_kwargs: dict = None, # extra data to save with prompt file
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
):
self.width: int = width
self.height: int = height
@@ -456,6 +458,7 @@ class GenerateImageConfig:
self.adapter_image_path: str = adapter_image_path
self.adapter_conditioning_scale: float = adapter_conditioning_scale
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
self.refiner_start_at = refiner_start_at
# prompt string will override any settings above
self._process_prompt_string()
@@ -612,6 +615,8 @@ class GenerateImageConfig:
self.guidance_rescale = float(content)
elif flag == 'a':
self.adapter_conditioning_scale = float(content)
elif flag == 'ref':
self.refiner_start_at = float(content)
def post_process_embeddings(
self,

View File

@@ -69,6 +69,9 @@ def get_train_sd_device_state_preset(
preset['refiner_unet']['training'] = True
preset['refiner_unet']['requires_grad'] = True
preset['refiner_unet']['device'] = device
# if not training unet, move that to cpu
if not train_unet:
preset['unet']['device'] = 'cpu'
if train_lora:
# preset['text_encoder']['requires_grad'] = False

View File

@@ -472,9 +472,9 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
if self.refiner_unet is not None:
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
# if we have a refiner loaded, set the denoising end at the refiner start
extra['denoising_end'] = self.model_config.refiner_start_at
extra['denoising_end'] = gen_config.refiner_start_at
extra['output_type'] = 'latent'
if not self.is_xl:
raise ValueError("Refiner is only supported for XL models")
@@ -526,7 +526,7 @@ class StableDiffusion:
**extra
).images[0]
if refiner_pipeline is not None:
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
@@ -546,7 +546,7 @@ class StableDiffusion:
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=grs,
denoising_start=self.model_config.refiner_start_at,
denoising_start=gen_config.refiner_start_at,
denoising_end=gen_config.num_inference_steps,
image=img.unsqueeze(0)
).images[0]