mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added ability to do cfg during training. Various bug fixes
This commit is contained in:
@@ -319,6 +319,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
loss = get_guidance_loss(
|
||||
@@ -331,6 +332,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
batch=batch,
|
||||
noise=noise,
|
||||
sd=self.sd,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -618,6 +620,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs
|
||||
):
|
||||
# todo for embeddings, we need to run without trigger words
|
||||
@@ -655,9 +658,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# self.network.multiplier = 0.0
|
||||
self.sd.unet.eval()
|
||||
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
@@ -901,6 +908,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
with self.timer('encode_prompt'):
|
||||
unconditional_embeds = None
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
@@ -909,6 +917,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
if self.train_config.do_cfg:
|
||||
# todo only do one and repeat it
|
||||
unconditional_embeds = self.sd.encode_prompt(
|
||||
["" for _ in range(noisy_latents.shape[0])],
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
with torch.set_grad_enabled(False):
|
||||
# make sure it is in eval mode
|
||||
@@ -923,9 +940,19 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
if self.train_config.do_cfg:
|
||||
# todo only do one and repeat it
|
||||
unconditional_embeds = self.sd.encode_prompt(
|
||||
["" for _ in range(noisy_latents.shape[0])],
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds = unconditional_embeds.detach()
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
@@ -965,21 +992,43 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
drop=True,
|
||||
is_training=True
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
device=self.device_torch, dtype=dtype
|
||||
).detach(),
|
||||
is_training=True,
|
||||
drop=True
|
||||
)
|
||||
elif has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
device=self.device_torch, dtype=dtype
|
||||
).detach(),
|
||||
is_training=True,
|
||||
drop=True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
|
||||
if not self.adapter_config.train_image_encoder:
|
||||
# we are not training the image encoder, so we need to detach the embeds
|
||||
conditional_clip_embeds = conditional_clip_embeds.detach()
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = unconditional_clip_embeds.detach()
|
||||
|
||||
|
||||
with self.timer('encode_adapter'):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds = self.adapter(unconditional_embeds.detach(), unconditional_clip_embeds)
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
||||
# pass in our scheduler
|
||||
@@ -1017,6 +1066,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
unconditional_embeds=unconditional_embeds
|
||||
)
|
||||
|
||||
self.before_unet_predict()
|
||||
@@ -1032,13 +1082,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs=pred_kwargs,
|
||||
batch=batch,
|
||||
noise=noise,
|
||||
unconditional_embeds=unconditional_embeds
|
||||
)
|
||||
|
||||
else:
|
||||
with self.timer('predict_unet'):
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype)
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
|
||||
Reference in New Issue
Block a user