mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-08 22:49:58 +00:00
Added a way pass refiner ratio to sample config
This commit is contained in:
@@ -220,6 +220,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
output_path=output_path,
|
||||
output_ext=sample_config.ext,
|
||||
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
||||
refiner_start_at=sample_config.refiner_start_at,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ class SampleConfig:
|
||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||
self.ext: ImgExt = kwargs.get('format', 'jpg')
|
||||
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists
|
||||
|
||||
|
||||
class LormModuleSettingsConfig:
|
||||
@@ -430,6 +431,7 @@ class GenerateImageConfig:
|
||||
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
|
||||
latents: Union[torch.Tensor | None] = None, # input latent to start with,
|
||||
extra_kwargs: dict = None, # extra data to save with prompt file
|
||||
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -456,6 +458,7 @@ class GenerateImageConfig:
|
||||
self.adapter_image_path: str = adapter_image_path
|
||||
self.adapter_conditioning_scale: float = adapter_conditioning_scale
|
||||
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
|
||||
self.refiner_start_at = refiner_start_at
|
||||
|
||||
# prompt string will override any settings above
|
||||
self._process_prompt_string()
|
||||
@@ -612,6 +615,8 @@ class GenerateImageConfig:
|
||||
self.guidance_rescale = float(content)
|
||||
elif flag == 'a':
|
||||
self.adapter_conditioning_scale = float(content)
|
||||
elif flag == 'ref':
|
||||
self.refiner_start_at = float(content)
|
||||
|
||||
def post_process_embeddings(
|
||||
self,
|
||||
|
||||
@@ -69,6 +69,9 @@ def get_train_sd_device_state_preset(
|
||||
preset['refiner_unet']['training'] = True
|
||||
preset['refiner_unet']['requires_grad'] = True
|
||||
preset['refiner_unet']['device'] = device
|
||||
# if not training unet, move that to cpu
|
||||
if not train_unet:
|
||||
preset['unet']['device'] = 'cpu'
|
||||
|
||||
if train_lora:
|
||||
# preset['text_encoder']['requires_grad'] = False
|
||||
|
||||
@@ -472,9 +472,9 @@ class StableDiffusion:
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||
|
||||
if self.refiner_unet is not None:
|
||||
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||
# if we have a refiner loaded, set the denoising end at the refiner start
|
||||
extra['denoising_end'] = self.model_config.refiner_start_at
|
||||
extra['denoising_end'] = gen_config.refiner_start_at
|
||||
extra['output_type'] = 'latent'
|
||||
if not self.is_xl:
|
||||
raise ValueError("Refiner is only supported for XL models")
|
||||
@@ -526,7 +526,7 @@ class StableDiffusion:
|
||||
**extra
|
||||
).images[0]
|
||||
|
||||
if refiner_pipeline is not None:
|
||||
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
|
||||
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
|
||||
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
|
||||
@@ -546,7 +546,7 @@ class StableDiffusion:
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=grs,
|
||||
denoising_start=self.model_config.refiner_start_at,
|
||||
denoising_start=gen_config.refiner_start_at,
|
||||
denoising_end=gen_config.num_inference_steps,
|
||||
image=img.unsqueeze(0)
|
||||
).images[0]
|
||||
|
||||
Reference in New Issue
Block a user