mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added new timestep weighing strategy
This commit is contained in:
@@ -501,13 +501,22 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = wavelet_loss(pred, batch.latents, noise)
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
do_weighted_timesteps = False
|
||||
if self.sd.is_flow_matching:
|
||||
if self.train_config.linear_timesteps or self.train_config.linear_timesteps2:
|
||||
do_weighted_timesteps = True
|
||||
if self.train_config.timestep_type == "weighted":
|
||||
# use the noise scheduler to get the weights for the timesteps
|
||||
do_weighted_timesteps = True
|
||||
|
||||
# handle linear timesteps and only adjust the weight of the timesteps
|
||||
if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2):
|
||||
if do_weighted_timesteps:
|
||||
# calculate the weights for the timesteps
|
||||
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(
|
||||
timesteps,
|
||||
v2=self.train_config.linear_timesteps2
|
||||
v2=self.train_config.linear_timesteps2,
|
||||
timestep_type=self.train_config.timestep_type
|
||||
).to(loss.device, dtype=loss.dtype)
|
||||
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
|
||||
loss = loss * timestep_weight
|
||||
|
||||
228
scripts/calculate_timestep_weighing_flex.py
Normal file
228
scripts/calculate_timestep_weighing_flex.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import gc
|
||||
import os, sys
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import json
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# set visible devices to 0
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
# protect from formatting
|
||||
if True:
|
||||
import torch
|
||||
from optimum.quanto import freeze, qfloat8, QTensor, qint4
|
||||
from diffusers import FluxTransformer2DModel, FluxPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTextModel, CLIPTokenizer
|
||||
from torchvision import transforms
|
||||
|
||||
qtype = "qfloat8"
|
||||
dtype = torch.bfloat16
|
||||
# base_model_path = "black-forest-labs/FLUX.1-dev"
|
||||
base_model_path = "ostris/Flex.1-alpha"
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print("Loading Transformer...")
|
||||
prompt = "Photo of a man and a woman in a park, sunny day"
|
||||
|
||||
output_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "output")
|
||||
output_path = os.path.join(output_root, "flex_timestep_weights.json")
|
||||
img_output_path = os.path.join(output_root, "flex_timestep_weights.png")
|
||||
|
||||
quantization_type = get_qtype(qtype)
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
pil_to_tensor = transforms.ToTensor()
|
||||
|
||||
with torch.no_grad():
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
base_model_path,
|
||||
subfolder='transformer',
|
||||
torch_dtype=dtype
|
||||
)
|
||||
|
||||
transformer.to(device, dtype=dtype)
|
||||
|
||||
print("Quantizing Transformer...")
|
||||
quantize(transformer, weights=quantization_type)
|
||||
freeze(transformer)
|
||||
flush()
|
||||
|
||||
print("Loading Scheduler...")
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
|
||||
print("Loading Autoencoder...")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
|
||||
vae.to(device, dtype=dtype)
|
||||
|
||||
flush()
|
||||
print("Loading Text Encoder...")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
text_encoder_2.to(device, dtype=dtype)
|
||||
|
||||
print("Quantizing Text Encoder...")
|
||||
quantize(text_encoder_2, weights=get_qtype(qtype))
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
print("Loading CLIP")
|
||||
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder.to(device, dtype=dtype)
|
||||
|
||||
print("Making pipe")
|
||||
|
||||
pipe: FluxPipeline = FluxPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
)
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.transformer = transformer
|
||||
|
||||
pipe.to(device, dtype=dtype)
|
||||
|
||||
print("Encoding prompt...")
|
||||
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
|
||||
prompt,
|
||||
prompt_2=prompt,
|
||||
device=device
|
||||
)
|
||||
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
|
||||
height = 1024
|
||||
width = 1024
|
||||
|
||||
print("Generating image...")
|
||||
|
||||
# Fix a bug in diffusers/torch
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
latents = callback_kwargs["latents"]
|
||||
if latents.dtype != dtype:
|
||||
latents = latents.to(dtype)
|
||||
return {"latents": latents}
|
||||
img = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
height=height,
|
||||
width=height,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=3.5,
|
||||
generator=generator,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
).images[0]
|
||||
|
||||
img.save(img_output_path)
|
||||
print(f"Image saved to {img_output_path}")
|
||||
|
||||
print("Encoding image...")
|
||||
# img is a PIL image. convert it to a -1 to 1 tensor
|
||||
img = pil_to_tensor(img)
|
||||
img = img.unsqueeze(0) # add batch dimension
|
||||
img = img * 2 - 1 # convert to -1 to 1 range
|
||||
img = img.to(device, dtype=dtype)
|
||||
latents = vae.encode(img).latent_dist.sample()
|
||||
|
||||
shift = vae.config['shift_factor'] if vae.config['shift_factor'] is not None else 0
|
||||
latents = vae.config['scaling_factor'] * (latents - shift)
|
||||
|
||||
num_channels_latents = pipe.transformer.config.in_channels // 4
|
||||
|
||||
l_height = 2 * (int(height) // (pipe.vae_scale_factor * 2))
|
||||
l_width = 2 * (int(width) // (pipe.vae_scale_factor * 2))
|
||||
packed_latents = pipe._pack_latents(latents, 1, num_channels_latents, l_height, l_width)
|
||||
|
||||
packed_latents, latent_image_ids = pipe.prepare_latents(
|
||||
1,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
packed_latents,
|
||||
)
|
||||
|
||||
print("Calculating timestep weights...")
|
||||
|
||||
torch.manual_seed(8675309)
|
||||
noise = torch.randn_like(packed_latents, device=device, dtype=dtype)
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
num_train_timesteps = 1000
|
||||
timesteps_torch = torch.linspace(1000, 1, num_train_timesteps, device='cpu')
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
timestep_weights = torch.zeros(num_train_timesteps, dtype=torch.float32, device=device)
|
||||
|
||||
guidance = torch.full([1], 1.0, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
|
||||
pbar = tqdm(range(num_train_timesteps), desc="loss: 0.000000 scaler: 0.0000")
|
||||
for i in pbar:
|
||||
timestep = timesteps[i:i+1].to(device)
|
||||
t_01 = (timestep / 1000).to(device)
|
||||
t_01 = t_01.reshape(-1, 1, 1)
|
||||
noisy_latents = (1.0 - t_01) * packed_latents + t_01 * noise
|
||||
|
||||
noise_pred = pipe.transformer(
|
||||
hidden_states=noisy_latents, # torch.Size([1, 4096, 64])
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
target = noise - packed_latents
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float())
|
||||
loss = loss
|
||||
|
||||
# determine scaler to multiply loss by to make it 1
|
||||
scaler = 1.0 / (loss + 1e-6)
|
||||
|
||||
timestep_weights[i] = scaler
|
||||
pbar.set_description(f"loss: {loss.item():.6f} scaler: {scaler.item():.4f}")
|
||||
|
||||
print("normalizing timestep weights...")
|
||||
# normalize the timestep weights so they are a mean of 1.0
|
||||
timestep_weights = timestep_weights / timestep_weights.mean()
|
||||
timestep_weights = timestep_weights.cpu().numpy().tolist()
|
||||
|
||||
print("Saving timestep weights...")
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(timestep_weights, f)
|
||||
|
||||
|
||||
print(f"Timestep weights saved to {output_path}")
|
||||
print("Done!")
|
||||
flush()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -437,7 +437,7 @@ class TrainConfig:
|
||||
# adds an additional loss to the network to encourage it output a normalized standard deviation
|
||||
self.target_norm_std = kwargs.get('target_norm_std', None)
|
||||
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
||||
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample
|
||||
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample, weighted
|
||||
self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8)
|
||||
self.linear_timesteps = kwargs.get('linear_timesteps', False)
|
||||
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
|
||||
|
||||
@@ -1773,6 +1773,97 @@ class LatentCachingMixin:
|
||||
self.sd.restore_device_state()
|
||||
|
||||
|
||||
|
||||
class TextEmbeddingCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
|
||||
|
||||
def cache_text_embeddings(self: 'AiToolkitDataset'):
|
||||
|
||||
with accelerator.main_process_first():
|
||||
print_acc(f"Caching text_embeddings for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
to_disk = self.is_caching_latents_to_disk
|
||||
to_memory = self.is_caching_latents_to_memory
|
||||
print_acc(" - Saving text embeddings to disk")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_latents')
|
||||
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.model_config.latent_space_version is not None:
|
||||
file_item.latent_space_version = self.sd.model_config.latent_space_version
|
||||
elif self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_v3:
|
||||
file_item.latent_space_version = 'sd3'
|
||||
elif self.sd.is_auraflow:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_flux:
|
||||
file_item.latent_space_version = 'flux1'
|
||||
elif self.sd.model_config.is_pixart_sigma:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = self.sd.model_config.arch
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
file_item.latent_load_device = self.sd.device
|
||||
|
||||
latent_path = file_item.get_latent_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if os.path.exists(latent_path):
|
||||
if to_memory:
|
||||
# load it into memory
|
||||
state_dict = load_file(latent_path, device='cpu')
|
||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
else:
|
||||
# not saved to disk, calculate
|
||||
# load the image first
|
||||
file_item.load_and_process_image(self.transform, only_load_latents=True)
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
# add batch dimension
|
||||
try:
|
||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||
except Exception as e:
|
||||
print_acc(f"Error processing image: {file_item.path}")
|
||||
print_acc(f"Error: {str(e)}")
|
||||
raise e
|
||||
# save_latent
|
||||
if to_disk:
|
||||
state_dict = OrderedDict([
|
||||
('latent', latent.clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||
save_file(state_dict, latent_path, metadata=meta)
|
||||
|
||||
if to_memory:
|
||||
# keep it in memory
|
||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||
|
||||
del imgs
|
||||
del latent
|
||||
del file_item.tensor
|
||||
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
|
||||
|
||||
class CLIPCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
# if we have super, call it
|
||||
|
||||
@@ -4,6 +4,7 @@ from torch.distributions import LogNormal
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
import numpy as np
|
||||
from toolkit.timestep_weighing.default_weighing_scheme import default_weighing_scheme
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
@@ -47,20 +48,26 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
hbsmntw_weighing[num_timesteps //
|
||||
2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
||||
# Create linear timesteps from 1000 to 1
|
||||
timesteps = torch.linspace(1000, 1, num_timesteps, device='cpu')
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
self.linear_timesteps_weights2 = hbsmntw_weighing
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False, timestep_type="linear") -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item()
|
||||
for t in timesteps]
|
||||
|
||||
# Get the weights for the timesteps
|
||||
if timestep_type == "weighted":
|
||||
weights = torch.tensor(
|
||||
[default_weighing_scheme[i] for i in step_indices],
|
||||
device=timesteps.device,
|
||||
dtype=timesteps.dtype
|
||||
)
|
||||
if v2:
|
||||
weights = self.linear_timesteps_weights2[step_indices].flatten()
|
||||
else:
|
||||
@@ -106,8 +113,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
patch_size=1
|
||||
):
|
||||
self.timestep_type = timestep_type
|
||||
if timestep_type == 'linear':
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||
if timestep_type == 'linear' or timestep_type == 'weighted':
|
||||
timesteps = torch.linspace(1000, 1, num_timesteps, device=device)
|
||||
self.timesteps = timesteps
|
||||
return timesteps
|
||||
elif timestep_type == 'sigmoid':
|
||||
@@ -198,7 +205,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
t1 = ((1 - t1/t1.max()) * 1000)
|
||||
|
||||
# add half of linear
|
||||
t2 = torch.linspace(1000, 0, int(
|
||||
t2 = torch.linspace(1000, 1, int(
|
||||
num_timesteps * (1 - alpha)), device=device)
|
||||
timesteps = torch.cat((t1, t2))
|
||||
|
||||
|
||||
0
toolkit/timestep_weighing/__init__.py
Normal file
0
toolkit/timestep_weighing/__init__.py
Normal file
1004
toolkit/timestep_weighing/default_weighing_scheme.py
Normal file
1004
toolkit/timestep_weighing/default_weighing_scheme.py
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/timestep_weighing/flex_timestep_weights_plot.png
Normal file
BIN
toolkit/timestep_weighing/flex_timestep_weights_plot.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 190 KiB |
Reference in New Issue
Block a user