diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 704bdcd8..2baf7f47 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -22,19 +22,21 @@ class SDTrainer(BaseSDTrainProcess): pass def hook_before_train_loop(self): - self.sd.vae.eval() - self.sd.vae.to(self.device_torch) - - # textual inversion - # if self.embedding is not None: - # set text encoder to train. Not sure if this is necessary but diffusers example did it - # self.sd.text_encoder.train() + # move vae to device if we did not cache latents + if not self.is_latents_cached: + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + else: + # offload it. Already cached + self.sd.vae.to('cpu') def hook_train_loop(self, batch): + dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) network_weight_list = batch.get_network_weight_list() - flush() + # flush() + self.optimizer.zero_grad() # text encoding grad_on_text_encoder = False @@ -57,9 +59,9 @@ class SDTrainer(BaseSDTrainProcess): with network: with torch.set_grad_enabled(grad_on_text_encoder): conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype) - # if not grad_on_text_encoder: - # # detach the embeddings - # conditional_embeds = conditional_embeds.detach() + if not grad_on_text_encoder: + # detach the embeddings + conditional_embeds = conditional_embeds.detach() # flush() noise_pred = self.sd.predict_noise( @@ -68,7 +70,7 @@ class SDTrainer(BaseSDTrainProcess): timestep=timesteps, guidance_scale=1.0, ) - flush() + # flush() # 9.18 gb noise = noise.to(self.device_torch, dtype=dtype).detach() @@ -95,11 +97,10 @@ class SDTrainer(BaseSDTrainProcess): # I spent weeks on fighting this. DON'T DO IT loss.backward() torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) - flush() + # flush() # apply gradients self.optimizer.step() - self.optimizer.zero_grad() self.lr_scheduler.step() if self.embedding is not None: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 45e77dcc..a2943c10 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -16,6 +16,7 @@ from toolkit.lycoris_special import LycorisSpecialNetwork from toolkit.network_mixins import Network from toolkit.optimizer import get_optimizer from toolkit.paths import CONFIG_ROOT +from toolkit.progress_bar import ToolkitProgressBar from toolkit.sampler import get_sampler from toolkit.scheduler import get_lr_scheduler @@ -73,6 +74,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.data_loader_reg: Union[DataLoader, None] = None self.trigger_word = self.get_conf('trigger_word', None) + # store is all are cached. Allows us to not load vae if we don't need to + self.is_latents_cached = True raw_datasets = self.get_conf('datasets', None) if raw_datasets is not None and len(raw_datasets) > 0: raw_datasets = preprocess_dataset_raw_config(raw_datasets) @@ -82,6 +85,9 @@ class BaseSDTrainProcess(BaseTrainProcess): if raw_datasets is not None and len(raw_datasets) > 0: for raw_dataset in raw_datasets: dataset = DatasetConfig(**raw_dataset) + is_caching = dataset.cache_latents or dataset.cache_latents_to_disk + if not is_caching: + self.is_latents_cached = False if dataset.is_reg: if self.datasets_reg is None: self.datasets_reg = [] @@ -355,9 +361,8 @@ class BaseSDTrainProcess(BaseTrainProcess): print("load_weights not implemented for non-network models") return None - def process_general_training_batch(self, batch): + def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): with torch.no_grad(): - imgs = batch.tensor prompts = batch.get_caption_list() is_reg_list = batch.get_is_reg_list() @@ -382,11 +387,18 @@ class BaseSDTrainProcess(BaseTrainProcess): ) conditioned_prompts.append(prompt) - batch_size = imgs.shape[0] - dtype = get_torch_dtype(self.train_config.dtype) - imgs = imgs.to(self.device_torch, dtype=dtype) - latents = self.sd.encode_images(imgs) + imgs = None + if batch.tensor is not None: + imgs = batch.tensor + imgs = imgs.to(self.device_torch, dtype=dtype) + if batch.latents is not None: + latents = batch.latents.to(self.device_torch, dtype=dtype) + else: + latents = self.sd.encode_images(imgs) + flush() + + batch_size = latents.shape[0] self.sd.noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch @@ -397,8 +409,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # get noise noise = self.sd.get_latent_noise( - pixel_height=imgs.shape[2], - pixel_width=imgs.shape[3], + height=latents.shape[2], + width=latents.shape[3], batch_size=batch_size, noise_offset=self.train_config.noise_offset ).to(self.device_torch, dtype=dtype) @@ -416,23 +428,12 @@ class BaseSDTrainProcess(BaseTrainProcess): def run(self): # run base process run BaseTrainProcess.run(self) - ### HOOk ### - self.before_dataset_load() - # load datasets if passed in the root process - if self.datasets is not None: - self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size) - if self.datasets_reg is not None: - self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size) ### HOOK ### self.hook_before_model_load() # run base sd process run self.sd.load_model() - if self.train_config.gradient_checkpointing: - # may get disabled elsewhere - self.sd.unet.enable_gradient_checkpointing() - dtype = get_torch_dtype(self.train_config.dtype) # model is loaded from BaseSDProcess @@ -480,6 +481,14 @@ class BaseSDTrainProcess(BaseTrainProcess): vae.eval() flush() + ### HOOk ### + self.before_dataset_load() + # load datasets if passed in the root process + if self.datasets is not None: + self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd) + if self.datasets_reg is not None: + self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd) + if self.network_config is not None: # TODO should we completely switch to LycorisSpecialNetwork? @@ -667,13 +676,14 @@ class BaseSDTrainProcess(BaseTrainProcess): self.print("Generating baseline samples before training") self.sample(0) - self.progress_bar = tqdm( + self.progress_bar = ToolkitProgressBar( total=self.train_config.steps, desc=self.job.name, leave=True, initial=self.step_num, iterable=range(0, self.train_config.steps), ) + self.progress_bar.pause() if self.data_loader is not None: dataloader = self.data_loader @@ -691,12 +701,30 @@ class BaseSDTrainProcess(BaseTrainProcess): # zero any gradients optimizer.zero_grad() - flush() + self.lr_scheduler.step(self.step_num) + if self.embedding is not None or self.train_config.train_text_encoder: + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.train() + else: + self.sd.text_encoder.train() + else: + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + else: + self.sd.text_encoder.eval() + if self.train_config.train_unet or self.embedding: + self.sd.unet.train() + else: + self.sd.unet.eval() + flush() # self.step_num = 0 for step in range(self.step_num, self.train_config.steps): + self.progress_bar.unpause() with torch.no_grad(): # if is even step and we have a reg dataset, use that # todo improve this logic to send one of each through if we can buckets and batch size might be an issue @@ -725,21 +753,14 @@ class BaseSDTrainProcess(BaseTrainProcess): # turn on normalization if we are using it and it is not on if self.network is not None and self.network_config.normalize and not self.network.is_normalizing: self.network.is_normalizing = True - flush() - if self.embedding is not None or self.train_config.train_text_encoder: - if isinstance(self.sd.text_encoder, list): - for te in self.sd.text_encoder: - te.train() - else: - self.sd.text_encoder.train() - - self.sd.unet.train() + # flush() ### HOOK ### loss_dict = self.hook_train_loop(batch) - flush() + # flush() # setup the networks to gradient checkpointing and everything works with torch.no_grad(): + torch.cuda.empty_cache() if self.train_config.optimizer.lower().startswith('dadaptation') or \ self.train_config.optimizer.lower().startswith('prodigy'): learning_rate = ( @@ -757,24 +778,27 @@ class BaseSDTrainProcess(BaseTrainProcess): # don't do on first step if self.step_num != self.start_step: - # pause progress bar - self.progress_bar.unpause() # makes it so doesn't track time if is_sample_step: + self.progress_bar.pause() # print above the progress bar self.sample(self.step_num) + self.progress_bar.unpause() if is_save_step: # print above the progress bar + self.progress_bar.pause() self.print(f"Saving at step {self.step_num}") self.save(self.step_num) + self.progress_bar.unpause() if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: + self.progress_bar.pause() # log to tensorboard if self.writer is not None: for key, value in loss_dict.items(): self.writer.add_scalar(f"{key}", value, self.step_num) self.writer.add_scalar(f"lr", learning_rate, self.step_num) - self.progress_bar.refresh() + self.progress_bar.unpause() # sets progress bar to match out step self.progress_bar.update(step - self.progress_bar.n) @@ -789,6 +813,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if isinstance(batch, DataLoaderBatchDTO): batch.cleanup() + self.progress_bar.close() self.sample(self.step_num + 1) print("") self.save() diff --git a/toolkit/basic.py b/toolkit/basic.py index 248ffc4e..f0464d69 100644 --- a/toolkit/basic.py +++ b/toolkit/basic.py @@ -1,4 +1,13 @@ +import gc + +import torch def value_map(inputs, min_in, max_in, min_out, max_out): return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out + + +def flush(garbage_collect=True): + torch.cuda.empty_cache() + if garbage_collect: + gc.collect() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 055a4f7d..50faac60 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -197,6 +197,11 @@ class DatasetConfig: self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) + # cache latents will store them in memory + self.cache_latents: bool = kwargs.get('cache_latents', False) + # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory + self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False) + # legacy compatability legacy_caption_type = kwargs.get('caption_type', None) if legacy_caption_type: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 0cd37840..682caceb 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -1,11 +1,10 @@ import json import os import random -from typing import List +from typing import List, TYPE_CHECKING import cv2 import numpy as np -import torch from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms @@ -13,11 +12,13 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset from tqdm import tqdm import albumentations as A -from toolkit import image_utils from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config -from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin +from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + class ImageDataset(Dataset, CaptionMixin): def __init__(self, config): @@ -288,9 +289,14 @@ class PairedImageDataset(Dataset): return img, prompt, (self.neg_weight, self.pos_weight) -class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): +class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): - def __init__(self, dataset_config: 'DatasetConfig', batch_size=1): + def __init__( + self, + dataset_config: 'DatasetConfig', + batch_size=1, + sd: 'StableDiffusion' = None, + ): super().__init__() self.dataset_config = dataset_config folder_path = dataset_config.folder_path @@ -298,6 +304,15 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): if self.dataset_path is None: self.dataset_path = folder_path + self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk + self.is_caching_latents_to_memory = dataset_config.cache_latents + self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk + + self.sd = sd + + if self.sd is None and self.is_caching_latents: + raise ValueError(f"sd is required for caching latents") + self.caption_type = dataset_config.caption_ext self.default_caption = dataset_config.default_caption self.random_scale = dataset_config.random_scale @@ -344,19 +359,21 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): # print(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" - self.setup_epoch() - self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] ]) + self.setup_epoch() + def setup_epoch(self): # TODO: set this up to redo cropping and everything else # do not call for now if self.dataset_config.buckets: # setup buckets self.setup_buckets() + if self.is_caching_latents: + self.cache_latents_all_latents() def __len__(self): if self.dataset_config.buckets: @@ -381,7 +398,11 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): return self._get_single_item(item) -def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader: +def get_dataloader_from_datasets( + dataset_options, + batch_size=1, + sd: 'StableDiffusion' = None, +) -> DataLoader: if dataset_options is None or len(dataset_options) == 0: return None @@ -402,7 +423,7 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader: for config in dataset_config_list: if config.type == 'image': - dataset = AiToolkitDataset(config, batch_size=batch_size) + dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd) datasets.append(dataset) if config.buckets: has_buckets = True @@ -432,14 +453,14 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader: drop_last=False, shuffle=True, collate_fn=dto_collation, # Use the custom collate function - num_workers=2 + num_workers=1 ) else: data_loader = DataLoader( concatenated_dataset, batch_size=batch_size, shuffle=True, - num_workers=2, + num_workers=1, collate_fn=dto_collation ) return data_loader diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 32a91f48..e6986c34 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -6,7 +6,7 @@ from PIL import Image from PIL.ImageOps import exif_transpose from toolkit import image_utils -from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin +from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig @@ -21,8 +21,9 @@ def print_once(msg): printed_messages.append(msg) -class FileItemDTO(CaptionProcessingDTOMixin, ImageProcessingDTOMixin): +class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, ImageProcessingDTOMixin): def __init__(self, **kwargs): + super().__init__() self.path = kwargs.get('path', None) self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) # process width and height @@ -53,12 +54,22 @@ class FileItemDTO(CaptionProcessingDTOMixin, ImageProcessingDTOMixin): def cleanup(self): self.tensor = None + self.cleanup_latent() class DataLoaderBatchDTO: def __init__(self, **kwargs): self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) - self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) + is_latents_cached = self.file_items[0].is_latent_cached + self.tensor: Union[torch.Tensor, None] = None + self.latents: Union[torch.Tensor, None] = None + if not is_latents_cached: + # only return a tensor if latents are not cached + self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) + # if we have encoded latents, we concatenate them + self.latents: Union[torch.Tensor, None] = None + if is_latents_cached: + self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) def get_is_reg_list(self): return [x.is_reg for x in self.file_items] @@ -82,3 +93,4 @@ class DataLoaderBatchDTO: self.tensor = None for file_item in self.file_items: file_item.cleanup() + del self.tensor diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 7e103610..2c2ffd59 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1,14 +1,26 @@ +import base64 +import hashlib +import json import math import os import random +from collections import OrderedDict from typing import TYPE_CHECKING, List, Dict, Union +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm + +from toolkit.basic import flush from toolkit.buckets import get_bucket_for_image_size +from toolkit.metadata import get_meta_for_safetensors from toolkit.prompt_utils import inject_trigger_into_prompt from torchvision import transforms from PIL import Image from PIL.ImageOps import exif_transpose +from toolkit.train_tools import get_torch_dtype + if TYPE_CHECKING: from toolkit.data_loader import AiToolkitDataset from toolkit.data_transfer_object.data_loader import FileItemDTO @@ -219,7 +231,9 @@ class ImageProcessingDTOMixin: self: 'FileItemDTO', transform: Union[None, transforms.Compose] ): - # todo make sure this matches + # if we are caching latents, just do that + if self.is_latent_cached: + self.get_latent() try: img = Image.open(self.path).convert('RGB') img = exif_transpose(img) @@ -265,3 +279,139 @@ class ImageProcessingDTOMixin: img = transform(img) self.tensor = img + + +class LatentCachingFileItemDTOMixin: + def __init__(self): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__() + self._encoded_latent: Union[torch.Tensor, None] = None + self._latent_path: Union[str, None] = None + self.is_latent_cached = False + self.is_caching_to_disk = False + self.is_caching_to_memory = False + self.latent_load_device = 'cpu' + # sd1 or sdxl or others + self.latent_space_version = 'sd1' + # todo, increment this if we change the latent format to invalidate cache + self.latent_version = 1 + + def get_latent_info_dict(self: 'FileItemDTO'): + return OrderedDict([ + ("filename", os.path.basename(self.path)), + ("scale_to_width", self.scale_to_width), + ("scale_to_height", self.scale_to_height), + ("crop_x", self.crop_x), + ("crop_y", self.crop_y), + ("crop_width", self.crop_width), + ("crop_height", self.crop_height), + ("latent_space_version", self.latent_space_version), + ("latent_version", self.latent_version), + ]) + + def get_latent_path(self: 'FileItemDTO', recalculate=False): + if self._latent_path is not None and not recalculate: + return self._latent_path + else: + # we store latents in a folder in same path as image called _latent_cache + img_dir = os.path.dirname(self.path) + latent_dir = os.path.join(img_dir, '_latent_cache') + hash_dict = self.get_latent_info_dict() + filename_no_ext = os.path.splitext(os.path.basename(self.path))[0] + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + self._latent_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors') + + return self._latent_path + + def cleanup_latent(self): + if self._encoded_latent is not None: + if not self.is_caching_to_memory: + # we are caching on disk, don't save in memory + self._encoded_latent = None + else: + # move it back to cpu + self._encoded_latent = self._encoded_latent.to('cpu') + + def get_latent(self, device=None): + if not self.is_latent_cached: + return None + if self._encoded_latent is None: + # load it from disk + state_dict = load_file( + self.get_latent_path(), + device=device if device is not None else self.latent_load_device + ) + self._encoded_latent = state_dict['latent'] + return self._encoded_latent + + +class LatentCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.latent_cache = {} + + def cache_latents_all_latents(self: 'AiToolkitDataset'): + print(f"Caching latents for {self.dataset_path}") + # cache all latents to disk + to_disk = self.is_caching_latents_to_disk + to_memory = self.is_caching_latents_to_memory + + if to_disk: + print(" - Saving latents to disk") + if to_memory: + print(" - Keeping latents in memory") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_latents') + + # use tqdm to show progress + for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): + # set latent space version + if self.sd.is_xl: + file_item.latent_space_version = 'sdxl' + else: + file_item.latent_space_version = 'sd1' + file_item.is_caching_to_disk = to_disk + file_item.is_caching_to_memory = to_memory + file_item.latent_load_device = self.sd.device + + latent_path = file_item.get_latent_path(recalculate=True) + # check if it is saved to disk already + if os.path.exists(latent_path): + if to_memory: + # load it into memory + state_dict = load_file(latent_path, device='cpu') + file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) + else: + # not saved to disk, calculate + # load the image first + file_item.load_and_process_image(self.transform) + dtype = self.sd.torch_dtype + device = self.sd.device_torch + # add batch dimension + imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) + latent = self.sd.encode_images(imgs).squeeze(0) + # save_latent + if to_disk: + state_dict = OrderedDict([ + ('latent', latent.clone().detach().cpu()), + ]) + # metadata + meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) + os.makedirs(os.path.dirname(latent_path), exist_ok=True) + save_file(state_dict, latent_path, metadata=meta) + + if to_memory: + # keep it in memory + file_item._encoded_latent = latent.to('cpu', dtype=self.sd.dtype) + + flush(garbage_collect=False) + file_item.is_latent_cached = True + + # restore device state + self.sd.restore_device_state() diff --git a/toolkit/progress_bar.py b/toolkit/progress_bar.py new file mode 100644 index 00000000..2707c56e --- /dev/null +++ b/toolkit/progress_bar.py @@ -0,0 +1,22 @@ +from tqdm import tqdm +import time + + +class ToolkitProgressBar(tqdm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.paused = False + + def pause(self): + if not self.paused: + self.paused = True + self.last_time = self._time() + + def unpause(self): + if self.paused: + self.paused = False + self.start_t += self._time() - self.last_time + + def update(self, *args, **kwargs): + if not self.paused: + super().update(*args, **kwargs) diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index d0673f5b..9741600f 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -495,7 +495,8 @@ def build_latent_image_batch_for_prompt_pair( def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True): if trigger is None: - return prompt + # process as empty string to remove any [trigger] tokens + trigger = '' output_prompt = prompt default_replacements = ["[name]", "[trigger]"] @@ -513,15 +514,16 @@ def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_i # replace it output_prompt = output_prompt.replace(to_replace, replace_with) - # see how many times replace_with is in the prompt - num_instances = output_prompt.count(replace_with) + if trigger.strip() != "": + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) - if num_instances == 0 and add_if_not_present: - # add it to the beginning of the prompt - output_prompt = replace_with + " " + output_prompt + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt - if num_instances > 1: - print( - f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + if num_instances > 1: + print( + f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") return output_prompt diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 6f87e15a..3ba7638e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -2,7 +2,7 @@ import gc import json import shutil import typing -from typing import Union, List, Tuple, Iterator +from typing import Union, List, Literal, Iterator import sys import os from collections import OrderedDict @@ -48,6 +48,8 @@ DO_NOT_TRAIN_WEIGHTS = [ "unet_time_embedding.linear_2.weight", ] +DeviceStatePreset = Literal['cache_latents'] + class BlankNetwork: @@ -102,6 +104,8 @@ class StableDiffusion: self.model_config = model_config self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + self.device_state = None + self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline'] self.vae: Union[None, 'AutoencoderKL'] self.unet: Union[None, 'UNet2DConditionModel'] @@ -128,8 +132,6 @@ class StableDiffusion: if self.is_loaded: return dtype = get_torch_dtype(self.dtype) - - # TODO handle other schedulers # sch = KDPM2DiscreteScheduler if self.noise_scheduler is None: scheduler = get_sampler('ddpm') @@ -146,6 +148,12 @@ class StableDiffusion: from toolkit.civitai import get_model_path_from_url model_path = get_model_path_from_url(self.model_config.name_or_path) + load_args = { + 'scheduler': self.noise_scheduler, + } + if self.model_config.vae_path is not None: + load_args['vae'] = load_vae(self.model_config.vae_path, dtype) + if self.model_config.is_xl: if self.custom_pipeline is not None: pipln = self.custom_pipeline @@ -159,16 +167,17 @@ class StableDiffusion: pipe = pipln.from_pretrained( model_path, dtype=dtype, - scheduler_type='ddpm', device=self.device_torch, - ).to(self.device_torch) + variant="fp16", + **load_args + ) else: pipe = pipln.from_single_file( model_path, - dtype=dtype, - scheduler_type='ddpm', device=self.device_torch, - ).to(self.device_torch) + torch_dtype=self.torch_dtype, + ) + flush() text_encoders = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] @@ -204,23 +213,25 @@ class StableDiffusion: pipe = pipln.from_pretrained( model_path, dtype=dtype, - scheduler_type='dpm', device=self.device_torch, load_safety_checker=False, requires_safety_checker=False, safety_checker=False, - variant="fp16" + variant="fp16", + **load_args ).to(self.device_torch) else: pipe = pipln.from_single_file( model_path, dtype=dtype, - scheduler_type='dpm', device=self.device_torch, load_safety_checker=False, requires_safety_checker=False, - safety_checker=False + torch_dtype=self.torch_dtype, + safety_checker=False, + **load_args ).to(self.device_torch) + flush() pipe.register_to_config(requires_safety_checker=False) text_encoder = pipe.text_encoder @@ -235,10 +246,6 @@ class StableDiffusion: # add hacks to unet to help training # pipe.unet = prepare_unet_for_training(pipe.unet) - if self.model_config.vae_path is not None: - external_vae = load_vae(self.model_config.vae_path, dtype) - pipe.vae = external_vae - self.unet = pipe.unet self.vae = pipe.vae.to(self.device_torch, dtype=dtype) self.vae.eval() @@ -252,6 +259,7 @@ class StableDiffusion: self.pipeline = pipe self.is_loaded = True + @torch.no_grad() def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None): # sample_folder = os.path.join(self.save_root, 'samples') if self.network is not None: @@ -266,27 +274,26 @@ class StableDiffusion: network.apply_stored_normalizer() network.is_normalizing = False + self.save_device_state() + # save current seed state for training rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - original_device_dict = { - 'vae': self.vae.device, - 'unet': self.unet.device, - # 'tokenizer': self.tokenizer.device, - } - # handle sdxl text encoder if isinstance(self.text_encoder, list): for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))): - original_device_dict[f'text_encoder_{i}'] = encoder.device encoder.to(self.device_torch) + encoder.eval() else: - original_device_dict['text_encoder'] = self.text_encoder.device self.text_encoder.to(self.device_torch) + self.text_encoder.eval() self.vae.to(self.device_torch) + self.vae.eval() self.unet.to(self.device_torch) + self.unet.eval() + flush() noise_scheduler = self.noise_scheduler if sampler is not None: @@ -302,7 +309,6 @@ class StableDiffusion: else: Pipe = StableDiffusionXLPipeline - # TODO add clip skip if self.is_xl: pipeline = Pipe( @@ -328,6 +334,7 @@ class StableDiffusion: feature_extractor=None, requires_safety_checker=False, ).to(self.device_torch) + flush() # disable progress bar pipeline.set_progress_bar_config(disable=True) @@ -366,7 +373,6 @@ class StableDiffusion: if sampler.startswith("sample_"): extra['use_karras_sigmas'] = True - img = pipeline( prompt=gen_config.prompt, prompt_2=gen_config.prompt_2, @@ -400,13 +406,7 @@ class StableDiffusion: if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - self.vae.to(original_device_dict['vae']) - self.unet.to(original_device_dict['unet']) - if isinstance(self.text_encoder, list): - for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))): - encoder.to(original_device_dict[f'text_encoder_{i}']) - else: - self.text_encoder.to(original_device_dict['text_encoder']) + self.restore_device_state() if self.network is not None: self.network.train() self.network.multiplier = start_multiplier @@ -666,7 +666,6 @@ class StableDiffusion: image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image) images = torch.stack(image_list) - flush() latents = self.vae.encode(images).latent_dist.sample() latents = latents * self.vae.config['scaling_factor'] latents = latents.to(device, dtype=dtype) @@ -766,7 +765,8 @@ class StableDiffusion: state_dict[new_key] = v return state_dict - def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[str, Parameter]: + def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[ + str, Parameter]: named_params: OrderedDict[str, Parameter] = OrderedDict() if vae: for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): @@ -794,7 +794,6 @@ class StableDiffusion: return named_params - def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): version_string = '1' if self.is_v2: @@ -865,3 +864,103 @@ class StableDiffusion: print(f"Found {len(params)} trainable parameter in text encoder") return trainable_parameters + + def save_device_state(self): + # saves the current device state for all modules + # this is useful for when we want to alter the state and restore it + self.device_state = { + 'vae': { + 'training': self.vae.training, + 'device': self.vae.device, + }, + 'unet': { + 'training': self.unet.training, + 'device': self.unet.device, + }, + } + if isinstance(self.text_encoder, list): + self.device_state['text_encoder']: List[dict] = [] + for encoder in self.text_encoder: + self.device_state['text_encoder'].append({ + 'training': encoder.training, + 'device': encoder.device, + }) + else: + self.device_state['text_encoder'] = { + 'training': self.text_encoder.training, + 'device': self.text_encoder.device, + } + + def restore_device_state(self): + # restores the device state for all modules + # this is useful for when we want to alter the state and restore it + if self.device_state is None: + return + self.set_device_state(self.device_state) + self.device_state = None + + def set_device_state(self, state): + if state['vae']['training']: + self.vae.train() + else: + self.vae.eval() + self.vae.to(state['vae']['device']) + if state['unet']['training']: + self.unet.train() + else: + self.unet.eval() + self.unet.to(state['unet']['device']) + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if state['text_encoder'][i]['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder'][i]['device']) + else: + if state['text_encoder']['training']: + self.text_encoder.train() + else: + self.text_encoder.eval() + self.text_encoder.to(state['text_encoder']['device']) + flush() + + def set_device_state_preset(self, device_state_preset: DeviceStatePreset): + # sets a preset for device state + + # save current state first + self.save_device_state() + + active_modules = [] + training_modules = [] + if device_state_preset in ['cache_latents']: + active_modules = ['vae'] + + state = {} + # vae + state['vae'] = { + 'training': 'vae' in training_modules, + 'device': self.device_torch if 'vae' in active_modules else 'cpu', + } + + # unet + state['unet'] = { + 'training': 'unet' in training_modules, + 'device': self.device_torch if 'unet' in active_modules else 'cpu', + } + + # text encoder + if isinstance(self.text_encoder, list): + state['text_encoder'] = [] + for i, encoder in enumerate(self.text_encoder): + state['text_encoder'].append({ + 'training': 'text_encoder' in training_modules, + 'device': self.device_torch if 'text_encoder' in active_modules else 'cpu', + }) + else: + state['text_encoder'] = { + 'training': 'text_encoder' in training_modules, + 'device': self.device_torch if 'text_encoder' in active_modules else 'cpu', + } + + self.set_device_state(state)