mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added additional config options for custom plugins I needed
This commit is contained in:
@@ -802,7 +802,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.embedding:
|
||||
if self.embedding is not None:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
||||
@@ -1095,13 +1095,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
else:
|
||||
with self.timer('predict_unet'):
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype)
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
detach_unconditional=False,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
@@ -333,6 +333,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
|
||||
# remove duplicates
|
||||
items_to_remove = list(dict.fromkeys(items_to_remove))
|
||||
|
||||
for item in items_to_remove:
|
||||
self.print(f"Removing old save: {item}")
|
||||
if os.path.isdir(item):
|
||||
@@ -758,7 +761,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
do_double = False
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
num_train_timesteps = self.train_config.num_train_timesteps
|
||||
|
||||
if self.train_config.noise_scheduler in ['custom_lcm']:
|
||||
# we store this value on our custom one
|
||||
@@ -791,14 +794,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if content_or_style == 'content':
|
||||
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps
|
||||
elif content_or_style == 'style':
|
||||
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps
|
||||
|
||||
timestep_indices = value_map(
|
||||
timestep_indices,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.num_train_timesteps - 1,
|
||||
min_noise_steps,
|
||||
max_noise_steps - 1
|
||||
)
|
||||
@@ -1234,6 +1237,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
if self.embedding.step > 1:
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
# self.step_num = self.embedding.step
|
||||
# self.start_step = self.step_num
|
||||
|
||||
@@ -222,6 +222,7 @@ class TrainConfig:
|
||||
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||
# multiplier applied to loos on regularization images
|
||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
||||
|
||||
# dropout that happens before encoding. It functions independently per text encoder
|
||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||
|
||||
@@ -610,9 +610,6 @@ class ClipImageFileItemDTOMixin:
|
||||
if self.dataset_config.clip_image_shuffle_augmentations:
|
||||
self.build_clip_imag_augmentation_transform()
|
||||
|
||||
# save the original tensor
|
||||
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
|
||||
|
||||
open_cv_image = np.array(img)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
@@ -413,6 +413,7 @@ def get_guided_loss_polarity(
|
||||
device = sd.device_torch
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(dtype)
|
||||
noise = noise.to(device, dtype=dtype).detach()
|
||||
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
@@ -51,6 +51,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
use_bias: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
self.can_merge_in = True
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
ToolkitModuleMixin.__init__(self, network=network)
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
@@ -26,7 +26,7 @@ from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
||||
@@ -701,11 +701,6 @@ class StableDiffusion:
|
||||
# get scheduler class name
|
||||
scheduler_class_name = self.noise_scheduler.__class__.__name__
|
||||
|
||||
index_noise_schedulers = [
|
||||
'DPMSolverMultistepScheduler',
|
||||
'EulerDiscreteSchedulerOutput',
|
||||
]
|
||||
|
||||
# todo handle if timestep is single value
|
||||
|
||||
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
||||
@@ -770,6 +765,7 @@ class StableDiffusion:
|
||||
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
is_input_scaled=False,
|
||||
detach_unconditional=False,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
@@ -777,11 +773,10 @@ class StableDiffusion:
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
if text_embeddings is None and unconditional_embeddings is not None:
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
text_embeddings = concat_prompt_embeds([
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
1, # batch size
|
||||
)
|
||||
])
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
text_embeddings = conditional_embeddings
|
||||
@@ -932,7 +927,7 @@ class StableDiffusion:
|
||||
with torch.no_grad():
|
||||
if do_classifier_free_guidance:
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = torch.cat([latents] * 2, dim=0)
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
@@ -946,7 +941,7 @@ class StableDiffusion:
|
||||
if ts_bs == 1:
|
||||
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
timestep = torch.cat([timestep] * 2, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
@@ -961,7 +956,9 @@ class StableDiffusion:
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0)
|
||||
if detach_unconditional:
|
||||
noise_pred_uncond = noise_pred_uncond.detach()
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
@@ -973,13 +970,15 @@ class StableDiffusion:
|
||||
|
||||
return noise_pred
|
||||
|
||||
def step_scheduler(self, model_input, latent_input, timestep_tensor):
|
||||
def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None):
|
||||
if noise_scheduler is None:
|
||||
noise_scheduler = self.noise_scheduler
|
||||
# // sometimes they are on the wrong device, no idea why
|
||||
if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler):
|
||||
if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler):
|
||||
try:
|
||||
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
||||
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch)
|
||||
noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch)
|
||||
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@@ -993,12 +992,12 @@ class StableDiffusion:
|
||||
|
||||
for idx in range(model_input.shape[0]):
|
||||
# Reset it so it is unique for the
|
||||
if hasattr(self.noise_scheduler, '_step_index'):
|
||||
self.noise_scheduler._step_index = None
|
||||
if hasattr(self.noise_scheduler, 'is_scale_input_called'):
|
||||
self.noise_scheduler.is_scale_input_called = True
|
||||
if hasattr(noise_scheduler, '_step_index'):
|
||||
noise_scheduler._step_index = None
|
||||
if hasattr(noise_scheduler, 'is_scale_input_called'):
|
||||
noise_scheduler.is_scale_input_called = True
|
||||
out_chunks.append(
|
||||
self.noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
|
||||
noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
|
||||
0]
|
||||
)
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user