mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.
This commit is contained in:
@@ -5,6 +5,7 @@ from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from diffusers import T2IAdapter
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
@@ -21,6 +22,7 @@ from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
@@ -34,7 +36,7 @@ import gc
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -105,9 +107,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if embedding_raw is not None:
|
||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||
|
||||
# t2i adapter
|
||||
self.adapter_config = None
|
||||
adapter_raw = self.get_conf('adapter', None)
|
||||
if adapter_raw is not None:
|
||||
self.adapter_config = AdapterConfig(**adapter_raw)
|
||||
# sdxl adapters end in _xl. Only full_adapter_xl for now
|
||||
if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
|
||||
self.adapter_config.adapter_type += '_xl'
|
||||
|
||||
model_config_to_load = copy.deepcopy(self.model_config)
|
||||
|
||||
if self.embed_config is None and self.network_config is None:
|
||||
if self.embed_config is None and self.network_config is None and self.adapter_config is None:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
@@ -135,6 +146,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
@@ -144,6 +156,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=self.adapter_config is not None,
|
||||
train_embedding=self.embed_config is not None,
|
||||
)
|
||||
|
||||
@@ -305,6 +318,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# replace extension
|
||||
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
|
||||
self.embedding.save(emb_file_path)
|
||||
elif self.adapter is not None:
|
||||
# save adapter
|
||||
state_dict = self.adapter.state_dict()
|
||||
save_t2i_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -360,20 +382,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
return None
|
||||
|
||||
def load_training_state_from_metadata(self, path):
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
def load_weights(self, path):
|
||||
if self.network is not None:
|
||||
extra_weights = self.network.load_weights(path)
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
self.load_training_state_from_metadata(path)
|
||||
return extra_weights
|
||||
else:
|
||||
print("load_weights not implemented for non-network models")
|
||||
return None
|
||||
|
||||
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
||||
sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
|
||||
schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, )
|
||||
timesteps = timesteps.to(self.device_torch, )
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
prompts = batch.get_caption_list()
|
||||
@@ -407,8 +444,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
else:
|
||||
latents = self.sd.encode_images(imgs)
|
||||
batch.latents = latents
|
||||
flush()
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
@@ -416,47 +455,46 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
if self.train_config.use_progressive_denoising:
|
||||
min_timestep = int(value_map(
|
||||
self.step_num,
|
||||
min_in=0,
|
||||
max_in=self.train_config.max_denoising_steps,
|
||||
min_out=self.train_config.min_denoising_steps,
|
||||
max_out=self.train_config.max_denoising_steps
|
||||
))
|
||||
|
||||
elif self.train_config.use_linear_denoising:
|
||||
# starts at max steps and walks down to min steps
|
||||
min_timestep = int(value_map(
|
||||
self.step_num,
|
||||
min_in=0,
|
||||
max_in=self.train_config.max_denoising_steps,
|
||||
min_out=self.train_config.max_denoising_steps - 1,
|
||||
max_out=self.train_config.min_denoising_steps
|
||||
))
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
# this is from diffusers training code
|
||||
# Cubic sampling for favoring later or earlier timesteps
|
||||
# For more details about why cubic sampling is used for content / structure,
|
||||
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
|
||||
|
||||
# for content / structure, it is best to favor earlier timesteps
|
||||
# for style, it is best to favor later timesteps
|
||||
|
||||
timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'style':
|
||||
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'content':
|
||||
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timesteps = value_map(
|
||||
timesteps,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps - 1
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
timesteps = torch.randint(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
min_timestep = self.train_config.min_denoising_steps
|
||||
|
||||
# todo improve this, but is skews odds for higher timesteps
|
||||
# 50% chance to use midpoint as the min_time_step
|
||||
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
|
||||
if torch.rand(1) > 0.5:
|
||||
min_timestep = mid_point
|
||||
|
||||
# 50% chance to use midpoint as the min_time_step
|
||||
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
|
||||
if torch.rand(1) > 0.5:
|
||||
min_timestep = mid_point
|
||||
|
||||
min_timestep = int(min_timestep)
|
||||
|
||||
timesteps = torch.randint(
|
||||
min_timestep,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
@@ -477,6 +515,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
|
||||
|
||||
def run(self):
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
# run base process run
|
||||
BaseTrainProcess.run(self)
|
||||
|
||||
@@ -653,9 +692,34 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
flush()
|
||||
elif self.adapter_config is not None:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
# t2i adapter
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
self.device_torch,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.adapter.parameters()
|
||||
self.sd.adapter = self.adapter
|
||||
flush()
|
||||
else:
|
||||
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user