Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.

This commit is contained in:
Jaret Burkett
2023-09-16 08:30:38 -06:00
parent 17e4fe40d7
commit 27f343fc08
8 changed files with 314 additions and 84 deletions

View File

@@ -1,17 +1,27 @@
import os.path
from collections import OrderedDict
from PIL import Image
from torch.utils.data import DataLoader
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc
import torch
from jobs.process import BaseSDTrainProcess
from torchvision import transforms
def flush():
torch.cuda.empty_cache()
gc.collect()
adapter_transforms = transforms.Compose([
transforms.PILToTensor(),
])
class SDTrainer(BaseSDTrainProcess):
@@ -31,11 +41,47 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu')
flush()
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
adapter_folder_path = self.adapter_config.image_dir
adapter_images = []
# loop through images
for file_item in batch.file_items:
img_path = file_item.path
file_name_no_ext = os.path.basename(img_path).split('.')[0]
# find the image
for ext in img_ext_list:
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
break
adapter_tensors = []
# load images with torch transforms
for adapter_image in adapter_images:
img = Image.open(adapter_image)
img = adapter_transforms(img)
adapter_tensors.append(img)
# stack them
adapter_tensors = torch.stack(adapter_tensors)
return adapter_tensors
def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
adapter_images = None
sigmas = None
if self.adapter:
# todo move this to data loader
adapter_images = self.get_adapter_images(batch)
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
# flush()
self.optimizer.zero_grad()
@@ -64,30 +110,55 @@ class SDTrainer(BaseSDTrainProcess):
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
# flush()
pred_kwargs = {}
if self.adapter:
down_block_additional_residuals = self.adapter(adapter_images)
down_block_additional_residuals = [
sample.to(dtype=dtype) for sample in down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
# flush()
# 9.18 gb
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.adapter:
# todo, diffusers does this on t2i training, is it better approach?
# Denoise the latents
denoised_latents = noise_pred * (-sigmas) + noisy_latents
weighing = sigmas ** -2.0
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
# Get the target for loss depending on the prediction type
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
target = batch.latents # we are computing loss against denoise latents
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
# MSE loss
loss = torch.mean(
(weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
dim=1,
)
else:
target = noise
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
# TODO: I think the sigma method does not need this. Check
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()

View File

@@ -5,6 +5,7 @@ from collections import OrderedDict
import os
from typing import Union
from diffusers import T2IAdapter
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
import torch
@@ -21,6 +22,7 @@ from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
from toolkit.progress_bar import ToolkitProgressBar
from toolkit.sampler import get_sampler
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model
from toolkit.scheduler import get_lr_scheduler
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
@@ -34,7 +36,7 @@ import gc
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig
def flush():
@@ -105,9 +107,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
# t2i adapter
self.adapter_config = None
adapter_raw = self.get_conf('adapter', None)
if adapter_raw is not None:
self.adapter_config = AdapterConfig(**adapter_raw)
# sdxl adapters end in _xl. Only full_adapter_xl for now
if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
self.adapter_config.adapter_type += '_xl'
model_config_to_load = copy.deepcopy(self.model_config)
if self.embed_config is None and self.network_config is None:
if self.embed_config is None and self.network_config is None and self.adapter_config is None:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
@@ -135,6 +146,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, None] = None
self.embedding: Union[Embedding, None] = None
# get the device state preset based on what we are training
@@ -144,6 +156,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_text_encoder=self.train_config.train_text_encoder,
cached_latents=self.is_latents_cached,
train_lora=self.network_config is not None,
train_adapter=self.adapter_config is not None,
train_embedding=self.embed_config is not None,
)
@@ -305,6 +318,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
# replace extension
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
self.embedding.save(emb_file_path)
elif self.adapter is not None:
# save adapter
state_dict = self.adapter.state_dict()
save_t2i_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
else:
self.sd.save(
file_path,
@@ -360,20 +382,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
return None
def load_training_state_from_metadata(self, path):
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
def load_weights(self, path):
if self.network is not None:
extra_weights = self.network.load_weights(path)
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
self.load_training_state_from_metadata(path)
return extra_weights
else:
print("load_weights not implemented for non-network models")
return None
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, )
timesteps = timesteps.to(self.device_torch, )
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
prompts = batch.get_caption_list()
@@ -407,8 +444,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
imgs = imgs.to(self.device_torch, dtype=dtype)
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
else:
latents = self.sd.encode_images(imgs)
batch.latents = latents
flush()
batch_size = latents.shape[0]
@@ -416,47 +455,46 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
if self.train_config.use_progressive_denoising:
min_timestep = int(value_map(
self.step_num,
min_in=0,
max_in=self.train_config.max_denoising_steps,
min_out=self.train_config.min_denoising_steps,
max_out=self.train_config.max_denoising_steps
))
elif self.train_config.use_linear_denoising:
# starts at max steps and walks down to min steps
min_timestep = int(value_map(
self.step_num,
min_in=0,
max_in=self.train_config.max_denoising_steps,
min_out=self.train_config.max_denoising_steps - 1,
max_out=self.train_config.min_denoising_steps
))
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
# for content / structure, it is best to favor earlier timesteps
# for style, it is best to favor later timesteps
timesteps = torch.rand((batch_size,), device=latents.device)
if self.train_config.content_or_style == 'style':
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'content':
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timesteps = value_map(
timesteps,
0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps
)
timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps - 1
)
elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
else:
min_timestep = self.train_config.min_denoising_steps
# todo improve this, but is skews odds for higher timesteps
# 50% chance to use midpoint as the min_time_step
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
if torch.rand(1) > 0.5:
min_timestep = mid_point
# 50% chance to use midpoint as the min_time_step
mid_point = (self.train_config.max_denoising_steps + min_timestep) / 2
if torch.rand(1) > 0.5:
min_timestep = mid_point
min_timestep = int(min_timestep)
timesteps = torch.randint(
min_timestep,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
# get noise
noise = self.sd.get_latent_noise(
@@ -477,6 +515,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
def run(self):
# torch.autograd.set_detect_anomaly(True)
# run base process run
BaseTrainProcess.run(self)
@@ -653,9 +692,34 @@ class BaseSDTrainProcess(BaseTrainProcess):
# set trainable params
params = self.embedding.get_trainable_params()
flush()
elif self.adapter_config is not None:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
# t2i adapter
latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
loaded_state_dict = load_t2i_model(
latest_save_path,
self.device_torch,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
self.load_training_state_from_metadata(latest_save_path)
params = self.get_params()
if not params:
# set trainable params
params = self.adapter.parameters()
self.sd.adapter = self.adapter
flush()
else:
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)

View File

@@ -1,7 +1,7 @@
torch
torchvision
safetensors
git+https://github.com/huggingface/diffusers.git
diffusers==0.21.1
transformers
lycoris_lora
flatten_json

View File

@@ -39,6 +39,7 @@ class SampleConfig:
NetworkType = Literal['lora', 'locon']
class NetworkConfig:
def __init__(self, **kwargs):
self.type: NetworkType = kwargs.get('type', 'lora')
@@ -58,6 +59,17 @@ class NetworkConfig:
self.dropout: Union[float, None] = kwargs.get('dropout', None)
class AdapterConfig:
def __init__(self, **kwargs):
self.in_channels: int = kwargs.get('in_channels', 3)
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
self.downscale_factor: int = kwargs.get('downscale_factor', 16)
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
self.image_dir: str = kwargs.get('image_dir', None)
self.test_img_path: str = kwargs.get('test_img_path', None)
class EmbeddingConfig:
def __init__(self, **kwargs):
self.trigger = kwargs.get('trigger', 'custom_embedding')
@@ -66,9 +78,13 @@ class EmbeddingConfig:
self.save_format = kwargs.get('save_format', 'safetensors')
ContentOrStyleType = Literal['balanced', 'style', 'content']
class TrainConfig:
def __init__(self, **kwargs):
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced')
self.steps: int = kwargs.get('steps', 1000)
self.lr = kwargs.get('lr', 1e-6)
self.unet_lr = kwargs.get('unet_lr', self.lr)
@@ -80,8 +96,6 @@ class TrainConfig:
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
self.use_linear_denoising: int = kwargs.get('use_linear_denoising', False)
self.use_progressive_denoising: int = kwargs.get('use_progressive_denoising', False)
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)
@@ -255,6 +269,7 @@ class GenerateImageConfig:
output_ext: str = ImgExt, # extension to save image as if output_path is not specified
output_tail: str = '', # tail to add to output filename
add_prompt_file: bool = False, # add a prompt file with generated image
adapter_image_path: str = None, # path to adapter image
):
self.width: int = width
self.height: int = height
@@ -277,6 +292,7 @@ class GenerateImageConfig:
self.add_prompt_file: bool = add_prompt_file
self.output_tail: str = output_tail
self.gen_time: int = int(time.time() * 1000)
self.adapter_image_path: str = adapter_image_path
# prompt string will override any settings above
self._process_prompt_string()

