mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-02 09:09:48 +00:00
Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.
This commit is contained in:
@@ -1,17 +1,27 @@
|
||||
import os.path
|
||||
from collections import OrderedDict
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
adapter_transforms = transforms.Compose([
|
||||
transforms.PILToTensor(),
|
||||
])
|
||||
|
||||
|
||||
class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
@@ -31,11 +41,47 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
|
||||
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
adapter_folder_path = self.adapter_config.image_dir
|
||||
adapter_images = []
|
||||
# loop through images
|
||||
for file_item in batch.file_items:
|
||||
img_path = file_item.path
|
||||
file_name_no_ext = os.path.basename(img_path).split('.')[0]
|
||||
# find the image
|
||||
for ext in img_ext_list:
|
||||
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
|
||||
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
|
||||
break
|
||||
|
||||
adapter_tensors = []
|
||||
# load images with torch transforms
|
||||
for adapter_image in adapter_images:
|
||||
img = Image.open(adapter_image)
|
||||
img = adapter_transforms(img)
|
||||
adapter_tensors.append(img)
|
||||
|
||||
# stack them
|
||||
adapter_tensors = torch.stack(adapter_tensors)
|
||||
return adapter_tensors
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if self.adapter:
|
||||
# todo move this to data loader
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
# flush()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
@@ -64,30 +110,55 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if self.adapter:
|
||||
down_block_additional_residuals = self.adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
# flush()
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
if self.adapter:
|
||||
# todo, diffusers does this on t2i training, is it better approach?
|
||||
# Denoise the latents
|
||||
denoised_latents = noise_pred * (-sigmas) + noisy_latents
|
||||
weighing = sigmas ** -2.0
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = batch.latents # we are computing loss against denoise latents
|
||||
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
||||
|
||||
# MSE loss
|
||||
loss = torch.mean(
|
||||
(weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
target = noise
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
# TODO: I think the sigma method does not need this. Check
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from diffusers import T2IAdapter
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
@@ -21,6 +22,7 @@ from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
@@ -34,7 +36,7 @@ import gc
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -105,9 +107,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if embedding_raw is not None:
|
||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||
|
||||
# t2i adapter
|
||||
self.adapter_config = None
|
||||
adapter_raw = self.get_conf('adapter', None)
|
||||
if adapter_raw is not None:
|
||||
self.adapter_config = AdapterConfig(**adapter_raw)
|
||||
# sdxl adapters end in _xl. Only full_adapter_xl for now
|
||||
if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
|
||||
self.adapter_config.adapter_type += '_xl'
|
||||
|
||||
model_config_to_load = copy.deepcopy(self.model_config)
|
||||
|
||||
if self.embed_config is None and self.network_config is None:
|
||||
if self.embed_config is None and self.network_config is None and self.adapter_config is None:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
@@ -135,6 +146,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
@@ -144,6 +156,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=self.adapter_config is not None,
|
||||
train_embedding=self.embed_config is not None,
|
||||
)
|
||||
|
||||
@@ -305,6 +318,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# replace extension
|
||||
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
|
||||
self.embedding.save(emb_file_path)
|
||||
elif self.adapter is not None:
|
||||
# save adapter
|
||||
state_dict = self.adapter.state_dict()
|
||||
save_t2i_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -360,20 +382,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
return None
|
||||
|
||||
def load_training_state_from_metadata(self, path):
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
def load_weights(self, path):
|
||||
if self.network is not None:
|
||||
extra_weights = self.network.load_weights(path)
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
self.load_training_state_from_metadata(path)
|
||||
return extra_weights
|
||||
else:
|
||||
print("load_weights not implemented for non-network models")
|
||||
return None
|
||||
|
||||
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
||||
sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
|
||||
schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, )
|
||||
timesteps = timesteps.to(self.device_torch, )
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
prompts = batch.get_caption_list()
|
||||
@@ -407,8 +444,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
else:
|
||||
latents = self.sd.encode_images(imgs)
|
||||
batch.latents = latents
|
||||
flush()
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
@@ -416,47 +455,46 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
if self.train_config.use_progressive_denoising:
|
||||
min_timestep = int(value_map(
|
||||
self.step_num,
|
||||
min_in=0,
|
||||
max_in=self.train_config.max_denoising_steps,
|
||||
min_out=self.train_config.min_denoising_steps,
|
||||
max_out=self.train_config.max_denoising_steps
|
||||
))
|
||||
|
||||
elif self.train_config.use_linear_denoising:
|
||||
# starts at max steps and walks down to min steps
|
||||
min_timestep = int(value_map(
|
||||
self.step_num,
|
||||
min_in=0,
|
||||
max_in=self.train_config.max_denoising_steps,
|
||||
min_out=self.train_config.max_denoising_steps - 1,
|
||||
max_out=self.train_config.min_denoising_steps
|
||||
))
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
# this is from diffusers training code
|
||||
# Cubic sampling for favoring later or earlier timesteps
|
||||
# For more details about why cubic sampling is used for content / structure,
|
||||
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
|
||||
|
||||
# for content / structure, it is best to favor earlier timesteps
|
||||
# for style, it is best to favor later timesteps
|
||||
|
||||
timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'style':
|
||||
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'content':
|
||||
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timesteps = value_map(
|
||||
timesteps,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps - 1
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
timesteps = torch.randint(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
min_timestep = self.train_config.min_denoising_steps
|
||||
|
||||
# todo improve this, but is skews odds for higher timesteps
|
||||
# 50% chance to use midpoint as the min_time_step
|
||||
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
|
||||
if torch.rand(1) > 0.5:
|
||||
min_timestep = mid_point
|
||||
|
||||
# 50% chance to use midpoint as the min_time_step
|
||||
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
|
||||
if torch.rand(1) > 0.5:
|
||||
min_timestep = mid_point
|
||||
|
||||
min_timestep = int(min_timestep)
|
||||
|
||||
timesteps = torch.randint(
|
||||
min_timestep,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
@@ -477,6 +515,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
|
||||
|
||||
def run(self):
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
# run base process run
|
||||
BaseTrainProcess.run(self)
|
||||
|
||||
@@ -653,9 +692,34 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
flush()
|
||||
elif self.adapter_config is not None:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
# t2i adapter
|
||||
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
self.device_torch,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.adapter.parameters()
|
||||
self.sd.adapter = self.adapter
|
||||
flush()
|
||||
else:
|
||||
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchvision
|
||||
safetensors
|
||||
git+https://github.com/huggingface/diffusers.git
|
||||
diffusers==0.21.1
|
||||
transformers
|
||||
lycoris_lora
|
||||
flatten_json
|
||||
|
||||
@@ -39,6 +39,7 @@ class SampleConfig:
|
||||
|
||||
NetworkType = Literal['lora', 'locon']
|
||||
|
||||
|
||||
class NetworkConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: NetworkType = kwargs.get('type', 'lora')
|
||||
@@ -58,6 +59,17 @@ class NetworkConfig:
|
||||
self.dropout: Union[float, None] = kwargs.get('dropout', None)
|
||||
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.in_channels: int = kwargs.get('in_channels', 3)
|
||||
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
|
||||
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
|
||||
self.downscale_factor: int = kwargs.get('downscale_factor', 16)
|
||||
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
|
||||
self.image_dir: str = kwargs.get('image_dir', None)
|
||||
self.test_img_path: str = kwargs.get('test_img_path', None)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.trigger = kwargs.get('trigger', 'custom_embedding')
|
||||
@@ -66,9 +78,13 @@ class EmbeddingConfig:
|
||||
self.save_format = kwargs.get('save_format', 'safetensors')
|
||||
|
||||
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||
self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced')
|
||||
self.steps: int = kwargs.get('steps', 1000)
|
||||
self.lr = kwargs.get('lr', 1e-6)
|
||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||
@@ -80,8 +96,6 @@ class TrainConfig:
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.use_linear_denoising: int = kwargs.get('use_linear_denoising', False)
|
||||
self.use_progressive_denoising: int = kwargs.get('use_progressive_denoising', False)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
@@ -255,6 +269,7 @@ class GenerateImageConfig:
|
||||
output_ext: str = ImgExt, # extension to save image as if output_path is not specified
|
||||
output_tail: str = '', # tail to add to output filename
|
||||
add_prompt_file: bool = False, # add a prompt file with generated image
|
||||
adapter_image_path: str = None, # path to adapter image
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -277,6 +292,7 @@ class GenerateImageConfig:
|
||||
self.add_prompt_file: bool = add_prompt_file
|
||||
self.output_tail: str = output_tail
|
||||
self.gen_time: int = int(time.time() * 1000)
|
||||
self.adapter_image_path: str = adapter_image_path
|
||||
|
||||
# prompt string will override any settings above
|
||||
self._process_prompt_string()
|
||||
|
||||
@@ -341,13 +341,7 @@ class ToolkitNetworkMixin:
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
|
||||
elif isinstance(multiplier, list):
|
||||
tensor_list = []
|
||||
for m in multiplier:
|
||||
if isinstance(m, int) or isinstance(m, float):
|
||||
tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype))
|
||||
elif isinstance(m, torch.Tensor):
|
||||
tensor_list.append(m.clone().detach().to(device, dtype=dtype))
|
||||
tensor_multiplier = torch.cat(tensor_list)
|
||||
tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype)
|
||||
elif isinstance(multiplier, torch.Tensor):
|
||||
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
|
||||
|
||||
|
||||
@@ -161,10 +161,35 @@ def save_lora_from_diffusers(
|
||||
else:
|
||||
converted_key = key
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
|
||||
def save_t2i_from_diffusers(
|
||||
t2i_state_dict: 'OrderedDict',
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
dtype=get_torch_dtype('fp16'),
|
||||
):
|
||||
# todo: test compatibility with non diffusers
|
||||
converted_state_dict = OrderedDict()
|
||||
for key, value in t2i_state_dict.items():
|
||||
converted_state_dict[key] = value.detach().to('cpu', dtype=dtype)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta
|
||||
)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
|
||||
def load_t2i_model(
|
||||
path_to_file,
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
dtype: torch.dtype = torch.float32
|
||||
):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
converted_state_dict = OrderedDict()
|
||||
for key, value in raw_state_dict.items():
|
||||
# todo see if we need to convert dict
|
||||
converted_state_dict[key] = value.detach().to(device, dtype=dtype)
|
||||
return converted_state_dict
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import copy
|
||||
|
||||
@@ -15,16 +17,22 @@ empty_preset = {
|
||||
'training': False,
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
}
|
||||
},
|
||||
'adapter': {
|
||||
'training': False,
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_train_sd_device_state_preset (
|
||||
device: torch.DeviceObjType,
|
||||
def get_train_sd_device_state_preset(
|
||||
device: Union[str, torch.device],
|
||||
train_unet: bool = False,
|
||||
train_text_encoder: bool = False,
|
||||
cached_latents: bool = False,
|
||||
train_lora: bool = False,
|
||||
train_adapter: bool = False,
|
||||
train_embedding: bool = False,
|
||||
):
|
||||
preset = copy.deepcopy(empty_preset)
|
||||
@@ -51,9 +59,14 @@ def get_train_sd_device_state_preset (
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['unet']['training'] = True
|
||||
|
||||
|
||||
if train_lora:
|
||||
preset['text_encoder']['requires_grad'] = False
|
||||
preset['unet']['requires_grad'] = False
|
||||
|
||||
if train_adapter:
|
||||
preset['adapter']['requires_grad'] = True
|
||||
preset['adapter']['training'] = True
|
||||
preset['adapter']['device'] = device
|
||||
preset['unet']['training'] = True
|
||||
|
||||
return preset
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import shutil
|
||||
@@ -7,6 +8,7 @@ import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from PIL import Image
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.nn import Parameter
|
||||
@@ -22,11 +24,13 @@ from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.sd_device_states_presets import empty_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
||||
import diffusers
|
||||
|
||||
# tell it to shut up
|
||||
@@ -110,7 +114,7 @@ class StableDiffusion:
|
||||
self.unet: Union[None, 'UNet2DConditionModel']
|
||||
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers'] = noise_scheduler
|
||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
@@ -119,6 +123,7 @@ class StableDiffusion:
|
||||
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
self.adapter: Union['T2IAdapter', None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
|
||||
@@ -291,8 +296,18 @@ class StableDiffusion:
|
||||
if sampler.startswith("sample_") and self.is_xl:
|
||||
# using kdiffusion
|
||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||
else:
|
||||
elif self.is_xl:
|
||||
Pipe = StableDiffusionXLPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionPipeline
|
||||
|
||||
extra_args = {}
|
||||
if self.adapter:
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLAdapterPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
@@ -305,11 +320,12 @@ class StableDiffusion:
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
add_watermarker=False,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
# force turn that (ruin your images with obvious green and red dots) the #$@@ off!!!
|
||||
pipeline.watermark = None
|
||||
else:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
@@ -318,6 +334,7 @@ class StableDiffusion:
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
flush()
|
||||
# disable progress bar
|
||||
@@ -340,6 +357,12 @@ class StableDiffusion:
|
||||
for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False):
|
||||
gen_config = image_configs[i]
|
||||
|
||||
extra = {}
|
||||
if gen_config.adapter_image_path is not None:
|
||||
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['image'] = validation_image
|
||||
|
||||
if self.network is not None:
|
||||
self.network.multiplier = gen_config.network_multiplier
|
||||
torch.manual_seed(gen_config.seed)
|
||||
@@ -355,7 +378,6 @@ class StableDiffusion:
|
||||
grs = 0.7
|
||||
# grs = 0.0
|
||||
|
||||
extra = {}
|
||||
if sampler.startswith("sample_"):
|
||||
extra['use_karras_sigmas'] = True
|
||||
|
||||
@@ -379,6 +401,7 @@ class StableDiffusion:
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
**extra
|
||||
).images[0]
|
||||
|
||||
gen_config.save_image(img)
|
||||
@@ -517,6 +540,7 @@ class StableDiffusion:
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
@@ -558,6 +582,7 @@ class StableDiffusion:
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
@@ -855,6 +880,7 @@ class StableDiffusion:
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
self.device_state = {
|
||||
**empty_preset,
|
||||
'vae': {
|
||||
'training': self.vae.training,
|
||||
'device': self.vae.device,
|
||||
@@ -880,6 +906,12 @@ class StableDiffusion:
|
||||
'device': self.text_encoder.device,
|
||||
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
}
|
||||
if self.adapter is not None:
|
||||
self.device_state['adapter'] = {
|
||||
'training': self.adapter.training,
|
||||
'device': self.adapter.device,
|
||||
'requires_grad': self.adapter.requires_grad,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
# restores the device state for all modules
|
||||
@@ -927,6 +959,14 @@ class StableDiffusion:
|
||||
self.text_encoder.eval()
|
||||
self.text_encoder.to(state['text_encoder']['device'])
|
||||
self.text_encoder.requires_grad_(state['text_encoder']['requires_grad'])
|
||||
|
||||
if self.adapter is not None:
|
||||
self.adapter.to(state['adapter']['device'])
|
||||
self.adapter.requires_grad_(state['adapter']['requires_grad'])
|
||||
if state['adapter']['training']:
|
||||
self.adapter.train()
|
||||
else:
|
||||
self.adapter.eval()
|
||||
flush()
|
||||
|
||||
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
||||
@@ -940,9 +980,9 @@ class StableDiffusion:
|
||||
if device_state_preset in ['cache_latents']:
|
||||
active_modules = ['vae']
|
||||
if device_state_preset in ['generate']:
|
||||
active_modules = ['vae', 'unet', 'text_encoder']
|
||||
active_modules = ['vae', 'unet', 'text_encoder', 'adapter']
|
||||
|
||||
state = {}
|
||||
state = copy.deepcopy(empty_preset)
|
||||
# vae
|
||||
state['vae'] = {
|
||||
'training': 'vae' in training_modules,
|
||||
@@ -973,4 +1013,11 @@ class StableDiffusion:
|
||||
'requires_grad': 'text_encoder' in training_modules,
|
||||
}
|
||||
|
||||
if self.adapter is not None:
|
||||
state['adapter'] = {
|
||||
'training': 'adapter' in training_modules,
|
||||
'device': self.device_torch if 'adapter' in active_modules else 'cpu',
|
||||
'requires_grad': 'adapter' in training_modules,
|
||||
}
|
||||
|
||||
self.set_device_state(state)
|
||||
|
||||
Reference in New Issue
Block a user