Added new timestep weighing strategy

This commit is contained in:
Jaret Burkett
2025-06-04 01:16:02 -06:00
parent adc31ec77d
commit 22cdfadab6
8 changed files with 1348 additions and 9 deletions

View File

@@ -437,7 +437,7 @@ class TrainConfig:
# adds an additional loss to the network to encourage it output a normalized standard deviation
self.target_norm_std = kwargs.get('target_norm_std', None)
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample, weighted
self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8)
self.linear_timesteps = kwargs.get('linear_timesteps', False)
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)

View File

@@ -1773,6 +1773,97 @@ class LatentCachingMixin:
self.sd.restore_device_state()
class TextEmbeddingCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
def cache_text_embeddings(self: 'AiToolkitDataset'):
with accelerator.main_process_first():
print_acc(f"Caching text_embeddings for {self.dataset_path}")
# cache all latents to disk
to_disk = self.is_caching_latents_to_disk
to_memory = self.is_caching_latents_to_memory
print_acc(" - Saving text embeddings to disk")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_latents')
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
# set latent space version
if self.sd.model_config.latent_space_version is not None:
file_item.latent_space_version = self.sd.model_config.latent_space_version
elif self.sd.is_xl:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_v3:
file_item.latent_space_version = 'sd3'
elif self.sd.is_auraflow:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_flux:
file_item.latent_space_version = 'flux1'
elif self.sd.model_config.is_pixart_sigma:
file_item.latent_space_version = 'sdxl'
else:
file_item.latent_space_version = self.sd.model_config.arch
file_item.is_caching_to_disk = to_disk
file_item.is_caching_to_memory = to_memory
file_item.latent_load_device = self.sd.device
latent_path = file_item.get_latent_path(recalculate=True)
# check if it is saved to disk already
if os.path.exists(latent_path):
if to_memory:
# load it into memory
state_dict = load_file(latent_path, device='cpu')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
else:
# not saved to disk, calculate
# load the image first
file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype
device = self.sd.device_torch
# add batch dimension
try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0)
except Exception as e:
print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}")
raise e
# save_latent
if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
save_file(state_dict, latent_path, metadata=meta)
if to_memory:
# keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs
del latent
del file_item.tensor
# flush(garbage_collect=False)
file_item.is_latent_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()
class CLIPCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
# if we have super, call it

View File

@@ -4,6 +4,7 @@ from torch.distributions import LogNormal
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
import numpy as np
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
def calculate_shift(
@@ -47,20 +48,26 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
hbsmntw_weighing[num_timesteps //
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
# Create linear timesteps from 1000 to 1
timesteps = torch.linspace(1000, 1, num_timesteps, device='cpu')
self.linear_timesteps = timesteps
self.linear_timesteps_weights = bsmntw_weighing
self.linear_timesteps_weights2 = hbsmntw_weighing
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False, timestep_type="linear") -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item()
for t in timesteps]
# Get the weights for the timesteps
if timestep_type == "weighted":
weights = torch.tensor(
[default_weighing_scheme[i] for i in step_indices],
device=timesteps.device,
dtype=timesteps.dtype
)
if v2:
weights = self.linear_timesteps_weights2[step_indices].flatten()
else:
@@ -106,8 +113,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
patch_size=1
):
self.timestep_type = timestep_type
if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
if timestep_type == 'linear' or timestep_type == 'weighted':
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
elif timestep_type == 'sigmoid':
@@ -198,7 +205,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
t1 = ((1 - t1/t1.max()) * 1000)
# add half of linear
t2 = torch.linspace(1000, 0, int(
t2 = torch.linspace(1000, 1, int(
num_timesteps * (1 - alpha)), device=device)
timesteps = torch.cat((t1, t2))

View File

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 190 KiB