mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-23 07:49:24 +00:00
Added additional config options for custom plugins I needed
This commit is contained in:
@@ -26,7 +26,7 @@ from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
||||
@@ -701,11 +701,6 @@ class StableDiffusion:
|
||||
# get scheduler class name
|
||||
scheduler_class_name = self.noise_scheduler.__class__.__name__
|
||||
|
||||
index_noise_schedulers = [
|
||||
'DPMSolverMultistepScheduler',
|
||||
'EulerDiscreteSchedulerOutput',
|
||||
]
|
||||
|
||||
# todo handle if timestep is single value
|
||||
|
||||
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
||||
@@ -770,6 +765,7 @@ class StableDiffusion:
|
||||
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
is_input_scaled=False,
|
||||
detach_unconditional=False,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
@@ -777,11 +773,10 @@ class StableDiffusion:
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
if text_embeddings is None and unconditional_embeddings is not None:
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
text_embeddings = concat_prompt_embeds([
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
1, # batch size
|
||||
)
|
||||
])
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
text_embeddings = conditional_embeddings
|
||||
@@ -932,7 +927,7 @@ class StableDiffusion:
|
||||
with torch.no_grad():
|
||||
if do_classifier_free_guidance:
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = torch.cat([latents] * 2, dim=0)
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
@@ -946,7 +941,7 @@ class StableDiffusion:
|
||||
if ts_bs == 1:
|
||||
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
timestep = torch.cat([timestep] * 2, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
@@ -961,7 +956,9 @@ class StableDiffusion:
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0)
|
||||
if detach_unconditional:
|
||||
noise_pred_uncond = noise_pred_uncond.detach()
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
@@ -973,13 +970,15 @@ class StableDiffusion:
|
||||
|
||||
return noise_pred
|
||||
|
||||
def step_scheduler(self, model_input, latent_input, timestep_tensor):
|
||||
def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None):
|
||||
if noise_scheduler is None:
|
||||
noise_scheduler = self.noise_scheduler
|
||||
# // sometimes they are on the wrong device, no idea why
|
||||
if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler):
|
||||
if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler):
|
||||
try:
|
||||
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch)
|
||||
noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch)
|
||||
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@@ -993,12 +992,12 @@ class StableDiffusion:
|
||||
|
||||
for idx in range(model_input.shape[0]):
|
||||
# Reset it so it is unique for the
|
||||
if hasattr(self.noise_scheduler, '_step_index'):
|
||||
self.noise_scheduler._step_index = None
|
||||
if hasattr(self.noise_scheduler, 'is_scale_input_called'):
|
||||
self.noise_scheduler.is_scale_input_called = True
|
||||
if hasattr(noise_scheduler, '_step_index'):
|
||||
noise_scheduler._step_index = None
|
||||
if hasattr(noise_scheduler, 'is_scale_input_called'):
|
||||
noise_scheduler.is_scale_input_called = True
|
||||
out_chunks.append(
|
||||
self.noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
|
||||
noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
|
||||
0]
|
||||
)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user