Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.

This commit is contained in:
Jaret Burkett
2023-09-16 08:30:38 -06:00
parent 17e4fe40d7
commit 27f343fc08
8 changed files with 314 additions and 84 deletions

View File

@@ -1,3 +1,4 @@
import copy
import gc
import json
import shutil
@@ -7,6 +8,7 @@ import sys
import os
from collections import OrderedDict
from PIL import Image
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file, load_file
from torch.nn import Parameter
@@ -22,11 +24,13 @@ from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
from toolkit.sampler import get_sampler
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.sd_device_states_presets import empty_preset
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
import diffusers
# tell it to shut up
@@ -110,7 +114,7 @@ class StableDiffusion:
self.unet: Union[None, 'UNet2DConditionModel']
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers'] = noise_scheduler
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
# sdxl stuff
self.logit_scale = None
@@ -119,6 +123,7 @@ class StableDiffusion:
# to hold network if there is one
self.network = None
self.adapter: Union['T2IAdapter', None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
@@ -291,8 +296,18 @@ class StableDiffusion:
if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline
else:
elif self.is_xl:
Pipe = StableDiffusionXLPipeline
else:
Pipe = StableDiffusionPipeline
extra_args = {}
if self.adapter:
if self.is_xl:
Pipe = StableDiffusionXLAdapterPipeline
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
# TODO add clip skip
if self.is_xl:
@@ -305,11 +320,12 @@ class StableDiffusion:
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,
add_watermarker=False,
**extra_args
).to(self.device_torch)
# force turn that (ruin your images with obvious green and red dots) the #$@@ off!!!
pipeline.watermark = None
else:
pipeline = StableDiffusionPipeline(
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder,
@@ -318,6 +334,7 @@ class StableDiffusion:
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
**extra_args
).to(self.device_torch)
flush()
# disable progress bar
@@ -340,6 +357,12 @@ class StableDiffusion:
for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False):
gen_config = image_configs[i]
extra = {}
if gen_config.adapter_image_path is not None:
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
validation_image = validation_image.resize((gen_config.width, gen_config.height))
extra['image'] = validation_image
if self.network is not None:
self.network.multiplier = gen_config.network_multiplier
torch.manual_seed(gen_config.seed)
@@ -355,7 +378,6 @@ class StableDiffusion:
grs = 0.7
# grs = 0.0
extra = {}
if sampler.startswith("sample_"):
extra['use_karras_sigmas'] = True
@@ -379,6 +401,7 @@ class StableDiffusion:
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
**extra
).images[0]
gen_config.save_image(img)
@@ -517,6 +540,7 @@ class StableDiffusion:
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
).sample
if do_classifier_free_guidance:
@@ -558,6 +582,7 @@ class StableDiffusion:
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
**kwargs,
).sample
if do_classifier_free_guidance:
@@ -855,6 +880,7 @@ class StableDiffusion:
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
self.device_state = {
**empty_preset,
'vae': {
'training': self.vae.training,
'device': self.vae.device,
@@ -880,6 +906,12 @@ class StableDiffusion:
'device': self.text_encoder.device,
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
}
if self.adapter is not None:
self.device_state['adapter'] = {
'training': self.adapter.training,
'device': self.adapter.device,
'requires_grad': self.adapter.requires_grad,
}
def restore_device_state(self):
# restores the device state for all modules
@@ -927,6 +959,14 @@ class StableDiffusion:
self.text_encoder.eval()
self.text_encoder.to(state['text_encoder']['device'])
self.text_encoder.requires_grad_(state['text_encoder']['requires_grad'])
if self.adapter is not None:
self.adapter.to(state['adapter']['device'])
self.adapter.requires_grad_(state['adapter']['requires_grad'])
if state['adapter']['training']:
self.adapter.train()
else:
self.adapter.eval()
flush()
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
@@ -940,9 +980,9 @@ class StableDiffusion:
if device_state_preset in ['cache_latents']:
active_modules = ['vae']
if device_state_preset in ['generate']:
active_modules = ['vae', 'unet', 'text_encoder']
active_modules = ['vae', 'unet', 'text_encoder', 'adapter']
state = {}
state = copy.deepcopy(empty_preset)
# vae
state['vae'] = {
'training': 'vae' in training_modules,
@@ -973,4 +1013,11 @@ class StableDiffusion:
'requires_grad': 'text_encoder' in training_modules,
}
if self.adapter is not None:
state['adapter'] = {
'training': 'adapter' in training_modules,
'device': self.device_torch if 'adapter' in active_modules else 'cpu',
'requires_grad': 'adapter' in training_modules,
}
self.set_device_state(state)