Massive speed increase. Added latent caching both to disk and to memory

This commit is contained in:
Jaret Burkett
2023-09-10 08:54:49 -06:00
parent 41a3f63b72
commit 34bfeba229
10 changed files with 455 additions and 109 deletions

View File

@@ -2,7 +2,7 @@ import gc
import json
import shutil
import typing
from typing import Union, List, Tuple, Iterator
from typing import Union, List, Literal, Iterator
import sys
import os
from collections import OrderedDict
@@ -48,6 +48,8 @@ DO_NOT_TRAIN_WEIGHTS = [
"unet_time_embedding.linear_2.weight",
]
DeviceStatePreset = Literal['cache_latents']
class BlankNetwork:
@@ -102,6 +104,8 @@ class StableDiffusion:
self.model_config = model_config
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
self.device_state = None
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
self.vae: Union[None, 'AutoencoderKL']
self.unet: Union[None, 'UNet2DConditionModel']
@@ -128,8 +132,6 @@ class StableDiffusion:
if self.is_loaded:
return
dtype = get_torch_dtype(self.dtype)
# TODO handle other schedulers
# sch = KDPM2DiscreteScheduler
if self.noise_scheduler is None:
scheduler = get_sampler('ddpm')
@@ -146,6 +148,12 @@ class StableDiffusion:
from toolkit.civitai import get_model_path_from_url
model_path = get_model_path_from_url(self.model_config.name_or_path)
load_args = {
'scheduler': self.noise_scheduler,
}
if self.model_config.vae_path is not None:
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
if self.model_config.is_xl:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
@@ -159,16 +167,17 @@ class StableDiffusion:
pipe = pipln.from_pretrained(
model_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
variant="fp16",
**load_args
)
else:
pipe = pipln.from_single_file(
model_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
torch_dtype=self.torch_dtype,
)
flush()
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
@@ -204,23 +213,25 @@ class StableDiffusion:
pipe = pipln.from_pretrained(
model_path,
dtype=dtype,
scheduler_type='dpm',
device=self.device_torch,
load_safety_checker=False,
requires_safety_checker=False,
safety_checker=False,
variant="fp16"
variant="fp16",
**load_args
).to(self.device_torch)
else:
pipe = pipln.from_single_file(
model_path,
dtype=dtype,
scheduler_type='dpm',
device=self.device_torch,
load_safety_checker=False,
requires_safety_checker=False,
safety_checker=False
torch_dtype=self.torch_dtype,
safety_checker=False,
**load_args
).to(self.device_torch)
flush()
pipe.register_to_config(requires_safety_checker=False)
text_encoder = pipe.text_encoder
@@ -235,10 +246,6 @@ class StableDiffusion:
# add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet)
if self.model_config.vae_path is not None:
external_vae = load_vae(self.model_config.vae_path, dtype)
pipe.vae = external_vae
self.unet = pipe.unet
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
self.vae.eval()
@@ -252,6 +259,7 @@ class StableDiffusion:
self.pipeline = pipe
self.is_loaded = True
@torch.no_grad()
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
# sample_folder = os.path.join(self.save_root, 'samples')
if self.network is not None:
@@ -266,27 +274,26 @@ class StableDiffusion:
network.apply_stored_normalizer()
network.is_normalizing = False
self.save_device_state()
# save current seed state for training
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
original_device_dict = {
'vae': self.vae.device,
'unet': self.unet.device,
# 'tokenizer': self.tokenizer.device,
}
# handle sdxl text encoder
if isinstance(self.text_encoder, list):
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
original_device_dict[f'text_encoder_{i}'] = encoder.device
encoder.to(self.device_torch)
encoder.eval()
else:
original_device_dict['text_encoder'] = self.text_encoder.device
self.text_encoder.to(self.device_torch)
self.text_encoder.eval()
self.vae.to(self.device_torch)
self.vae.eval()
self.unet.to(self.device_torch)
self.unet.eval()
flush()
noise_scheduler = self.noise_scheduler
if sampler is not None:
@@ -302,7 +309,6 @@ class StableDiffusion:
else:
Pipe = StableDiffusionXLPipeline
# TODO add clip skip
if self.is_xl:
pipeline = Pipe(
@@ -328,6 +334,7 @@ class StableDiffusion:
feature_extractor=None,
requires_safety_checker=False,
).to(self.device_torch)
flush()
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
@@ -366,7 +373,6 @@ class StableDiffusion:
if sampler.startswith("sample_"):
extra['use_karras_sigmas'] = True
img = pipeline(
prompt=gen_config.prompt,
prompt_2=gen_config.prompt_2,
@@ -400,13 +406,7 @@ class StableDiffusion:
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
self.vae.to(original_device_dict['vae'])
self.unet.to(original_device_dict['unet'])
if isinstance(self.text_encoder, list):
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
encoder.to(original_device_dict[f'text_encoder_{i}'])
else:
self.text_encoder.to(original_device_dict['text_encoder'])
self.restore_device_state()
if self.network is not None:
self.network.train()
self.network.multiplier = start_multiplier
@@ -666,7 +666,6 @@ class StableDiffusion:
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
images = torch.stack(image_list)
flush()
latents = self.vae.encode(images).latent_dist.sample()
latents = latents * self.vae.config['scaling_factor']
latents = latents.to(device, dtype=dtype)
@@ -766,7 +765,8 @@ class StableDiffusion:
state_dict[new_key] = v
return state_dict
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[str, Parameter]:
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[
str, Parameter]:
named_params: OrderedDict[str, Parameter] = OrderedDict()
if vae:
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
@@ -794,7 +794,6 @@ class StableDiffusion:
return named_params
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
version_string = '1'
if self.is_v2:
@@ -865,3 +864,103 @@ class StableDiffusion:
print(f"Found {len(params)} trainable parameter in text encoder")
return trainable_parameters
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
self.device_state = {
'vae': {
'training': self.vae.training,
'device': self.vae.device,
},
'unet': {
'training': self.unet.training,
'device': self.unet.device,
},
}
if isinstance(self.text_encoder, list):
self.device_state['text_encoder']: List[dict] = []
for encoder in self.text_encoder:
self.device_state['text_encoder'].append({
'training': encoder.training,
'device': encoder.device,
})
else:
self.device_state['text_encoder'] = {
'training': self.text_encoder.training,
'device': self.text_encoder.device,
}
def restore_device_state(self):
# restores the device state for all modules
# this is useful for when we want to alter the state and restore it
if self.device_state is None:
return
self.set_device_state(self.device_state)
self.device_state = None
def set_device_state(self, state):
if state['vae']['training']:
self.vae.train()
else:
self.vae.eval()
self.vae.to(state['vae']['device'])
if state['unet']['training']:
self.unet.train()
else:
self.unet.eval()
self.unet.to(state['unet']['device'])
if isinstance(self.text_encoder, list):
for i, encoder in enumerate(self.text_encoder):
if state['text_encoder'][i]['training']:
encoder.train()
else:
encoder.eval()
encoder.to(state['text_encoder'][i]['device'])
else:
if state['text_encoder']['training']:
self.text_encoder.train()
else:
self.text_encoder.eval()
self.text_encoder.to(state['text_encoder']['device'])
flush()
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
# sets a preset for device state
# save current state first
self.save_device_state()
active_modules = []
training_modules = []
if device_state_preset in ['cache_latents']:
active_modules = ['vae']
state = {}
# vae
state['vae'] = {
'training': 'vae' in training_modules,
'device': self.device_torch if 'vae' in active_modules else 'cpu',
}
# unet
state['unet'] = {
'training': 'unet' in training_modules,
'device': self.device_torch if 'unet' in active_modules else 'cpu',
}
# text encoder
if isinstance(self.text_encoder, list):
state['text_encoder'] = []
for i, encoder in enumerate(self.text_encoder):
state['text_encoder'].append({
'training': 'text_encoder' in training_modules,
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
})
else:
state['text_encoder'] = {
'training': 'text_encoder' in training_modules,
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
}
self.set_device_state(state)