mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Massive speed increase. Added latent caching both to disk and to memory
This commit is contained in:
@@ -2,7 +2,7 @@ import gc
|
||||
import json
|
||||
import shutil
|
||||
import typing
|
||||
from typing import Union, List, Tuple, Iterator
|
||||
from typing import Union, List, Literal, Iterator
|
||||
import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
@@ -48,6 +48,8 @@ DO_NOT_TRAIN_WEIGHTS = [
|
||||
"unet_time_embedding.linear_2.weight",
|
||||
]
|
||||
|
||||
DeviceStatePreset = Literal['cache_latents']
|
||||
|
||||
|
||||
class BlankNetwork:
|
||||
|
||||
@@ -102,6 +104,8 @@ class StableDiffusion:
|
||||
self.model_config = model_config
|
||||
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
||||
|
||||
self.device_state = None
|
||||
|
||||
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
|
||||
self.vae: Union[None, 'AutoencoderKL']
|
||||
self.unet: Union[None, 'UNet2DConditionModel']
|
||||
@@ -128,8 +132,6 @@ class StableDiffusion:
|
||||
if self.is_loaded:
|
||||
return
|
||||
dtype = get_torch_dtype(self.dtype)
|
||||
|
||||
# TODO handle other schedulers
|
||||
# sch = KDPM2DiscreteScheduler
|
||||
if self.noise_scheduler is None:
|
||||
scheduler = get_sampler('ddpm')
|
||||
@@ -146,6 +148,12 @@ class StableDiffusion:
|
||||
from toolkit.civitai import get_model_path_from_url
|
||||
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
||||
|
||||
load_args = {
|
||||
'scheduler': self.noise_scheduler,
|
||||
}
|
||||
if self.model_config.vae_path is not None:
|
||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -159,16 +167,17 @@ class StableDiffusion:
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
variant="fp16",
|
||||
**load_args
|
||||
)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='ddpm',
|
||||
device=self.device_torch,
|
||||
).to(self.device_torch)
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
flush()
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
@@ -204,23 +213,25 @@ class StableDiffusion:
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
load_safety_checker=False,
|
||||
requires_safety_checker=False,
|
||||
safety_checker=False,
|
||||
variant="fp16"
|
||||
variant="fp16",
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
scheduler_type='dpm',
|
||||
device=self.device_torch,
|
||||
load_safety_checker=False,
|
||||
requires_safety_checker=False,
|
||||
safety_checker=False
|
||||
torch_dtype=self.torch_dtype,
|
||||
safety_checker=False,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
flush()
|
||||
|
||||
pipe.register_to_config(requires_safety_checker=False)
|
||||
text_encoder = pipe.text_encoder
|
||||
@@ -235,10 +246,6 @@ class StableDiffusion:
|
||||
# add hacks to unet to help training
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
if self.model_config.vae_path is not None:
|
||||
external_vae = load_vae(self.model_config.vae_path, dtype)
|
||||
pipe.vae = external_vae
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.vae.eval()
|
||||
@@ -252,6 +259,7 @@ class StableDiffusion:
|
||||
self.pipeline = pipe
|
||||
self.is_loaded = True
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
|
||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if self.network is not None:
|
||||
@@ -266,27 +274,26 @@ class StableDiffusion:
|
||||
network.apply_stored_normalizer()
|
||||
network.is_normalizing = False
|
||||
|
||||
self.save_device_state()
|
||||
|
||||
# save current seed state for training
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
|
||||
original_device_dict = {
|
||||
'vae': self.vae.device,
|
||||
'unet': self.unet.device,
|
||||
# 'tokenizer': self.tokenizer.device,
|
||||
}
|
||||
|
||||
# handle sdxl text encoder
|
||||
if isinstance(self.text_encoder, list):
|
||||
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
|
||||
original_device_dict[f'text_encoder_{i}'] = encoder.device
|
||||
encoder.to(self.device_torch)
|
||||
encoder.eval()
|
||||
else:
|
||||
original_device_dict['text_encoder'] = self.text_encoder.device
|
||||
self.text_encoder.to(self.device_torch)
|
||||
self.text_encoder.eval()
|
||||
|
||||
self.vae.to(self.device_torch)
|
||||
self.vae.eval()
|
||||
self.unet.to(self.device_torch)
|
||||
self.unet.eval()
|
||||
flush()
|
||||
|
||||
noise_scheduler = self.noise_scheduler
|
||||
if sampler is not None:
|
||||
@@ -302,7 +309,6 @@ class StableDiffusion:
|
||||
else:
|
||||
Pipe = StableDiffusionXLPipeline
|
||||
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
pipeline = Pipe(
|
||||
@@ -328,6 +334,7 @@ class StableDiffusion:
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
).to(self.device_torch)
|
||||
flush()
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -366,7 +373,6 @@ class StableDiffusion:
|
||||
if sampler.startswith("sample_"):
|
||||
extra['use_karras_sigmas'] = True
|
||||
|
||||
|
||||
img = pipeline(
|
||||
prompt=gen_config.prompt,
|
||||
prompt_2=gen_config.prompt_2,
|
||||
@@ -400,13 +406,7 @@ class StableDiffusion:
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
self.vae.to(original_device_dict['vae'])
|
||||
self.unet.to(original_device_dict['unet'])
|
||||
if isinstance(self.text_encoder, list):
|
||||
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
|
||||
encoder.to(original_device_dict[f'text_encoder_{i}'])
|
||||
else:
|
||||
self.text_encoder.to(original_device_dict['text_encoder'])
|
||||
self.restore_device_state()
|
||||
if self.network is not None:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
@@ -666,7 +666,6 @@ class StableDiffusion:
|
||||
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
flush()
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
latents = latents * self.vae.config['scaling_factor']
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
@@ -766,7 +765,8 @@ class StableDiffusion:
|
||||
state_dict[new_key] = v
|
||||
return state_dict
|
||||
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[str, Parameter]:
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[
|
||||
str, Parameter]:
|
||||
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
||||
if vae:
|
||||
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
|
||||
@@ -794,7 +794,6 @@ class StableDiffusion:
|
||||
|
||||
return named_params
|
||||
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
version_string = '1'
|
||||
if self.is_v2:
|
||||
@@ -865,3 +864,103 @@ class StableDiffusion:
|
||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||
|
||||
return trainable_parameters
|
||||
|
||||
def save_device_state(self):
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
self.device_state = {
|
||||
'vae': {
|
||||
'training': self.vae.training,
|
||||
'device': self.vae.device,
|
||||
},
|
||||
'unet': {
|
||||
'training': self.unet.training,
|
||||
'device': self.unet.device,
|
||||
},
|
||||
}
|
||||
if isinstance(self.text_encoder, list):
|
||||
self.device_state['text_encoder']: List[dict] = []
|
||||
for encoder in self.text_encoder:
|
||||
self.device_state['text_encoder'].append({
|
||||
'training': encoder.training,
|
||||
'device': encoder.device,
|
||||
})
|
||||
else:
|
||||
self.device_state['text_encoder'] = {
|
||||
'training': self.text_encoder.training,
|
||||
'device': self.text_encoder.device,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
# restores the device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
if self.device_state is None:
|
||||
return
|
||||
self.set_device_state(self.device_state)
|
||||
self.device_state = None
|
||||
|
||||
def set_device_state(self, state):
|
||||
if state['vae']['training']:
|
||||
self.vae.train()
|
||||
else:
|
||||
self.vae.eval()
|
||||
self.vae.to(state['vae']['device'])
|
||||
if state['unet']['training']:
|
||||
self.unet.train()
|
||||
else:
|
||||
self.unet.eval()
|
||||
self.unet.to(state['unet']['device'])
|
||||
if isinstance(self.text_encoder, list):
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
if state['text_encoder'][i]['training']:
|
||||
encoder.train()
|
||||
else:
|
||||
encoder.eval()
|
||||
encoder.to(state['text_encoder'][i]['device'])
|
||||
else:
|
||||
if state['text_encoder']['training']:
|
||||
self.text_encoder.train()
|
||||
else:
|
||||
self.text_encoder.eval()
|
||||
self.text_encoder.to(state['text_encoder']['device'])
|
||||
flush()
|
||||
|
||||
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
||||
# sets a preset for device state
|
||||
|
||||
# save current state first
|
||||
self.save_device_state()
|
||||
|
||||
active_modules = []
|
||||
training_modules = []
|
||||
if device_state_preset in ['cache_latents']:
|
||||
active_modules = ['vae']
|
||||
|
||||
state = {}
|
||||
# vae
|
||||
state['vae'] = {
|
||||
'training': 'vae' in training_modules,
|
||||
'device': self.device_torch if 'vae' in active_modules else 'cpu',
|
||||
}
|
||||
|
||||
# unet
|
||||
state['unet'] = {
|
||||
'training': 'unet' in training_modules,
|
||||
'device': self.device_torch if 'unet' in active_modules else 'cpu',
|
||||
}
|
||||
|
||||
# text encoder
|
||||
if isinstance(self.text_encoder, list):
|
||||
state['text_encoder'] = []
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
state['text_encoder'].append({
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
})
|
||||
else:
|
||||
state['text_encoder'] = {
|
||||
'training': 'text_encoder' in training_modules,
|
||||
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
|
||||
}
|
||||
|
||||
self.set_device_state(state)
|
||||
|
||||
Reference in New Issue
Block a user