mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added additional config options for custom plugins I needed
This commit is contained in:
@@ -802,7 +802,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.train_config.train_text_encoder:
|
if self.train_config.train_text_encoder:
|
||||||
grad_on_text_encoder = True
|
grad_on_text_encoder = True
|
||||||
|
|
||||||
if self.embedding:
|
if self.embedding is not None:
|
||||||
grad_on_text_encoder = True
|
grad_on_text_encoder = True
|
||||||
|
|
||||||
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
||||||
@@ -1095,13 +1095,14 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
else:
|
else:
|
||||||
with self.timer('predict_unet'):
|
with self.timer('predict_unet'):
|
||||||
if unconditional_embeds is not None:
|
if unconditional_embeds is not None:
|
||||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype)
|
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||||
noise_pred = self.sd.predict_noise(
|
noise_pred = self.sd.predict_noise(
|
||||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||||
unconditional_embeddings=unconditional_embeds,
|
unconditional_embeddings=unconditional_embeds,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
guidance_scale=self.train_config.cfg_scale,
|
guidance_scale=self.train_config.cfg_scale,
|
||||||
|
detach_unconditional=False,
|
||||||
**pred_kwargs
|
**pred_kwargs
|
||||||
)
|
)
|
||||||
self.after_unet_predict()
|
self.after_unet_predict()
|
||||||
|
|||||||
@@ -333,6 +333,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# remove all but the latest max_step_saves_to_keep
|
# remove all but the latest max_step_saves_to_keep
|
||||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||||
|
|
||||||
|
# remove duplicates
|
||||||
|
items_to_remove = list(dict.fromkeys(items_to_remove))
|
||||||
|
|
||||||
for item in items_to_remove:
|
for item in items_to_remove:
|
||||||
self.print(f"Removing old save: {item}")
|
self.print(f"Removing old save: {item}")
|
||||||
if os.path.isdir(item):
|
if os.path.isdir(item):
|
||||||
@@ -758,7 +761,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
do_double = False
|
do_double = False
|
||||||
|
|
||||||
with self.timer('prepare_noise'):
|
with self.timer('prepare_noise'):
|
||||||
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
|
num_train_timesteps = self.train_config.num_train_timesteps
|
||||||
|
|
||||||
if self.train_config.noise_scheduler in ['custom_lcm']:
|
if self.train_config.noise_scheduler in ['custom_lcm']:
|
||||||
# we store this value on our custom one
|
# we store this value on our custom one
|
||||||
@@ -791,14 +794,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
||||||
|
|
||||||
if content_or_style == 'content':
|
if content_or_style == 'content':
|
||||||
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps
|
||||||
elif content_or_style == 'style':
|
elif content_or_style == 'style':
|
||||||
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps
|
||||||
|
|
||||||
timestep_indices = value_map(
|
timestep_indices = value_map(
|
||||||
timestep_indices,
|
timestep_indices,
|
||||||
0,
|
0,
|
||||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
self.train_config.num_train_timesteps - 1,
|
||||||
min_noise_steps,
|
min_noise_steps,
|
||||||
max_noise_steps - 1
|
max_noise_steps - 1
|
||||||
)
|
)
|
||||||
@@ -1234,6 +1237,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# load last saved weights
|
# load last saved weights
|
||||||
if latest_save_path is not None:
|
if latest_save_path is not None:
|
||||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||||
|
if self.embedding.step > 1:
|
||||||
|
self.step_num = self.embedding.step
|
||||||
|
self.start_step = self.step_num
|
||||||
|
|
||||||
# self.step_num = self.embedding.step
|
# self.step_num = self.embedding.step
|
||||||
# self.start_step = self.step_num
|
# self.start_step = self.step_num
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ class TrainConfig:
|
|||||||
self.negative_prompt = kwargs.get('negative_prompt', None)
|
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||||
# multiplier applied to loos on regularization images
|
# multiplier applied to loos on regularization images
|
||||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||||
|
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
||||||
|
|
||||||
# dropout that happens before encoding. It functions independently per text encoder
|
# dropout that happens before encoding. It functions independently per text encoder
|
||||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||||
|
|||||||
@@ -610,9 +610,6 @@ class ClipImageFileItemDTOMixin:
|
|||||||
if self.dataset_config.clip_image_shuffle_augmentations:
|
if self.dataset_config.clip_image_shuffle_augmentations:
|
||||||
self.build_clip_imag_augmentation_transform()
|
self.build_clip_imag_augmentation_transform()
|
||||||
|
|
||||||
# save the original tensor
|
|
||||||
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
|
|
||||||
|
|
||||||
open_cv_image = np.array(img)
|
open_cv_image = np.array(img)
|
||||||
# Convert RGB to BGR
|
# Convert RGB to BGR
|
||||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||||
|
|||||||
@@ -413,6 +413,7 @@ def get_guided_loss_polarity(
|
|||||||
device = sd.device_torch
|
device = sd.device_torch
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
dtype = get_torch_dtype(dtype)
|
dtype = get_torch_dtype(dtype)
|
||||||
|
noise = noise.to(device, dtype=dtype).detach()
|
||||||
|
|
||||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
|||||||
use_bias: bool = False,
|
use_bias: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
self.can_merge_in = True
|
||||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||||
ToolkitModuleMixin.__init__(self, network=network)
|
ToolkitModuleMixin.__init__(self, network=network)
|
||||||
torch.nn.Module.__init__(self)
|
torch.nn.Module.__init__(self)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from toolkit import train_tools
|
|||||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
||||||
from toolkit.reference_adapter import ReferenceAdapter
|
from toolkit.reference_adapter import ReferenceAdapter
|
||||||
from toolkit.sampler import get_sampler
|
from toolkit.sampler import get_sampler
|
||||||
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
||||||
@@ -701,11 +701,6 @@ class StableDiffusion:
|
|||||||
# get scheduler class name
|
# get scheduler class name
|
||||||
scheduler_class_name = self.noise_scheduler.__class__.__name__
|
scheduler_class_name = self.noise_scheduler.__class__.__name__
|
||||||
|
|
||||||
index_noise_schedulers = [
|
|
||||||
'DPMSolverMultistepScheduler',
|
|
||||||
'EulerDiscreteSchedulerOutput',
|
|
||||||
]
|
|
||||||
|
|
||||||
# todo handle if timestep is single value
|
# todo handle if timestep is single value
|
||||||
|
|
||||||
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
||||||
@@ -770,6 +765,7 @@ class StableDiffusion:
|
|||||||
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||||
is_input_scaled=False,
|
is_input_scaled=False,
|
||||||
|
detach_unconditional=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -777,11 +773,10 @@ class StableDiffusion:
|
|||||||
if text_embeddings is None and conditional_embeddings is None:
|
if text_embeddings is None and conditional_embeddings is None:
|
||||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||||
if text_embeddings is None and unconditional_embeddings is not None:
|
if text_embeddings is None and unconditional_embeddings is not None:
|
||||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
text_embeddings = concat_prompt_embeds([
|
||||||
unconditional_embeddings, # negative embedding
|
unconditional_embeddings, # negative embedding
|
||||||
conditional_embeddings, # positive embedding
|
conditional_embeddings, # positive embedding
|
||||||
1, # batch size
|
])
|
||||||
)
|
|
||||||
elif text_embeddings is None and conditional_embeddings is not None:
|
elif text_embeddings is None and conditional_embeddings is not None:
|
||||||
# not doing cfg
|
# not doing cfg
|
||||||
text_embeddings = conditional_embeddings
|
text_embeddings = conditional_embeddings
|
||||||
@@ -932,7 +927,7 @@ class StableDiffusion:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
# if we are doing classifier free guidance, need to double up
|
# if we are doing classifier free guidance, need to double up
|
||||||
latent_model_input = torch.cat([latents] * 2)
|
latent_model_input = torch.cat([latents] * 2, dim=0)
|
||||||
timestep = torch.cat([timestep] * 2)
|
timestep = torch.cat([timestep] * 2)
|
||||||
else:
|
else:
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
@@ -946,7 +941,7 @@ class StableDiffusion:
|
|||||||
if ts_bs == 1:
|
if ts_bs == 1:
|
||||||
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
||||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||||
timestep = torch.cat([timestep] * 2)
|
timestep = torch.cat([timestep] * 2, dim=0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||||
@@ -961,7 +956,9 @@ class StableDiffusion:
|
|||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
# perform guidance
|
# perform guidance
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0)
|
||||||
|
if detach_unconditional:
|
||||||
|
noise_pred_uncond = noise_pred_uncond.detach()
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
noise_pred_text - noise_pred_uncond
|
noise_pred_text - noise_pred_uncond
|
||||||
)
|
)
|
||||||
@@ -973,13 +970,15 @@ class StableDiffusion:
|
|||||||
|
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
def step_scheduler(self, model_input, latent_input, timestep_tensor):
|
def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None):
|
||||||
|
if noise_scheduler is None:
|
||||||
|
noise_scheduler = self.noise_scheduler
|
||||||
# // sometimes they are on the wrong device, no idea why
|
# // sometimes they are on the wrong device, no idea why
|
||||||
if isinstance(self.noise_scheduler, DDPMScheduler) or isinstance(self.noise_scheduler, LCMScheduler):
|
if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler):
|
||||||
try:
|
try:
|
||||||
self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch)
|
||||||
self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch)
|
||||||
self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -993,12 +992,12 @@ class StableDiffusion:
|
|||||||
|
|
||||||
for idx in range(model_input.shape[0]):
|
for idx in range(model_input.shape[0]):
|
||||||
# Reset it so it is unique for the
|
# Reset it so it is unique for the
|
||||||
if hasattr(self.noise_scheduler, '_step_index'):
|
if hasattr(noise_scheduler, '_step_index'):
|
||||||
self.noise_scheduler._step_index = None
|
noise_scheduler._step_index = None
|
||||||
if hasattr(self.noise_scheduler, 'is_scale_input_called'):
|
if hasattr(noise_scheduler, 'is_scale_input_called'):
|
||||||
self.noise_scheduler.is_scale_input_called = True
|
noise_scheduler.is_scale_input_called = True
|
||||||
out_chunks.append(
|
out_chunks.append(
|
||||||
self.noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
|
noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
|
||||||
0]
|
0]
|
||||||
)
|
)
|
||||||
return torch.cat(out_chunks, dim=0)
|
return torch.cat(out_chunks, dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user