View File

@@ -341,13 +341,7 @@ class ToolkitNetworkMixin:
if isinstance(multiplier, int) or isinstance(multiplier, float):
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
elif isinstance(multiplier, list):
tensor_list = []
for m in multiplier:
if isinstance(m, int) or isinstance(m, float):
tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype))
elif isinstance(m, torch.Tensor):
tensor_list.append(m.clone().detach().to(device, dtype=dtype))
tensor_multiplier = torch.cat(tensor_list)
tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype)
elif isinstance(multiplier, torch.Tensor):
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)

View File

@@ -161,10 +161,35 @@ def save_lora_from_diffusers(
else:
converted_key = key
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
def save_t2i_from_diffusers(
t2i_state_dict: 'OrderedDict',
output_file: str,
meta: 'OrderedDict',
dtype=get_torch_dtype('fp16'),
):
# todo: test compatibility with non diffusers
converted_state_dict = OrderedDict()
for key, value in t2i_state_dict.items():
converted_state_dict[key] = value.detach().to('cpu', dtype=dtype)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta
)
save_file(converted_state_dict, output_file, metadata=meta)
def load_t2i_model(
path_to_file,
device: Union[str, torch.device] = 'cpu',
dtype: torch.dtype = torch.float32
):
raw_state_dict = load_file(path_to_file, device)
converted_state_dict = OrderedDict()
for key, value in raw_state_dict.items():
# todo see if we need to convert dict
converted_state_dict[key] = value.detach().to(device, dtype=dtype)
return converted_state_dict

