Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.

This commit is contained in:
Jaret Burkett
2023-09-16 08:30:38 -06:00
parent 17e4fe40d7
commit 27f343fc08
8 changed files with 314 additions and 84 deletions

View File

@@ -5,6 +5,7 @@ from collections import OrderedDict
import os
from typing import Union
from diffusers import T2IAdapter
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
import torch
@@ -21,6 +22,7 @@ from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
from toolkit.progress_bar import ToolkitProgressBar
from toolkit.sampler import get_sampler
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model
from toolkit.scheduler import get_lr_scheduler
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
@@ -34,7 +36,7 @@ import gc
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig
def flush():
@@ -105,9 +107,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
# t2i adapter
self.adapter_config = None
adapter_raw = self.get_conf('adapter', None)
if adapter_raw is not None:
self.adapter_config = AdapterConfig(**adapter_raw)
# sdxl adapters end in _xl. Only full_adapter_xl for now
if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
self.adapter_config.adapter_type += '_xl'
model_config_to_load = copy.deepcopy(self.model_config)
if self.embed_config is None and self.network_config is None:
if self.embed_config is None and self.network_config is None and self.adapter_config is None:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
@@ -135,6 +146,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, None] = None
self.embedding: Union[Embedding, None] = None
# get the device state preset based on what we are training
@@ -144,6 +156,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_text_encoder=self.train_config.train_text_encoder,
cached_latents=self.is_latents_cached,
train_lora=self.network_config is not None,
train_adapter=self.adapter_config is not None,
train_embedding=self.embed_config is not None,
)
@@ -305,6 +318,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
# replace extension
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
self.embedding.save(emb_file_path)
elif self.adapter is not None:
# save adapter
state_dict = self.adapter.state_dict()
save_t2i_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
else:
self.sd.save(
file_path,
@@ -360,20 +382,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
return None
def load_training_state_from_metadata(self, path):
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
def load_weights(self, path):
if self.network is not None:
extra_weights = self.network.load_weights(path)
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
self.load_training_state_from_metadata(path)
return extra_weights
else:
print("load_weights not implemented for non-network models")
return None
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, )
timesteps = timesteps.to(self.device_torch, )
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
prompts = batch.get_caption_list()
@@ -407,8 +444,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
imgs = imgs.to(self.device_torch, dtype=dtype)
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
else:
latents = self.sd.encode_images(imgs)
batch.latents = latents
flush()
batch_size = latents.shape[0]
@@ -416,47 +455,46 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
if self.train_config.use_progressive_denoising:
min_timestep = int(value_map(
self.step_num,
min_in=0,
max_in=self.train_config.max_denoising_steps,
min_out=self.train_config.min_denoising_steps,
max_out=self.train_config.max_denoising_steps
))
elif self.train_config.use_linear_denoising:
# starts at max steps and walks down to min steps
min_timestep = int(value_map(
self.step_num,
min_in=0,
max_in=self.train_config.max_denoising_steps,
min_out=self.train_config.max_denoising_steps - 1,
max_out=self.train_config.min_denoising_steps
))
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
# for content / structure, it is best to favor earlier timesteps
# for style, it is best to favor later timesteps
timesteps = torch.rand((batch_size,), device=latents.device)
if self.train_config.content_or_style == 'style':
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'content':
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timesteps = value_map(
timesteps,
0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps
)
timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps - 1
)
elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
else:
min_timestep = self.train_config.min_denoising_steps
# todo improve this, but is skews odds for higher timesteps
# 50% chance to use midpoint as the min_time_step
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
if torch.rand(1) > 0.5:
min_timestep = mid_point
# 50% chance to use midpoint as the min_time_step
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
if torch.rand(1) > 0.5:
min_timestep = mid_point
min_timestep = int(min_timestep)
timesteps = torch.randint(
min_timestep,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
# get noise
noise = self.sd.get_latent_noise(
@@ -477,6 +515,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
def run(self):
# torch.autograd.set_detect_anomaly(True)
# run base process run
BaseTrainProcess.run(self)
@@ -653,9 +692,34 @@ class BaseSDTrainProcess(BaseTrainProcess):
# set trainable params
params = self.embedding.get_trainable_params()
flush()
elif self.adapter_config is not None:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
# t2i adapter
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
loaded_state_dict = load_t2i_model(
latest_save_path,
self.device_torch,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
self.load_training_state_from_metadata(latest_save_path)
params = self.get_params()
if not params:
# set trainable params
params = self.adapter.parameters()
self.sd.adapter = self.adapter
flush()
else:
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)