mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Massive speed increase. Added latent caching both to disk and to memory
This commit is contained in:
@@ -22,19 +22,21 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pass
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
|
||||
# textual inversion
|
||||
# if self.embedding is not None:
|
||||
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
||||
# self.sd.text_encoder.train()
|
||||
# move vae to device if we did not cache latents
|
||||
if not self.is_latents_cached:
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
else:
|
||||
# offload it. Already cached
|
||||
self.sd.vae.to('cpu')
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
flush()
|
||||
# flush()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
@@ -57,9 +59,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with network:
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
|
||||
# if not grad_on_text_encoder:
|
||||
# # detach the embeddings
|
||||
# conditional_embeds = conditional_embeds.detach()
|
||||
if not grad_on_text_encoder:
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
# flush()
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
@@ -68,7 +70,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
flush()
|
||||
# flush()
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
@@ -95,11 +97,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
flush()
|
||||
# flush()
|
||||
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
|
||||
@@ -16,6 +16,7 @@ from toolkit.lycoris_special import LycorisSpecialNetwork
|
||||
from toolkit.network_mixins import Network
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.sampler import get_sampler
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
@@ -73,6 +74,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.data_loader_reg: Union[DataLoader, None] = None
|
||||
self.trigger_word = self.get_conf('trigger_word', None)
|
||||
|
||||
# store is all are cached. Allows us to not load vae if we don't need to
|
||||
self.is_latents_cached = True
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
||||
@@ -82,6 +85,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
for raw_dataset in raw_datasets:
|
||||
dataset = DatasetConfig(**raw_dataset)
|
||||
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
||||
if not is_caching:
|
||||
self.is_latents_cached = False
|
||||
if dataset.is_reg:
|
||||
if self.datasets_reg is None:
|
||||
self.datasets_reg = []
|
||||
@@ -355,9 +361,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print("load_weights not implemented for non-network models")
|
||||
return None
|
||||
|
||||
def process_general_training_batch(self, batch):
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
imgs = batch.tensor
|
||||
prompts = batch.get_caption_list()
|
||||
is_reg_list = batch.get_is_reg_list()
|
||||
|
||||
@@ -382,11 +387,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
|
||||
batch_size = imgs.shape[0]
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
latents = self.sd.encode_images(imgs)
|
||||
imgs = None
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
else:
|
||||
latents = self.sd.encode_images(imgs)
|
||||
flush()
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
@@ -397,8 +409,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=imgs.shape[2],
|
||||
pixel_width=imgs.shape[3],
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
@@ -416,23 +428,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def run(self):
|
||||
# run base process run
|
||||
BaseTrainProcess.run(self)
|
||||
### HOOk ###
|
||||
self.before_dataset_load()
|
||||
# load datasets if passed in the root process
|
||||
if self.datasets is not None:
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size)
|
||||
if self.datasets_reg is not None:
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size)
|
||||
|
||||
### HOOK ###
|
||||
self.hook_before_model_load()
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
# may get disabled elsewhere
|
||||
self.sd.unet.enable_gradient_checkpointing()
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# model is loaded from BaseSDProcess
|
||||
@@ -480,6 +481,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
vae.eval()
|
||||
flush()
|
||||
|
||||
### HOOk ###
|
||||
self.before_dataset_load()
|
||||
# load datasets if passed in the root process
|
||||
if self.datasets is not None:
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd)
|
||||
if self.datasets_reg is not None:
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd)
|
||||
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
|
||||
@@ -667,13 +676,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.print("Generating baseline samples before training")
|
||||
self.sample(0)
|
||||
|
||||
self.progress_bar = tqdm(
|
||||
self.progress_bar = ToolkitProgressBar(
|
||||
total=self.train_config.steps,
|
||||
desc=self.job.name,
|
||||
leave=True,
|
||||
initial=self.step_num,
|
||||
iterable=range(0, self.train_config.steps),
|
||||
)
|
||||
self.progress_bar.pause()
|
||||
|
||||
if self.data_loader is not None:
|
||||
dataloader = self.data_loader
|
||||
@@ -691,12 +701,30 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# zero any gradients
|
||||
optimizer.zero_grad()
|
||||
flush()
|
||||
|
||||
|
||||
self.lr_scheduler.step(self.step_num)
|
||||
|
||||
if self.embedding is not None or self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.train()
|
||||
else:
|
||||
self.sd.text_encoder.train()
|
||||
else:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
if self.train_config.train_unet or self.embedding:
|
||||
self.sd.unet.train()
|
||||
else:
|
||||
self.sd.unet.eval()
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
self.progress_bar.unpause()
|
||||
with torch.no_grad():
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
@@ -725,21 +753,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# turn on normalization if we are using it and it is not on
|
||||
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
|
||||
self.network.is_normalizing = True
|
||||
flush()
|
||||
if self.embedding is not None or self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.train()
|
||||
else:
|
||||
self.sd.text_encoder.train()
|
||||
|
||||
self.sd.unet.train()
|
||||
# flush()
|
||||
### HOOK ###
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
flush()
|
||||
# flush()
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cuda.empty_cache()
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
@@ -757,24 +778,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
# pause progress bar
|
||||
self.progress_bar.unpause() # makes it so doesn't track time
|
||||
if is_sample_step:
|
||||
self.progress_bar.pause()
|
||||
# print above the progress bar
|
||||
self.sample(self.step_num)
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if is_save_step:
|
||||
# print above the progress bar
|
||||
self.progress_bar.pause()
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.refresh()
|
||||
self.progress_bar.unpause()
|
||||
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
@@ -789,6 +813,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
batch.cleanup()
|
||||
|
||||
self.progress_bar.close()
|
||||
self.sample(self.step_num + 1)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
@@ -1,4 +1,13 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def value_map(inputs, min_in, max_in, min_out, max_out):
|
||||
return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out
|
||||
|
||||
|
||||
def flush(garbage_collect=True):
|
||||
torch.cuda.empty_cache()
|
||||
if garbage_collect:
|
||||
gc.collect()
|
||||
|
||||
@@ -197,6 +197,11 @@ class DatasetConfig:
|
||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
|
||||
# cache latents will store them in memory
|
||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||
|
||||
# legacy compatability
|
||||
legacy_caption_type = kwargs.get('caption_type', None)
|
||||
if legacy_caption_type:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torchvision import transforms
|
||||
@@ -13,11 +12,13 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
def __init__(self, config):
|
||||
@@ -288,9 +289,14 @@ class PairedImageDataset(Dataset):
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
|
||||
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_config: 'DatasetConfig',
|
||||
batch_size=1,
|
||||
sd: 'StableDiffusion' = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_config = dataset_config
|
||||
folder_path = dataset_config.folder_path
|
||||
@@ -298,6 +304,15 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
if self.dataset_path is None:
|
||||
self.dataset_path = folder_path
|
||||
|
||||
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
|
||||
self.is_caching_latents_to_memory = dataset_config.cache_latents
|
||||
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
|
||||
|
||||
self.sd = sd
|
||||
|
||||
if self.sd is None and self.is_caching_latents:
|
||||
raise ValueError(f"sd is required for caching latents")
|
||||
|
||||
self.caption_type = dataset_config.caption_ext
|
||||
self.default_caption = dataset_config.default_caption
|
||||
self.random_scale = dataset_config.random_scale
|
||||
@@ -344,19 +359,21 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
# print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
def setup_epoch(self):
|
||||
# TODO: set this up to redo cropping and everything else
|
||||
# do not call for now
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
if self.is_caching_latents:
|
||||
self.cache_latents_all_latents()
|
||||
|
||||
def __len__(self):
|
||||
if self.dataset_config.buckets:
|
||||
@@ -381,7 +398,11 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
return self._get_single_item(item)
|
||||
|
||||
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
|
||||
def get_dataloader_from_datasets(
|
||||
dataset_options,
|
||||
batch_size=1,
|
||||
sd: 'StableDiffusion' = None,
|
||||
) -> DataLoader:
|
||||
if dataset_options is None or len(dataset_options) == 0:
|
||||
return None
|
||||
|
||||
@@ -402,7 +423,7 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
|
||||
for config in dataset_config_list:
|
||||
|
||||
if config.type == 'image':
|
||||
dataset = AiToolkitDataset(config, batch_size=batch_size)
|
||||
dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd)
|
||||
datasets.append(dataset)
|
||||
if config.buckets:
|
||||
has_buckets = True
|
||||
@@ -432,14 +453,14 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=dto_collation, # Use the custom collate function
|
||||
num_workers=2
|
||||
num_workers=1
|
||||
)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2,
|
||||
num_workers=1,
|
||||
collate_fn=dto_collation
|
||||
)
|
||||
return data_loader
|
||||
|
||||
@@ -6,7 +6,7 @@ from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -21,8 +21,9 @@ def print_once(msg):
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
class FileItemDTO(CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
|
||||
class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.path = kwargs.get('path', None)
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
# process width and height
|
||||
@@ -53,12 +54,22 @@ class FileItemDTO(CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
|
||||
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
|
||||
|
||||
class DataLoaderBatchDTO:
|
||||
def __init__(self, **kwargs):
|
||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
is_latents_cached = self.file_items[0].is_latent_cached
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
# if we have encoded latents, we concatenate them
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
|
||||
def get_is_reg_list(self):
|
||||
return [x.is_reg for x in self.file_items]
|
||||
@@ -82,3 +93,4 @@ class DataLoaderBatchDTO:
|
||||
self.tensor = None
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
del self.tensor
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, List, Dict, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.basic import flush
|
||||
from toolkit.buckets import get_bucket_for_image_size
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_loader import AiToolkitDataset
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
@@ -219,7 +231,9 @@ class ImageProcessingDTOMixin:
|
||||
self: 'FileItemDTO',
|
||||
transform: Union[None, transforms.Compose]
|
||||
):
|
||||
# todo make sure this matches
|
||||
# if we are caching latents, just do that
|
||||
if self.is_latent_cached:
|
||||
self.get_latent()
|
||||
try:
|
||||
img = Image.open(self.path).convert('RGB')
|
||||
img = exif_transpose(img)
|
||||
@@ -265,3 +279,139 @@ class ImageProcessingDTOMixin:
|
||||
img = transform(img)
|
||||
|
||||
self.tensor = img
|
||||
|
||||
|
||||
class LatentCachingFileItemDTOMixin:
|
||||
def __init__(self):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__()
|
||||
self._encoded_latent: Union[torch.Tensor, None] = None
|
||||
self._latent_path: Union[str, None] = None
|
||||
self.is_latent_cached = False
|
||||
self.is_caching_to_disk = False
|
||||
self.is_caching_to_memory = False
|
||||
self.latent_load_device = 'cpu'
|
||||
# sd1 or sdxl or others
|
||||
self.latent_space_version = 'sd1'
|
||||
# todo, increment this if we change the latent format to invalidate cache
|
||||
self.latent_version = 1
|
||||
|
||||
def get_latent_info_dict(self: 'FileItemDTO'):
|
||||
return OrderedDict([
|
||||
("filename", os.path.basename(self.path)),
|
||||
("scale_to_width", self.scale_to_width),
|
||||
("scale_to_height", self.scale_to_height),
|
||||
("crop_x", self.crop_x),
|
||||
("crop_y", self.crop_y),
|
||||
("crop_width", self.crop_width),
|
||||
("crop_height", self.crop_height),
|
||||
("latent_space_version", self.latent_space_version),
|
||||
("latent_version", self.latent_version),
|
||||
])
|
||||
|
||||
def get_latent_path(self: 'FileItemDTO', recalculate=False):
|
||||
if self._latent_path is not None and not recalculate:
|
||||
return self._latent_path
|
||||
else:
|
||||
# we store latents in a folder in same path as image called _latent_cache
|
||||
img_dir = os.path.dirname(self.path)
|
||||
latent_dir = os.path.join(img_dir, '_latent_cache')
|
||||
hash_dict = self.get_latent_info_dict()
|
||||
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
|
||||
# get base64 hash of md5 checksum of hash_dict
|
||||
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
|
||||
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
|
||||
hash_str = hash_str.replace('=', '')
|
||||
self._latent_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
|
||||
|
||||
return self._latent_path
|
||||
|
||||
def cleanup_latent(self):
|
||||
if self._encoded_latent is not None:
|
||||
if not self.is_caching_to_memory:
|
||||
# we are caching on disk, don't save in memory
|
||||
self._encoded_latent = None
|
||||
else:
|
||||
# move it back to cpu
|
||||
self._encoded_latent = self._encoded_latent.to('cpu')
|
||||
|
||||
def get_latent(self, device=None):
|
||||
if not self.is_latent_cached:
|
||||
return None
|
||||
if self._encoded_latent is None:
|
||||
# load it from disk
|
||||
state_dict = load_file(
|
||||
self.get_latent_path(),
|
||||
device=device if device is not None else self.latent_load_device
|
||||
)
|
||||
self._encoded_latent = state_dict['latent']
|
||||
return self._encoded_latent
|
||||
|
||||
|
||||
class LatentCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.latent_cache = {}
|
||||
|
||||
def cache_latents_all_latents(self: 'AiToolkitDataset'):
|
||||
print(f"Caching latents for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
to_disk = self.is_caching_latents_to_disk
|
||||
to_memory = self.is_caching_latents_to_memory
|
||||
|
||||
if to_disk:
|
||||
print(" - Saving latents to disk")
|
||||
if to_memory:
|
||||
print(" - Keeping latents in memory")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_latents')
|
||||
|
||||
# use tqdm to show progress
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = 'sd1'
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
file_item.latent_load_device = self.sd.device
|
||||
|
||||
latent_path = file_item.get_latent_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if os.path.exists(latent_path):
|
||||
if to_memory:
|
||||
# load it into memory
|
||||
state_dict = load_file(latent_path, device='cpu')
|
||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
else:
|
||||
# not saved to disk, calculate
|
||||
# load the image first
|
||||
file_item.load_and_process_image(self.transform)
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
# add batch dimension
|
||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||
# save_latent
|
||||
if to_disk:
|
||||
state_dict = OrderedDict([
|
||||
('latent', latent.clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||
save_file(state_dict, latent_path, metadata=meta)
|
||||
|
||||
if to_memory:
|
||||
# keep it in memory
|
||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.dtype)
|
||||
|
||||
flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
|
||||
22
toolkit/progress_bar.py
Normal file
22
toolkit/progress_bar.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
|
||||
|
||||
class ToolkitProgressBar(tqdm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.paused = False
|
||||
|
||||
def pause(self):
|
||||
if not self.paused:
|
||||
self.paused = True
|
||||
self.last_time = self._time()
|
||||
|
||||
def unpause(self):
|
||||
if self.paused:
|
||||
self.paused = False
|
||||
self.start_t += self._time() - self.last_time
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
if not self.paused:
|
||||
super().update(*args, **kwargs)
|
||||
@@ -495,7 +495,8 @@ def build_latent_image_batch_for_prompt_pair(
|
||||
|
||||
def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||
if trigger is None:
|
||||
return prompt
|
||||
# process as empty string to remove any [trigger] tokens
|
||||
trigger = ''
|
||||
output_prompt = prompt
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
@@ -513,15 +514,16 @@ def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_i
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
if trigger.strip() != "":
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
@@ -2,7 +2,7 @@ import gc
|
||||
import json
|
||||
import shutil
|
||||
import typing
|
||||
from typing import Union, List, Tuple, Iterator
|
||||
from typing import Union, List, Literal, Iterator
|
||||
import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
@@ -48,6 +48,8 @@ DO_NOT_TRAIN_WEIGHTS = [
|
||||
"unet_time_embedding.linear_2.weight",
|
||||
]
|
||||
|
||||
DeviceStatePreset = Literal['cache_latents']
|
||||
|
||||
|
||||
class BlankNetwork:
|
||||
|
||||
@@ -102,6 +104,8 @@ class StableDiffusion:
|
||||
self.model_config = model_config
|
||||
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
|
||||
self.device_state = None
|
||||
|
||||
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
|
||||
self.vae: Union[None, 'AutoencoderKL']
|
||||
self.unet: Union[None, 'UNet2DConditionModel']
|
||||
@@ -128,8 +132,6 @@ class StableDiffusion:
|
||||
if self.is_loaded:
|
||||
return
|
||||
dtype = get_torch_dtype(self.dtype)
|
||||
|
||||
# TODO handle other schedulers
|
||||
# sch = KDPM2DiscreteScheduler
|
||||
if self.noise_scheduler is None:
|
||||
scheduler = get_sampler('ddpm')
|
||||
@@ -146,6 +148,12 @@ class StableDiffusion:
|
||||
from toolkit.civitai import get_model_path_from_url
|
||||
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
||||
|
||||
load_args = {
|
||||
'scheduler': self.noise_scheduler,
|
||||
}
|
||||
if self.model_config.vae_path is not None:
|
||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -159,16 +167,17 @@ class StableDiffusion:
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
variant="fp16",
|
||||
**load_args
|
||||
)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
flush()
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
@@ -204,23 +213,25 @@ class StableDiffusion:
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
load_safety_checker=False,
|
||||
requires_safety_checker=False,
|
||||
safety_checker=False,
|
||||
variant="fp16"
|
||||
variant="fp16",
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
load_safety_checker=False,
|
||||
requires_safety_checker=False,
|
||||
safety_checker=False
|
||||
torch_dtype=self.torch_dtype,
|
||||
safety_checker=False,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
flush()
|
||||
|
||||
pipe.register_to_config(requires_safety_checker=False)
|
||||
text_encoder = pipe.text_encoder
|
||||
@@ -235,10 +246,6 @@ class StableDiffusion:
|
||||
# add hacks to unet to help training
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
if self.model_config.vae_path is not None:
|
||||
external_vae = load_vae(self.model_config.vae_path, dtype)
|
||||
pipe.vae = external_vae
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.vae.eval()
|
||||
@@ -252,6 +259,7 @@ class StableDiffusion:
|
||||
self.pipeline = pipe
|
||||
self.is_loaded = True
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
|
||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if self.network is not None:
|
||||
@@ -266,27 +274,26 @@ class StableDiffusion:
|
||||
network.apply_stored_normalizer()
|
||||
network.is_normalizing = False
|
||||
|
||||
self.save_device_state()
|
||||
|
||||
# save current seed state for training
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
|
||||
original_device_dict = {
|
||||
'vae': self.vae.device,
|
||||
'unet': self.unet.device,
|
||||
# 'tokenizer': self.tokenizer.device,
|
||||
}
|
||||
|
||||
# handle sdxl text encoder
|
||||
if isinstance(self.text_encoder, list):
|
||||
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
|
||||
original_device_dict[f'text_encoder_{i}'] = encoder.device
|
||||
encoder.to(self.device_torch)
|
||||
encoder.eval()
|
||||
else:
|
||||
original_device_dict['text_encoder'] = self.text_encoder.device
|
||||
self.text_encoder.to(self.device_torch)
|
||||
self.text_encoder.eval()
|
||||
|
||||
self.vae.to(self.device_torch)
|
||||
self.vae.eval()
|
||||
self.unet.to(self.device_torch)
|
||||
self.unet.eval()
|
||||
flush()
|
||||
|
||||
noise_scheduler = self.noise_scheduler
|
||||
if sampler is not None:
|
||||
@@ -302,7 +309,6 @@ class StableDiffusion:
|
||||
else:
|
||||
Pipe = StableDiffusionXLPipeline
|
||||
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
pipeline = Pipe(
|
||||
@@ -328,6 +334,7 @@ class StableDiffusion:
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
).to(self.device_torch)
|
||||
flush()
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -366,7 +373,6 @@ class StableDiffusion:
|
||||
if sampler.startswith("sample_"):
|
||||
extra['use_karras_sigmas'] = True
|
||||
|
||||
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
prompt_2=gen_config.prompt_2,
|
||||
@@ -400,13 +406,7 @@ class StableDiffusion:
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
self.vae.to(original_device_dict['vae'])
|
||||
self.unet.to(original_device_dict['unet'])
|
||||
if isinstance(self.text_encoder, list):
|
||||
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
|
||||
encoder.to(original_device_dict[f'text_encoder_{i}'])
|
||||
else:
|
||||
self.text_encoder.to(original_device_dict['text_encoder'])
|
||||
self.restore_device_state()
|
||||
if self.network is not None:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
@@ -666,7 +666,6 @@ class StableDiffusion:
|
||||
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
flush()
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
latents = latents * self.vae.config['scaling_factor']
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
@@ -766,7 +765,8 @@ class StableDiffusion:
|
||||
state_dict[new_key] = v
|
||||
return state_dict
|
||||
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[str, Parameter]:
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[
|
||||
str, Parameter]:
|
||||
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
||||
if vae:
|
||||
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
|
||||
@@ -794,7 +794,6 @@ class StableDiffusion:
|
||||
|
||||
return named_params
|
||||
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
version_string = '1'
|
||||
if self.is_v2:
|
||||
@@ -865,3 +864,103 @@ class StableDiffusion:
|
||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||
|
||||
return trainable_parameters
|
||||
|
||||
def save_device_state(self):
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
self.device_state = {
|
||||
'vae': {
|
||||
'training': self.vae.training,
|
||||
'device': self.vae.device,
|
||||
},
|
||||
'unet': {
|
||||
'training': self.unet.training,
|
||||
'device': self.unet.device,
|
||||
},
|
||||
}
|
||||
if isinstance(self.text_encoder, list):
|
||||
self.device_state['text_encoder']: List[dict] = []
|
||||
for encoder in self.text_encoder:
|
||||
self.device_state['text_encoder'].append({
|
||||
'training': encoder.training,
|
||||
'device': encoder.device,
|
||||
})
|
||||
else:
|
||||
self.device_state['text_encoder'] = {
|
||||
'training': self.text_encoder.training,
|
||||
'device': self.text_encoder.device,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
# restores the device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
if self.device_state is None:
|
||||
return
|
||||
self.set_device_state(self.device_state)
|
||||
self.device_state = None
|
||||
|
||||
def set_device_state(self, state):
|
||||
if state['vae']['training']:
|
||||
self.vae.train()
|
||||
else:
|
||||
self.vae.eval()
|
||||
self.vae.to(state['vae']['device'])
|
||||
if state['unet']['training']:
|
||||
self.unet.train()
|
||||
else:
|
||||
self.unet.eval()
|
||||
self.unet.to(state['unet']['device'])
|
||||
if isinstance(self.text_encoder, list):
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
if state['text_encoder'][i]['training']:
|
||||
encoder.train()
|
||||
else:
|
||||
encoder.eval()
|
||||
encoder.to(state['text_encoder'][i]['device'])
|
||||
else:
|
||||
if state['text_encoder']['training']:
|
||||
self.text_encoder.train()
|
||||
else:
|
||||
self.text_encoder.eval()
|
||||
self.text_encoder.to(state['text_encoder']['device'])
|
||||
flush()
|
||||
|
||||
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
||||
# sets a preset for device state
|
||||
|
||||
# save current state first
|
||||
self.save_device_state()
|
||||
|
||||
active_modules = []
|
||||
training_modules = []
|
||||
if device_state_preset in ['cache_latents']:
|
||||
active_modules = ['vae']
|
||||
|
||||
state = {}
|
||||
# vae
|
||||
state['vae'] = {
|
||||
'training': 'vae' in training_modules,
|
||||
'device': self.device_torch if 'vae' in active_modules else 'cpu',
|
||||
}
|
||||
|
||||
# unet
|
||||
state['unet'] = {
|
||||
'training': 'unet' in training_modules,
|
||||
'device': self.device_torch if 'unet' in active_modules else 'cpu',
|
||||
}
|
||||
|
||||
# text encoder
|
||||
if isinstance(self.text_encoder, list):
|
||||
state['text_encoder'] = []
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
state['text_encoder'].append({
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
})
|
||||
else:
|
||||
state['text_encoder'] = {
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
}
|
||||
|
||||
self.set_device_state(state)
|
||||
|
||||
Reference in New Issue
Block a user