mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added a way pass refiner ratio to sample config
This commit is contained in:
@@ -220,6 +220,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
output_ext=sample_config.ext,
|
output_ext=sample_config.ext,
|
||||||
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
||||||
|
refiner_start_at=sample_config.refiner_start_at,
|
||||||
**extra_args
|
**extra_args
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ class SampleConfig:
|
|||||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||||
self.ext: ImgExt = kwargs.get('format', 'jpg')
|
self.ext: ImgExt = kwargs.get('format', 'jpg')
|
||||||
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
||||||
|
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists
|
||||||
|
|
||||||
|
|
||||||
class LormModuleSettingsConfig:
|
class LormModuleSettingsConfig:
|
||||||
@@ -430,6 +431,7 @@ class GenerateImageConfig:
|
|||||||
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
|
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
|
||||||
latents: Union[torch.Tensor | None] = None, # input latent to start with,
|
latents: Union[torch.Tensor | None] = None, # input latent to start with,
|
||||||
extra_kwargs: dict = None, # extra data to save with prompt file
|
extra_kwargs: dict = None, # extra data to save with prompt file
|
||||||
|
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
|
||||||
):
|
):
|
||||||
self.width: int = width
|
self.width: int = width
|
||||||
self.height: int = height
|
self.height: int = height
|
||||||
@@ -456,6 +458,7 @@ class GenerateImageConfig:
|
|||||||
self.adapter_image_path: str = adapter_image_path
|
self.adapter_image_path: str = adapter_image_path
|
||||||
self.adapter_conditioning_scale: float = adapter_conditioning_scale
|
self.adapter_conditioning_scale: float = adapter_conditioning_scale
|
||||||
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
|
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
|
||||||
|
self.refiner_start_at = refiner_start_at
|
||||||
|
|
||||||
# prompt string will override any settings above
|
# prompt string will override any settings above
|
||||||
self._process_prompt_string()
|
self._process_prompt_string()
|
||||||
@@ -612,6 +615,8 @@ class GenerateImageConfig:
|
|||||||
self.guidance_rescale = float(content)
|
self.guidance_rescale = float(content)
|
||||||
elif flag == 'a':
|
elif flag == 'a':
|
||||||
self.adapter_conditioning_scale = float(content)
|
self.adapter_conditioning_scale = float(content)
|
||||||
|
elif flag == 'ref':
|
||||||
|
self.refiner_start_at = float(content)
|
||||||
|
|
||||||
def post_process_embeddings(
|
def post_process_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -69,6 +69,9 @@ def get_train_sd_device_state_preset(
|
|||||||
preset['refiner_unet']['training'] = True
|
preset['refiner_unet']['training'] = True
|
||||||
preset['refiner_unet']['requires_grad'] = True
|
preset['refiner_unet']['requires_grad'] = True
|
||||||
preset['refiner_unet']['device'] = device
|
preset['refiner_unet']['device'] = device
|
||||||
|
# if not training unet, move that to cpu
|
||||||
|
if not train_unet:
|
||||||
|
preset['unet']['device'] = 'cpu'
|
||||||
|
|
||||||
if train_lora:
|
if train_lora:
|
||||||
# preset['text_encoder']['requires_grad'] = False
|
# preset['text_encoder']['requires_grad'] = False
|
||||||
|
|||||||
@@ -472,9 +472,9 @@ class StableDiffusion:
|
|||||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||||
|
|
||||||
if self.refiner_unet is not None:
|
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||||
# if we have a refiner loaded, set the denoising end at the refiner start
|
# if we have a refiner loaded, set the denoising end at the refiner start
|
||||||
extra['denoising_end'] = self.model_config.refiner_start_at
|
extra['denoising_end'] = gen_config.refiner_start_at
|
||||||
extra['output_type'] = 'latent'
|
extra['output_type'] = 'latent'
|
||||||
if not self.is_xl:
|
if not self.is_xl:
|
||||||
raise ValueError("Refiner is only supported for XL models")
|
raise ValueError("Refiner is only supported for XL models")
|
||||||
@@ -526,7 +526,7 @@ class StableDiffusion:
|
|||||||
**extra
|
**extra
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
if refiner_pipeline is not None:
|
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
||||||
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
|
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
|
||||||
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
|
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
|
||||||
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
|
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
|
||||||
@@ -546,7 +546,7 @@ class StableDiffusion:
|
|||||||
num_inference_steps=gen_config.num_inference_steps,
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
guidance_scale=gen_config.guidance_scale,
|
guidance_scale=gen_config.guidance_scale,
|
||||||
guidance_rescale=grs,
|
guidance_rescale=grs,
|
||||||
denoising_start=self.model_config.refiner_start_at,
|
denoising_start=gen_config.refiner_start_at,
|
||||||
denoising_end=gen_config.num_inference_steps,
|
denoising_end=gen_config.num_inference_steps,
|
||||||
image=img.unsqueeze(0)
|
image=img.unsqueeze(0)
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user