mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added a guidance burning loss. Modified DFE to work with new model. Bug fixes
This commit is contained in:
@@ -61,7 +61,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
||||
self.negative_prompt_pool: Union[List[str], None] = None
|
||||
self.batch_negative_prompt: Union[List[str], None] = None
|
||||
self.cfm_cache = None
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
@@ -84,6 +83,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None
|
||||
|
||||
self.dfe: Optional[DiffusionFeatureExtractor] = None
|
||||
self.unconditional_embeds = None
|
||||
|
||||
if self.train_config.diff_output_preservation:
|
||||
if self.trigger_word is None:
|
||||
@@ -95,6 +95,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# always do a prior prediction when doing diff output preservation
|
||||
self.do_prior_prediction = True
|
||||
|
||||
# store the loss target for a batch so we can use it in a loss
|
||||
self._guidance_loss_target_batch: float = 0.0
|
||||
if isinstance(self.train_config.guidance_loss_target, (int, float)):
|
||||
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target)
|
||||
elif isinstance(self.train_config.guidance_loss_target, list):
|
||||
self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0])
|
||||
else:
|
||||
raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}")
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
@@ -135,6 +144,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
|
||||
# cache unconditional embeds (blank prompt)
|
||||
with torch.no_grad():
|
||||
self.unconditional_embeds = self.sd.encode_prompt(
|
||||
[''],
|
||||
long_prompts=self.do_long_prompts
|
||||
).to(
|
||||
self.device_torch,
|
||||
dtype=self.sd.torch_dtype
|
||||
).detach()
|
||||
|
||||
if self.train_config.do_prior_divergence:
|
||||
self.do_prior_prediction = True
|
||||
# move vae to device if we did not cache latents
|
||||
@@ -476,6 +495,47 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
|
||||
else:
|
||||
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
|
||||
|
||||
if self.train_config.do_guidance_loss:
|
||||
with torch.no_grad():
|
||||
# we make cached blank prompt embeds that match the batch size
|
||||
unconditional_embeds = concat_prompt_embeds(
|
||||
[self.unconditional_embeds] * noisy_latents.shape[0],
|
||||
)
|
||||
cfm_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=unconditional_embeds,
|
||||
unconditional_embeds=None,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
# zero cfg
|
||||
|
||||
# ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557
|
||||
batch_size = target.shape[0]
|
||||
positive_flat = target.view(batch_size, -1)
|
||||
negative_flat = cfm_pred.view(batch_size, -1)
|
||||
# Calculate dot production
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
# Squared norm of uncondition
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
st_star = dot_product / squared_norm
|
||||
|
||||
alpha = st_star
|
||||
|
||||
is_video = len(target.shape) == 5
|
||||
|
||||
alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1)
|
||||
|
||||
guidance_scale = self._guidance_loss_target_batch
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype)
|
||||
guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1)
|
||||
|
||||
unconditional_target = cfm_pred * alpha
|
||||
target = unconditional_target + guidance_scale * (target - unconditional_target)
|
||||
|
||||
|
||||
if target is None:
|
||||
@@ -895,6 +955,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
guidance_embedding_scale = self.train_config.cfg_scale
|
||||
if self.train_config.do_guidance_loss:
|
||||
guidance_embedding_scale = self._guidance_loss_target_batch
|
||||
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
@@ -902,6 +966,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=guidance_embedding_scale,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
batch=batch,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
@@ -945,13 +1010,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
guidance_embedding_scale = self.train_config.cfg_scale
|
||||
if self.train_config.do_guidance_loss:
|
||||
guidance_embedding_scale = self._guidance_loss_target_batch
|
||||
return self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=self.train_config.cfg_scale,
|
||||
guidance_embedding_scale=guidance_embedding_scale,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
|
||||
@@ -959,80 +1027,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def cfm_augment_tensors(
|
||||
self,
|
||||
images: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if self.cfm_cache is None:
|
||||
# flip the current one. Only need this for first time
|
||||
self.cfm_cache = torch.flip(images, [3]).clone()
|
||||
augmented_tensor_list = []
|
||||
for i in range(images.shape[0]):
|
||||
# get a random one
|
||||
idx = random.randint(0, self.cfm_cache.shape[0] - 1)
|
||||
augmented_tensor_list.append(self.cfm_cache[idx:idx + 1])
|
||||
augmented = torch.cat(augmented_tensor_list, dim=0)
|
||||
# resize to match the input
|
||||
augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear')
|
||||
self.cfm_cache = images.clone()
|
||||
return augmented
|
||||
|
||||
def get_cfm_loss(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
noise_pred: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
timesteps: torch.Tensor,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
alpha: float = 0.1,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
if hasattr(self.sd, 'get_loss_target'):
|
||||
target = self.sd.get_loss_target(
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
timesteps=timesteps,
|
||||
).detach()
|
||||
|
||||
elif self.sd.is_flow_matching:
|
||||
# forward ODE
|
||||
target = (noise - batch.latents).detach()
|
||||
else:
|
||||
raise ValueError("CFM loss only works with flow matching")
|
||||
fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
with torch.no_grad():
|
||||
# we need to compute the contrast
|
||||
cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype)
|
||||
cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype)
|
||||
cfm_noisy_latents = self.sd.add_noise(
|
||||
original_samples=cfm_latents,
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
cfm_pred = self.predict_noise(
|
||||
noisy_latents=cfm_noisy_latents,
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=None,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
# v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1)
|
||||
# v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W)
|
||||
|
||||
# # Compute cosine similarity at each pixel
|
||||
# sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W)
|
||||
|
||||
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
# Compute cosine similarity at each pixel
|
||||
sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W)
|
||||
|
||||
# Average over spatial dimensions, then batch
|
||||
contrastive_loss = -sim.mean()
|
||||
|
||||
loss = fm_loss.mean() + alpha * contrastive_loss
|
||||
return loss
|
||||
|
||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||
with torch.no_grad():
|
||||
@@ -1658,6 +1652,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
)
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||
|
||||
if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list):
|
||||
batch_size = noisy_latents.shape[0]
|
||||
# update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1]
|
||||
self._guidance_loss_target_batch = [
|
||||
random.uniform(
|
||||
self.train_config.guidance_loss_target[0],
|
||||
self.train_config.guidance_loss_target[1]
|
||||
) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
self.before_unet_predict()
|
||||
|
||||
@@ -1757,25 +1761,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
||||
prior_to_calculate_loss = None
|
||||
|
||||
if self.train_config.loss_type == 'cfm':
|
||||
loss = self.get_cfm_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
noise=noise,
|
||||
noise_pred=noise_pred,
|
||||
conditional_embeds=conditional_embeds,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
)
|
||||
else:
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
|
||||
if self.train_config.diff_output_preservation:
|
||||
# send the loss backwards otherwise checkpointing will fail
|
||||
|
||||
Reference in New Issue
Block a user