mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-21 06:49:08 +00:00
Added some experimental training techniques. Ignore for now. Still in testing.
This commit is contained in:
@@ -35,6 +35,7 @@ import math
|
||||
from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -60,6 +61,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
||||
self.negative_prompt_pool: Union[List[str], None] = None
|
||||
self.batch_negative_prompt: Union[List[str], None] = None
|
||||
self.cfm_cache = None
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
@@ -197,7 +199,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
flush()
|
||||
|
||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||
vae = None
|
||||
vae = self.sd.vae
|
||||
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
||||
# vae = self.sd.vae
|
||||
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
|
||||
@@ -756,13 +758,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pass
|
||||
|
||||
def predict_noise(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
timesteps: Union[int, torch.Tensor] = 1,
|
||||
conditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
batch: Optional['DataLoaderBatchDTO'] = None,
|
||||
**kwargs,
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
timesteps: Union[int, torch.Tensor] = 1,
|
||||
conditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
batch: Optional['DataLoaderBatchDTO'] = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
return self.sd.predict_noise(
|
||||
@@ -778,6 +780,81 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
batch=batch,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def cfm_augment_tensors(
|
||||
self,
|
||||
images: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if self.cfm_cache is None:
|
||||
# flip the current one. Only need this for first time
|
||||
self.cfm_cache = torch.flip(images, [3]).clone()
|
||||
augmented_tensor_list = []
|
||||
for i in range(images.shape[0]):
|
||||
# get a random one
|
||||
idx = random.randint(0, self.cfm_cache.shape[0] - 1)
|
||||
augmented_tensor_list.append(self.cfm_cache[idx:idx + 1])
|
||||
augmented = torch.cat(augmented_tensor_list, dim=0)
|
||||
# resize to match the input
|
||||
augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear')
|
||||
self.cfm_cache = images.clone()
|
||||
return augmented
|
||||
|
||||
def get_cfm_loss(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
noise_pred: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
timesteps: torch.Tensor,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
alpha: float = 0.1,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
if hasattr(self.sd, 'get_loss_target'):
|
||||
target = self.sd.get_loss_target(
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
timesteps=timesteps,
|
||||
).detach()
|
||||
|
||||
elif self.sd.is_flow_matching:
|
||||
# forward ODE
|
||||
target = (noise - batch.latents).detach()
|
||||
else:
|
||||
raise ValueError("CFM loss only works with flow matching")
|
||||
fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
with torch.no_grad():
|
||||
# we need to compute the contrast
|
||||
cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype)
|
||||
cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype)
|
||||
cfm_noisy_latents = self.sd.add_noise(
|
||||
original_samples=cfm_latents,
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
cfm_pred = self.predict_noise(
|
||||
noisy_latents=cfm_noisy_latents,
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=None,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
# v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1)
|
||||
# v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W)
|
||||
|
||||
# # Compute cosine similarity at each pixel
|
||||
# sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W)
|
||||
|
||||
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
# Compute cosine similarity at each pixel
|
||||
sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W)
|
||||
|
||||
# Average over spatial dimensions, then batch
|
||||
contrastive_loss = -sim.mean()
|
||||
|
||||
loss = fm_loss.mean() + alpha * contrastive_loss
|
||||
return loss
|
||||
|
||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||
self.timer.start('preprocess_batch')
|
||||
@@ -1431,6 +1508,44 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
|
||||
|
||||
if self.train_config.timestep_type == 'next_sample':
|
||||
with self.timer('next_sample_step'):
|
||||
with torch.no_grad():
|
||||
|
||||
stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps]
|
||||
stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies]
|
||||
stepped_timesteps = torch.stack(stepped_timesteps, dim=0)
|
||||
|
||||
# do a sample at the current timestep and step it, then determine new noise
|
||||
next_sample_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
**pred_kwargs
|
||||
)
|
||||
stepped_latents = self.sd.step_scheduler(
|
||||
next_sample_pred,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
self.sd.noise_scheduler
|
||||
)
|
||||
# stepped latents is our new noisy latents. Now we need to determine noise in the current sample
|
||||
noisy_latents = stepped_latents
|
||||
original_samples = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
# todo calc next timestep, for now this may work as it
|
||||
t_01 = (stepped_timesteps / 1000).to(original_samples.device)
|
||||
if len(stepped_latents.shape) == 4:
|
||||
t_01 = t_01.view(-1, 1, 1, 1)
|
||||
elif len(stepped_latents.shape) == 5:
|
||||
t_01 = t_01.view(-1, 1, 1, 1, 1)
|
||||
else:
|
||||
raise ValueError("Unknown stepped latents shape", stepped_latents.shape)
|
||||
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
|
||||
noise = next_sample_noise
|
||||
timesteps = stepped_timesteps
|
||||
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
@@ -1450,15 +1565,25 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
|
||||
prior_to_calculate_loss = None
|
||||
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
if self.train_config.loss_type == 'cfm':
|
||||
loss = self.get_cfm_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
noise=noise,
|
||||
noise_pred=noise_pred,
|
||||
conditional_embeds=conditional_embeds,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
)
|
||||
else:
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_to_calculate_loss,
|
||||
)
|
||||
|
||||
if self.train_config.diff_output_preservation:
|
||||
# send the loss backwards otherwise checkpointing will fail
|
||||
|
||||
Reference in New Issue
Block a user