mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added ability to do cfg during training. Various bug fixes
This commit is contained in:
@@ -319,6 +319,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
pred_kwargs: dict,
|
pred_kwargs: dict,
|
||||||
batch: 'DataLoaderBatchDTO',
|
batch: 'DataLoaderBatchDTO',
|
||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
loss = get_guidance_loss(
|
loss = get_guidance_loss(
|
||||||
@@ -331,6 +332,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
sd=self.sd,
|
sd=self.sd,
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -618,6 +620,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
pred_kwargs: dict,
|
pred_kwargs: dict,
|
||||||
batch: 'DataLoaderBatchDTO',
|
batch: 'DataLoaderBatchDTO',
|
||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# todo for embeddings, we need to run without trigger words
|
# todo for embeddings, we need to run without trigger words
|
||||||
@@ -655,9 +658,13 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# self.network.multiplier = 0.0
|
# self.network.multiplier = 0.0
|
||||||
self.sd.unet.eval()
|
self.sd.unet.eval()
|
||||||
|
|
||||||
|
if unconditional_embeds is not None:
|
||||||
|
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||||
|
|
||||||
prior_pred = self.sd.predict_noise(
|
prior_pred = self.sd.predict_noise(
|
||||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||||
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
|
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
|
||||||
|
unconditional_embeddings=unconditional_embeds,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
@@ -901,6 +908,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.adapter(conditional_clip_embeds)
|
self.adapter(conditional_clip_embeds)
|
||||||
|
|
||||||
with self.timer('encode_prompt'):
|
with self.timer('encode_prompt'):
|
||||||
|
unconditional_embeds = None
|
||||||
if grad_on_text_encoder:
|
if grad_on_text_encoder:
|
||||||
with torch.set_grad_enabled(True):
|
with torch.set_grad_enabled(True):
|
||||||
conditional_embeds = self.sd.encode_prompt(
|
conditional_embeds = self.sd.encode_prompt(
|
||||||
@@ -909,6 +917,15 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
long_prompts=self.do_long_prompts).to(
|
long_prompts=self.do_long_prompts).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
|
if self.train_config.do_cfg:
|
||||||
|
# todo only do one and repeat it
|
||||||
|
unconditional_embeds = self.sd.encode_prompt(
|
||||||
|
["" for _ in range(noisy_latents.shape[0])],
|
||||||
|
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||||
|
long_prompts=self.do_long_prompts).to(
|
||||||
|
self.device_torch,
|
||||||
|
dtype=dtype)
|
||||||
else:
|
else:
|
||||||
with torch.set_grad_enabled(False):
|
with torch.set_grad_enabled(False):
|
||||||
# make sure it is in eval mode
|
# make sure it is in eval mode
|
||||||
@@ -923,9 +940,19 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
long_prompts=self.do_long_prompts).to(
|
long_prompts=self.do_long_prompts).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
if self.train_config.do_cfg:
|
||||||
|
# todo only do one and repeat it
|
||||||
|
unconditional_embeds = self.sd.encode_prompt(
|
||||||
|
["" for _ in range(noisy_latents.shape[0])],
|
||||||
|
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||||
|
long_prompts=self.do_long_prompts).to(
|
||||||
|
self.device_torch,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
# detach the embeddings
|
# detach the embeddings
|
||||||
conditional_embeds = conditional_embeds.detach()
|
conditional_embeds = conditional_embeds.detach()
|
||||||
|
if self.train_config.do_cfg:
|
||||||
|
unconditional_embeds = unconditional_embeds.detach()
|
||||||
|
|
||||||
# flush()
|
# flush()
|
||||||
pred_kwargs = {}
|
pred_kwargs = {}
|
||||||
@@ -965,21 +992,43 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
drop=True,
|
drop=True,
|
||||||
is_training=True
|
is_training=True
|
||||||
)
|
)
|
||||||
|
if self.train_config.do_cfg:
|
||||||
|
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||||
|
torch.zeros(
|
||||||
|
(noisy_latents.shape[0], 3, 512, 512),
|
||||||
|
device=self.device_torch, dtype=dtype
|
||||||
|
).detach(),
|
||||||
|
is_training=True,
|
||||||
|
drop=True
|
||||||
|
)
|
||||||
elif has_clip_image:
|
elif has_clip_image:
|
||||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||||
is_training=True
|
is_training=True
|
||||||
)
|
)
|
||||||
|
if self.train_config.do_cfg:
|
||||||
|
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||||
|
torch.zeros(
|
||||||
|
(noisy_latents.shape[0], 3, 512, 512),
|
||||||
|
device=self.device_torch, dtype=dtype
|
||||||
|
).detach(),
|
||||||
|
is_training=True,
|
||||||
|
drop=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||||
|
|
||||||
if not self.adapter_config.train_image_encoder:
|
if not self.adapter_config.train_image_encoder:
|
||||||
# we are not training the image encoder, so we need to detach the embeds
|
# we are not training the image encoder, so we need to detach the embeds
|
||||||
conditional_clip_embeds = conditional_clip_embeds.detach()
|
conditional_clip_embeds = conditional_clip_embeds.detach()
|
||||||
|
if self.train_config.do_cfg:
|
||||||
|
unconditional_clip_embeds = unconditional_clip_embeds.detach()
|
||||||
|
|
||||||
|
|
||||||
with self.timer('encode_adapter'):
|
with self.timer('encode_adapter'):
|
||||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||||
|
if self.train_config.do_cfg:
|
||||||
|
unconditional_embeds = self.adapter(unconditional_embeds.detach(), unconditional_clip_embeds)
|
||||||
|
|
||||||
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
||||||
# pass in our scheduler
|
# pass in our scheduler
|
||||||
@@ -1017,6 +1066,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
pred_kwargs=pred_kwargs,
|
pred_kwargs=pred_kwargs,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
unconditional_embeds=unconditional_embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
self.before_unet_predict()
|
self.before_unet_predict()
|
||||||
@@ -1032,13 +1082,17 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
pred_kwargs=pred_kwargs,
|
pred_kwargs=pred_kwargs,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
|
unconditional_embeds=unconditional_embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
with self.timer('predict_unet'):
|
with self.timer('predict_unet'):
|
||||||
|
if unconditional_embeds is not None:
|
||||||
|
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype)
|
||||||
noise_pred = self.sd.predict_noise(
|
noise_pred = self.sd.predict_noise(
|
||||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||||
|
unconditional_embeddings=unconditional_embeds,
|
||||||
timestep=timesteps,
|
timestep=timesteps,
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
**pred_kwargs
|
**pred_kwargs
|
||||||
|
|||||||
@@ -156,6 +156,7 @@ class AdapterConfig:
|
|||||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
|
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
|
||||||
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
|
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
|
||||||
self.safe_channels: int = kwargs.get('safe_channels', 2048)
|
self.safe_channels: int = kwargs.get('safe_channels', 2048)
|
||||||
|
self.safe_tokens: int = kwargs.get('safe_tokens', 8)
|
||||||
|
|
||||||
# clip vision
|
# clip vision
|
||||||
self.trigger = kwargs.get('trigger', 'tri993r')
|
self.trigger = kwargs.get('trigger', 'tri993r')
|
||||||
@@ -270,6 +271,7 @@ class TrainConfig:
|
|||||||
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
|
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
|
||||||
|
|
||||||
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
|
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
|
||||||
|
self.do_cfg = kwargs.get('do_cfg', False)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Literal
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from toolkit.basic import value_map
|
from toolkit.basic import value_map
|
||||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
@@ -193,6 +193,7 @@ def get_direct_guidance_loss(
|
|||||||
batch: 'DataLoaderBatchDTO',
|
batch: 'DataLoaderBatchDTO',
|
||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
sd: 'StableDiffusion',
|
sd: 'StableDiffusion',
|
||||||
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -222,9 +223,14 @@ def get_direct_guidance_loss(
|
|||||||
|
|
||||||
# sd.network.multiplier = network_weight_list
|
# sd.network.multiplier = network_weight_list
|
||||||
# do our prediction with LoRA active on the scaled guidance latents
|
# do our prediction with LoRA active on the scaled guidance latents
|
||||||
|
if unconditional_embeds is not None:
|
||||||
|
unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach()
|
||||||
|
unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
|
||||||
|
|
||||||
prediction = sd.predict_noise(
|
prediction = sd.predict_noise(
|
||||||
latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(),
|
latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(),
|
||||||
conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(),
|
conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(),
|
||||||
|
unconditional_embeddings=unconditional_embeds,
|
||||||
timestep=torch.cat([timesteps, timesteps]),
|
timestep=torch.cat([timesteps, timesteps]),
|
||||||
guidance_scale=1.0,
|
guidance_scale=1.0,
|
||||||
**pred_kwargs # adapter residuals in here
|
**pred_kwargs # adapter residuals in here
|
||||||
@@ -482,12 +488,14 @@ def get_guidance_loss(
|
|||||||
batch: 'DataLoaderBatchDTO',
|
batch: 'DataLoaderBatchDTO',
|
||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
sd: 'StableDiffusion',
|
sd: 'StableDiffusion',
|
||||||
|
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# TODO add others and process individual batch items separately
|
# TODO add others and process individual batch items separately
|
||||||
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
|
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
|
||||||
|
|
||||||
if guidance_type == "targeted":
|
if guidance_type == "targeted":
|
||||||
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance"
|
||||||
return get_targeted_guidance_loss(
|
return get_targeted_guidance_loss(
|
||||||
noisy_latents,
|
noisy_latents,
|
||||||
conditional_embeds,
|
conditional_embeds,
|
||||||
@@ -501,6 +509,7 @@ def get_guidance_loss(
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
elif guidance_type == "polarity":
|
elif guidance_type == "polarity":
|
||||||
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
|
||||||
return get_guided_loss_polarity(
|
return get_guided_loss_polarity(
|
||||||
noisy_latents,
|
noisy_latents,
|
||||||
conditional_embeds,
|
conditional_embeds,
|
||||||
@@ -515,6 +524,7 @@ def get_guidance_loss(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif guidance_type == "targeted_polarity":
|
elif guidance_type == "targeted_polarity":
|
||||||
|
assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"
|
||||||
return get_targeted_polarity_loss(
|
return get_targeted_polarity_loss(
|
||||||
noisy_latents,
|
noisy_latents,
|
||||||
conditional_embeds,
|
conditional_embeds,
|
||||||
@@ -538,6 +548,7 @@ def get_guidance_loss(
|
|||||||
batch,
|
batch,
|
||||||
noise,
|
noise,
|
||||||
sd,
|
sd,
|
||||||
|
unconditional_embeds=unconditional_embeds,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
self.clip_image_processor = SAFEImageProcessor()
|
self.clip_image_processor = SAFEImageProcessor()
|
||||||
self.image_encoder = SAFEVisionModel(
|
self.image_encoder = SAFEVisionModel(
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
num_tokens=8,
|
num_tokens=self.config.safe_tokens,
|
||||||
num_vectors=sd.unet.config['cross_attention_dim'],
|
num_vectors=sd.unet.config['cross_attention_dim'],
|
||||||
reducer_channels=self.config.safe_reducer_channels,
|
reducer_channels=self.config.safe_reducer_channels,
|
||||||
channels=self.config.safe_channels,
|
channels=self.config.safe_channels,
|
||||||
@@ -234,8 +234,8 @@ class IPAdapter(torch.nn.Module):
|
|||||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1]
|
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1]
|
||||||
|
|
||||||
if self.config.image_encoder_arch == 'safe':
|
# if self.config.image_encoder_arch == 'safe':
|
||||||
embedding_dim = self.config.safe_channels
|
# embedding_dim = self.config.safe_tokens
|
||||||
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
|
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
|
||||||
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
|
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
|
||||||
# ip-adapter-plus
|
# ip-adapter-plus
|
||||||
|
|||||||
Reference in New Issue
Block a user