mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-24 16:29:26 +00:00
Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import shutil
|
||||
@@ -7,6 +8,7 @@ import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from PIL import Image
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.nn import Parameter
|
||||
@@ -22,11 +24,13 @@ from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.sd_device_states_presets import empty_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
||||
import diffusers
|
||||
|
||||
# tell it to shut up
|
||||
@@ -110,7 +114,7 @@ class StableDiffusion:
|
||||
self.unet: Union[None, 'UNet2DConditionModel']
|
||||
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers'] = noise_scheduler
|
||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
@@ -119,6 +123,7 @@ class StableDiffusion:
|
||||
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
self.adapter: Union['T2IAdapter', None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
|
||||
@@ -291,8 +296,18 @@ class StableDiffusion:
|
||||
if sampler.startswith("sample_") and self.is_xl:
|
||||
# using kdiffusion
|
||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||
else:
|
||||
elif self.is_xl:
|
||||
Pipe = StableDiffusionXLPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionPipeline
|
||||
|
||||
extra_args = {}
|
||||
if self.adapter:
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLAdapterPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
@@ -305,11 +320,12 @@ class StableDiffusion:
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
add_watermarker=False,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
# force turn that (ruin your images with obvious green and red dots) the #$@@ off!!!
|
||||
pipeline.watermark = None
|
||||
else:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
@@ -318,6 +334,7 @@ class StableDiffusion:
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
flush()
|
||||
# disable progress bar
|
||||
@@ -340,6 +357,12 @@ class StableDiffusion:
|
||||
for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False):
|
||||
gen_config = image_configs[i]
|
||||
|
||||
extra = {}
|
||||
if gen_config.adapter_image_path is not None:
|
||||
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
extra['image'] = validation_image
|
||||
|
||||
if self.network is not None:
|
||||
self.network.multiplier = gen_config.network_multiplier
|
||||
torch.manual_seed(gen_config.seed)
|
||||
@@ -355,7 +378,6 @@ class StableDiffusion:
|
||||
grs = 0.7
|
||||
# grs = 0.0
|
||||
|
||||
extra = {}
|
||||
if sampler.startswith("sample_"):
|
||||
extra['use_karras_sigmas'] = True
|
||||
|
||||
@@ -379,6 +401,7 @@ class StableDiffusion:
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
**extra
|
||||
).images[0]
|
||||
|
||||
gen_config.save_image(img)
|
||||
@@ -517,6 +540,7 @@ class StableDiffusion:
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
@@ -558,6 +582,7 @@ class StableDiffusion:
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
@@ -855,6 +880,7 @@ class StableDiffusion:
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
self.device_state = {
|
||||
**empty_preset,
|
||||
'vae': {
|
||||
'training': self.vae.training,
|
||||
'device': self.vae.device,
|
||||
@@ -880,6 +906,12 @@ class StableDiffusion:
|
||||
'device': self.text_encoder.device,
|
||||
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
}
|
||||
if self.adapter is not None:
|
||||
self.device_state['adapter'] = {
|
||||
'training': self.adapter.training,
|
||||
'device': self.adapter.device,
|
||||
'requires_grad': self.adapter.requires_grad,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
# restores the device state for all modules
|
||||
@@ -927,6 +959,14 @@ class StableDiffusion:
|
||||
self.text_encoder.eval()
|
||||
self.text_encoder.to(state['text_encoder']['device'])
|
||||
self.text_encoder.requires_grad_(state['text_encoder']['requires_grad'])
|
||||
|
||||
if self.adapter is not None:
|
||||
self.adapter.to(state['adapter']['device'])
|
||||
self.adapter.requires_grad_(state['adapter']['requires_grad'])
|
||||
if state['adapter']['training']:
|
||||
self.adapter.train()
|
||||
else:
|
||||
self.adapter.eval()
|
||||
flush()
|
||||
|
||||
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
||||
@@ -940,9 +980,9 @@ class StableDiffusion:
|
||||
if device_state_preset in ['cache_latents']:
|
||||
active_modules = ['vae']
|
||||
if device_state_preset in ['generate']:
|
||||
active_modules = ['vae', 'unet', 'text_encoder']
|
||||
active_modules = ['vae', 'unet', 'text_encoder', 'adapter']
|
||||
|
||||
state = {}
|
||||
state = copy.deepcopy(empty_preset)
|
||||
# vae
|
||||
state['vae'] = {
|
||||
'training': 'vae' in training_modules,
|
||||
@@ -973,4 +1013,11 @@ class StableDiffusion:
|
||||
'requires_grad': 'text_encoder' in training_modules,
|
||||
}
|
||||
|
||||
if self.adapter is not None:
|
||||
state['adapter'] = {
|
||||
'training': 'adapter' in training_modules,
|
||||
'device': self.device_torch if 'adapter' in active_modules else 'cpu',
|
||||
'requires_grad': 'adapter' in training_modules,
|
||||
}
|
||||
|
||||
self.set_device_state(state)
|
||||
|
||||
Reference in New Issue
Block a user