Added some experimental training techniques. Ignore for now. Still in testing.

This commit is contained in:
Jaret Burkett
2025-05-21 02:19:54 -06:00
parent 01101be196
commit e5181d23cd
6 changed files with 240 additions and 43 deletions

16
.vscode/launch.json vendored
View File

@@ -16,6 +16,22 @@
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Run current config (cuda:1)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/run.py",
"args": [
"${file}"
],
"env": {
"CUDA_LAUNCH_BLOCKING": "1",
"DEBUG_TOOLKIT": "1",
"CUDA_VISIBLE_DEVICES": "1"
},
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Python: Debug Current File",
"type": "python",

21
build_and_push_docker_dev Normal file
View File

@@ -0,0 +1,21 @@
#!/usr/bin/env bash
VERSION=dev
GIT_COMMIT=dev
echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
echo "Building version: $VERSION and latest"
# wait 2 seconds
sleep 2
# Build the image with cache busting
docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
# Tag with version and latest
docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
# Push both tags
echo "Pushing images to Docker Hub..."
docker push ostris/aitoolkit:$VERSION
echo "Successfully built and pushed ostris/aitoolkit:$VERSION"

View File

@@ -35,6 +35,7 @@ import math
from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F
def flush():
@@ -60,6 +61,7 @@ class SDTrainer(BaseSDTrainProcess):
self._clip_image_embeds_unconditional: Union[List[str], None] = None
self.negative_prompt_pool: Union[List[str], None] = None
self.batch_negative_prompt: Union[List[str], None] = None
self.cfm_cache = None
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
@@ -197,7 +199,7 @@ class SDTrainer(BaseSDTrainProcess):
flush()
if self.train_config.diffusion_feature_extractor_path is not None:
vae = None
vae = self.sd.vae
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
# vae = self.sd.vae
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
@@ -756,13 +758,13 @@ class SDTrainer(BaseSDTrainProcess):
pass
def predict_noise(
self,
noisy_latents: torch.Tensor,
timesteps: Union[int, torch.Tensor] = 1,
conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None,
batch: Optional['DataLoaderBatchDTO'] = None,
**kwargs,
self,
noisy_latents: torch.Tensor,
timesteps: Union[int, torch.Tensor] = 1,
conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None,
batch: Optional['DataLoaderBatchDTO'] = None,
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
return self.sd.predict_noise(
@@ -778,6 +780,81 @@ class SDTrainer(BaseSDTrainProcess):
batch=batch,
**kwargs
)
def cfm_augment_tensors(
self,
images: torch.Tensor
) -> torch.Tensor:
if self.cfm_cache is None:
# flip the current one. Only need this for first time
self.cfm_cache = torch.flip(images, [3]).clone()
augmented_tensor_list = []
for i in range(images.shape[0]):
# get a random one
idx = random.randint(0, self.cfm_cache.shape[0] - 1)
augmented_tensor_list.append(self.cfm_cache[idx:idx + 1])
augmented = torch.cat(augmented_tensor_list, dim=0)
# resize to match the input
augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear')
self.cfm_cache = images.clone()
return augmented
def get_cfm_loss(
self,
noisy_latents: torch.Tensor,
noise: torch.Tensor,
noise_pred: torch.Tensor,
conditional_embeds: PromptEmbeds,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
alpha: float = 0.1,
):
dtype = get_torch_dtype(self.train_config.dtype)
if hasattr(self.sd, 'get_loss_target'):
target = self.sd.get_loss_target(
noise=noise,
batch=batch,
timesteps=timesteps,
).detach()
elif self.sd.is_flow_matching:
# forward ODE
target = (noise - batch.latents).detach()
else:
raise ValueError("CFM loss only works with flow matching")
fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
with torch.no_grad():
# we need to compute the contrast
cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype)
cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype)
cfm_noisy_latents = self.sd.add_noise(
original_samples=cfm_latents,
noise=noise,
timesteps=timesteps,
)
cfm_pred = self.predict_noise(
noisy_latents=cfm_noisy_latents,
timesteps=timesteps,
conditional_embeds=conditional_embeds,
unconditional_embeds=None,
batch=batch,
)
# v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1)
# v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W)
# # Compute cosine similarity at each pixel
# sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W)
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
# Compute cosine similarity at each pixel
sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W)
# Average over spatial dimensions, then batch
contrastive_loss = -sim.mean()
loss = fm_loss.mean() + alpha * contrastive_loss
return loss
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
self.timer.start('preprocess_batch')
@@ -1431,6 +1508,44 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, CustomAdapter):
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
if self.train_config.timestep_type == 'next_sample':
with self.timer('next_sample_step'):
with torch.no_grad():
stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps]
stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies]
stepped_timesteps = torch.stack(stepped_timesteps, dim=0)
# do a sample at the current timestep and step it, then determine new noise
next_sample_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
stepped_latents = self.sd.step_scheduler(
next_sample_pred,
noisy_latents,
timesteps,
self.sd.noise_scheduler
)
# stepped latents is our new noisy latents. Now we need to determine noise in the current sample
noisy_latents = stepped_latents
original_samples = batch.latents.to(self.device_torch, dtype=dtype)
# todo calc next timestep, for now this may work as it
t_01 = (stepped_timesteps / 1000).to(original_samples.device)
if len(stepped_latents.shape) == 4:
t_01 = t_01.view(-1, 1, 1, 1)
elif len(stepped_latents.shape) == 5:
t_01 = t_01.view(-1, 1, 1, 1, 1)
else:
raise ValueError("Unknown stepped latents shape", stepped_latents.shape)
next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01
noise = next_sample_noise
timesteps = stepped_timesteps
with self.timer('predict_unet'):
noise_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
@@ -1450,15 +1565,25 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
prior_to_calculate_loss = None
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_to_calculate_loss,
)
if self.train_config.loss_type == 'cfm':
loss = self.get_cfm_loss(
noisy_latents=noisy_latents,
noise=noise,
noise_pred=noise_pred,
conditional_embeds=conditional_embeds,
timesteps=timesteps,
batch=batch,
)
else:
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_to_calculate_loss,
)
if self.train_config.diff_output_preservation:
# send the loss backwards otherwise checkpointing will fail