View File

@@ -1,3 +1,5 @@
from typing import Union
import torch
import copy
@@ -15,16 +17,22 @@ empty_preset = {
'training': False,
'requires_grad': False,
'device': 'cpu',
}
},
'adapter': {
'training': False,
'requires_grad': False,
'device': 'cpu',
},
}
def get_train_sd_device_state_preset (
device: torch.DeviceObjType,
def get_train_sd_device_state_preset(
device: Union[str, torch.device],
train_unet: bool = False,
train_text_encoder: bool = False,
cached_latents: bool = False,
train_lora: bool = False,
train_adapter: bool = False,
train_embedding: bool = False,
):
preset = copy.deepcopy(empty_preset)
@@ -51,9 +59,14 @@ def get_train_sd_device_state_preset (
preset['text_encoder']['training'] = True
preset['unet']['training'] = True
if train_lora:
preset['text_encoder']['requires_grad'] = False
preset['unet']['requires_grad'] = False
if train_adapter:
preset['adapter']['requires_grad'] = True
preset['adapter']['training'] = True
preset['adapter']['device'] = device
preset['unet']['training'] = True
return preset

View File

@@ -1,3 +1,4 @@
import copy
import gc
import json
import shutil
@@ -7,6 +8,7 @@ import sys
import os
from collections import OrderedDict
from PIL import Image
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file, load_file
from torch.nn import Parameter
@@ -22,11 +24,13 @@ from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
from toolkit.sampler import get_sampler
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.sd_device_states_presets import empty_preset
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
import diffusers
# tell it to shut up
@@ -110,7 +114,7 @@ class StableDiffusion:
self.unet: Union[None, 'UNet2DConditionModel']
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers'] = noise_scheduler
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
# sdxl stuff
self.logit_scale = None
@@ -119,6 +123,7 @@ class StableDiffusion:
# to hold network if there is one
self.network = None
self.adapter: Union['T2IAdapter', None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
@@ -291,8 +296,18 @@ class StableDiffusion:
if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline
else:
elif self.is_xl:
Pipe = StableDiffusionXLPipeline
else:
Pipe = StableDiffusionPipeline
extra_args = {}
if self.adapter:
if self.is_xl:
Pipe = StableDiffusionXLAdapterPipeline
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
# TODO add clip skip
if self.is_xl:
@@ -305,11 +320,12 @@ class StableDiffusion:
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,
add_watermarker=False,
**extra_args
).to(self.device_torch)
# force turn that (ruin your images with obvious green and red dots) the #$@@ off!!!
pipeline.watermark = None
else:
pipeline = StableDiffusionPipeline(
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder,
@@ -318,6 +334,7 @@ class StableDiffusion:
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
**extra_args
).to(self.device_torch)
flush()
# disable progress bar
@@ -340,6 +357,12 @@ class StableDiffusion:
for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False):
gen_config = image_configs[i]
extra = {}
if gen_config.adapter_image_path is not None:
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
validation_image = validation_image.resize((gen_config.width, gen_config.height))
extra['image'] = validation_image
if self.network is not None:
self.network.multiplier = gen_config.network_multiplier
torch.manual_seed(gen_config.seed)
@@ -355,7 +378,6 @@ class StableDiffusion:
grs = 0.7
# grs = 0.0
extra = {}
if sampler.startswith("sample_"):
extra['use_karras_sigmas'] = True
@@ -379,6 +401,7 @@ class StableDiffusion:
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
**extra
).images[0]
gen_config.save_image(img)
@@ -517,6 +540,7 @@ class StableDiffusion:
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
).sample
if do_classifier_free_guidance:
@@ -558,6 +582,7 @@ class StableDiffusion:
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
**kwargs,
).sample
if do_classifier_free_guidance:
@@ -855,6 +880,7 @@ class StableDiffusion:
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
self.device_state = {
**empty_preset,
'vae': {
'training': self.vae.training,
'device': self.vae.device,
@@ -880,6 +906,12 @@ class StableDiffusion:
'device': self.text_encoder.device,
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
}
if self.adapter is not None:
self.device_state['adapter'] = {
'training': self.adapter.training,
'device': self.adapter.device,
'requires_grad': self.adapter.requires_grad,
}
def restore_device_state(self):
# restores the device state for all modules
@@ -927,6 +959,14 @@ class StableDiffusion:
self.text_encoder.eval()
self.text_encoder.to(state['text_encoder']['device'])
self.text_encoder.requires_grad_(state['text_encoder']['requires_grad'])
if self.adapter is not None:
self.adapter.to(state['adapter']['device'])
self.adapter.requires_grad_(state['adapter']['requires_grad'])
if state['adapter']['training']:
self.adapter.train()
else:
self.adapter.eval()
flush()
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
@@ -940,9 +980,9 @@ class StableDiffusion:
if device_state_preset in ['cache_latents']:
active_modules = ['vae']
if device_state_preset in ['generate']:
active_modules = ['vae', 'unet', 'text_encoder']
active_modules = ['vae', 'unet', 'text_encoder', 'adapter']
state = {}
state = copy.deepcopy(empty_preset)
# vae
state['vae'] = {
'training': 'vae' in training_modules,
@@ -973,4 +1013,11 @@ class StableDiffusion:
'requires_grad': 'text_encoder' in training_modules,
}
if self.adapter is not None:
state['adapter'] = {
'training': 'adapter' in training_modules,
'device': self.device_torch if 'adapter' in active_modules else 'cpu',
'requires_grad': 'adapter' in training_modules,
}
self.set_device_state(state)