mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added reference adapters, many bug fixes, more ip adapter work and customizability
This commit is contained in:
@@ -15,6 +15,7 @@ from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
||||
apply_learnable_snr_gos, LearnableSNRGamma
|
||||
@@ -285,9 +286,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if torch.isnan(prior_loss).any():
|
||||
raise ValueError("Prior loss is nan")
|
||||
|
||||
# prior_loss = prior_loss.mean([1, 2, 3])
|
||||
loss = loss + prior_loss
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
# loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
if prior_loss is not None:
|
||||
loss = loss + prior_loss
|
||||
|
||||
if not self.train_config.train_turbo:
|
||||
if self.train_config.learnable_snr_gos:
|
||||
@@ -623,11 +626,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.network is not None:
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
is_ip_adapter = False
|
||||
was_ip_adapter_active = False
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter):
|
||||
is_ip_adapter = True
|
||||
was_ip_adapter_active = self.adapter.is_active
|
||||
can_disable_adapter = False
|
||||
was_adapter_active = False
|
||||
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter)):
|
||||
can_disable_adapter = True
|
||||
was_adapter_active = self.adapter.is_active
|
||||
self.adapter.is_active = False
|
||||
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
@@ -666,8 +669,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_intrablock_additional_residuals']
|
||||
|
||||
if is_ip_adapter:
|
||||
self.adapter.is_active = was_ip_adapter_active
|
||||
if can_disable_adapter:
|
||||
self.adapter.is_active = was_adapter_active
|
||||
# restore network
|
||||
# self.network.multiplier = network_weight_list
|
||||
if self.network is not None:
|
||||
@@ -950,12 +953,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
if has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
)
|
||||
elif is_reg:
|
||||
if is_reg:
|
||||
# we will zero it out in the img embedder
|
||||
clip_images = torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
@@ -967,6 +965,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
drop=True,
|
||||
is_training=True
|
||||
)
|
||||
elif has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
|
||||
@@ -978,12 +981,26 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with self.timer('encode_adapter'):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
||||
# pass in our scheduler
|
||||
self.adapter.noise_scheduler = self.lr_scheduler
|
||||
if has_clip_image or has_adapter_img:
|
||||
img_to_use = clip_images if has_clip_image else adapter_images
|
||||
# currently 0-1 needs to be -1 to 1
|
||||
reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype)
|
||||
self.adapter.set_reference_images(reference_images)
|
||||
self.adapter.noise_scheduler = self.sd.noise_scheduler
|
||||
elif is_reg:
|
||||
self.adapter.set_blank_reference_images(noisy_latents.shape[0])
|
||||
else:
|
||||
self.adapter.set_reference_images(None)
|
||||
|
||||
prior_pred = None
|
||||
|
||||
do_reg_prior = False
|
||||
if is_reg and (self.network is not None or self.adapter is not None):
|
||||
# we are doing a reg image and we have a network or adapter
|
||||
do_reg_prior = True
|
||||
# if is_reg and (self.network is not None or self.adapter is not None):
|
||||
# # we are doing a reg image and we have a network or adapter
|
||||
# do_reg_prior = True
|
||||
|
||||
do_inverted_masked_prior = False
|
||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||
|
||||
Reference in New Issue
Block a user