View File

@@ -931,16 +931,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
dtype=noise.dtype) * 2 - 1
# if self.train_config.random_noise_shift > 0.0:
# # get random noise -1 to 1
# noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
# dtype=noise.dtype) * 2 - 1
# multiply by shift amount
noise_shift *= self.train_config.random_noise_shift
# # multiply by shift amount
# noise_shift *= self.train_config.random_noise_shift
# add to noise
noise += noise_shift
# # add to noise
# noise += noise_shift
if self.train_config.blended_blur_noise:
noise = get_blended_blur_noise(
@@ -1011,6 +1011,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
is_reg = any(batch.get_is_reg_list())
cfm_batch = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
@@ -1118,8 +1119,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
if timestep_type is None:
timestep_type = self.train_config.timestep_type
if self.train_config.timestep_type == 'next_sample':
# simulate a sample
num_train_timesteps = self.train_config.next_sample_timesteps
timestep_type = 'shift'
patch_size = 1
if self.sd.is_flux:
if self.sd.is_flux or 'flex' in self.sd.arch:
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
patch_size = 2
elif hasattr(self.sd.unet.config, 'patch_size'):
@@ -1142,7 +1148,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
content_or_style = self.train_config.content_or_style_reg
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if content_or_style in ['style', 'content']:
if self.train_config.timestep_type == 'next_sample':
timestep_indices = torch.randint(
0,
num_train_timesteps - 2, # -1 for 0 idx, -1 so we can step
(batch_size,),
device=self.device_torch
)
timestep_indices = timestep_indices.long()
elif content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
@@ -1169,7 +1183,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
min_noise_steps + 1,
max_noise_steps - 1
)
elif content_or_style == 'balanced':
if min_noise_steps == max_noise_steps:
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
@@ -1185,16 +1199,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
raise ValueError(f"Unknown content_or_style {content_or_style}")
# do flow matching
# if self.sd.is_flow_matching:
# u = compute_density_for_timestep_sampling(
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
# batch_size=batch_size,
# logit_mean=0.0,
# logit_std=1.0,
# mode_scale=1.29,
# )
# timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
timesteps = torch.stack(timesteps, dim=0)
@@ -1218,8 +1222,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
latents = unaugmented_latents
noise_multiplier = self.train_config.noise_multiplier
s = (noise.shape[0], noise.shape[1], 1, 1)
if len(noise.shape) == 5:
# if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame
s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1)
if self.train_config.random_noise_multiplier > 0.0:
# do it on a per batch item, per channel basis
noise_multiplier = 1 + torch.randn(
s,
device=noise.device,
dtype=noise.dtype
) * self.train_config.random_noise_multiplier
noise = noise * noise_multiplier
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.randn(
s,
device=noise.device,
dtype=noise.dtype
) * self.train_config.random_noise_shift
# add to noise
noise += noise_shift
latent_multiplier = self.train_config.latent_multiplier

View File

@@ -325,6 +325,8 @@ class TrainConfig:
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0)
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
@@ -333,7 +335,6 @@ class TrainConfig:
# multiplier applied to loos on regularization images
self.reg_weight = kwargs.get('reg_weight', 1.0)
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
# automatically adapte the vae scaling based on the image norm
self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False)
@@ -412,7 +413,7 @@ class TrainConfig:
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm
# scale the prediction by this. Increase for more detail, decrease for less
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
@@ -436,7 +437,8 @@ class TrainConfig:
# adds an additional loss to the network to encourage it output a normalized standard deviation
self.target_norm_std = kwargs.get('target_norm_std', None)
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample
self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8)
self.linear_timesteps = kwargs.get('linear_timesteps', False)
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
self.disable_sampling = kwargs.get('disable_sampling', False)

View File

@@ -142,7 +142,9 @@ class StableDiffusion:
):
self.accelerator = get_accelerator()
self.custom_pipeline = custom_pipeline
self.device = device
self.device = str(device)
if "cuda" in self.device and ":" not in self.device:
self.device = f"{self.device}:0"
self.device_torch = torch.device(device)
self.dtype = dtype
self.torch_dtype = get_torch_dtype(dtype)
@@ -2086,7 +2088,10 @@ class StableDiffusion:
noise_pred = noise_pred
else:
if self.unet.device != self.device_torch:
self.unet.to(self.device_torch)
try:
self.unet.to(self.device_torch)
except Exception as e:
pass
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)
if self.is_flux: