Big refactor of SD runner and added image generator

This commit is contained in:
Jaret Burkett
2023-08-03 14:51:25 -06:00
parent 75ec5d9292
commit 66c6f0f6f7
16 changed files with 923 additions and 430 deletions

View File

@@ -42,6 +42,16 @@ here so far.
--- ---
### Batch Image Generation
A image generator that can take frompts from a config file or form a txt file and generate them to a
folder. I mainly needed this for an SDXL test I am doing but added some polish to it so it can be used
for generat batch image generation.
It all runs off a config file, which you can find an example of in `config/examples/generate.example.yaml`.
Mere info is in the comments in the example
---
### LoRA (lierla), LoCON (LyCORIS) extractor ### LoRA (lierla), LoCON (LyCORIS) extractor
It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features
@@ -143,6 +153,11 @@ Just went in and out. It is much worse on smaller faces than shown here.
## Change Log ## Change Log
#### 2021-08-03
Another big refactor to make SD more modular.
Made batch image generation script
#### 2021-08-01 #### 2021-08-01
Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so
Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable

View File

@@ -0,0 +1,60 @@
---
job: generate # tells the runner what to do
config:
name: "generate" # this is not really used anywhere currently but required by runner
process:
# process 1
- type: to_folder # process images to a folder
output_folder: "output/gen"
device: cuda:0 # cpu, cuda:0, etc
generate:
# these are your defaults you can override most of them with flags
sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now
width: 1024
height: 1024
neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
seed: -1 # -1 is random
guidance_scale: 7
sample_steps: 20
ext: ".png" # .png, .jpg, .jpeg, .webp
# here ate the flags you can use for prompts. Always start with
# your prompt first then add these flags after. You can use as many
# like
# photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20
# we will try to support all sd-scripts flags where we can
# FROM SD-SCRIPTS
# --n Treat everything until the next option as a negative prompt.
# --w Specify the width of the generated image.
# --h Specify the height of the generated image.
# --d Specify the seed for the generated image.
# --l Specify the CFG scale for the generated image.
# --s Specify the number of steps during generation.
# OURS and some QOL additions
# --p2 Prompt for the second text encoder (SDXL only)
# --n2 Negative prompt for the second text encoder (SDXL only)
# --gr Specify the guidance rescale for the generated image (SDXL only)
# --seed Specify the seed for the generated image same as --d
# --cfg Specify the CFG scale for the generated image same as --l
# --steps Specify the number of steps during generation same as --s
prompt_file: false # if true a txt file will be created next to images with prompt strings used
# prompts can also be a path to a text file with one prompt per line
# prompts: "/path/to/prompts.txt"
prompts:
- "photo of batman"
- "photo of superman"
- "photo of spiderman"
- "photo of a superhero --n batman superman spiderman"
model:
# huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
# name_or_path: "runwayml/stable-diffusion-v1-5"
name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
is_v2: false # for v2 models
is_v_pred: false # for v-prediction models (most v2 models)
is_xl: false # for SDXL models
dtype: bf16

View File

@@ -57,7 +57,8 @@ config:
# bf16 works best if your GPU supports it (modern) # bf16 works best if your GPU supports it (modern)
dtype: bf16 # fp32, bf16, fp16 dtype: bf16 # fp32, bf16, fp16
# if you have it, use it. It is faster and better # if you have it, use it. It is faster and better
xformers: true # torch 2.0 doesnt need xformers anymore, only use if you have lower version
# xformers: true
# I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
# although, the way we train sliders is comparative, so it probably won't work anyway # although, the way we train sliders is comparative, so it probably won't work anyway
noise_offset: 0.0 noise_offset: 0.0

32
jobs/GenerateJob.py Normal file
View File

@@ -0,0 +1,32 @@
from jobs import BaseJob
from collections import OrderedDict
from typing import List
from jobs.process import GenerateProcess
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
process_dict = {
'to_folder': 'GenerateProcess',
}
class GenerateJob(BaseJob):
process: List[GenerateProcess]
def __init__(self, config: OrderedDict):
super().__init__(config)
self.device = self.get_conf('device', 'cpu')
# loads the processes from the config
self.load_processes(process_dict)
def run(self):
super().run()
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
for process in self.process:
process.run()

View File

@@ -3,3 +3,4 @@ from .ExtractJob import ExtractJob
from .TrainJob import TrainJob from .TrainJob import TrainJob
from .MergeJob import MergeJob from .MergeJob import MergeJob
from .ModJob import ModJob from .ModJob import ModJob
from .GenerateJob import GenerateJob

View File

@@ -1,10 +1,9 @@
import copy import copy
import json import json
from collections import OrderedDict from collections import OrderedDict
from typing import ForwardRef
class BaseProcess: class BaseProcess(object):
meta: OrderedDict meta: OrderedDict
def __init__( def __init__(

View File

@@ -1,34 +1,23 @@
import glob import glob
import time
from collections import OrderedDict from collections import OrderedDict
import os import os
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from toolkit.lora_special import LoRASpecialNetwork from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer from toolkit.optimizer import get_optimizer
from toolkit.paths import REPOS_ROOT
import sys
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline from toolkit.scheduler import get_lr_scheduler
from toolkit.stable_diffusion_model import StableDiffusion
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler, PNDMScheduler, \
DDIMScheduler, DDPMScheduler
from jobs.process import BaseTrainProcess from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
from toolkit.train_tools import get_torch_dtype, apply_noise_offset from toolkit.train_tools import get_torch_dtype
import gc import gc
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from leco import train_util, model_util from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig GenerateImageConfig
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
def flush(): def flush():
@@ -36,11 +25,9 @@ def flush():
gc.collect() gc.collect()
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
class BaseSDTrainProcess(BaseTrainProcess): class BaseSDTrainProcess(BaseTrainProcess):
sd: StableDiffusion
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
self.custom_pipeline = custom_pipeline self.custom_pipeline = custom_pipeline
@@ -64,177 +51,52 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.logging_config = LogingConfig(**self.get_conf('logging', {})) self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None self.optimizer = None
self.lr_scheduler = None self.lr_scheduler = None
self.sd: 'StableDiffusion' = None
# sdxl stuff self.sd = StableDiffusion(
self.logit_scale = None device=self.device,
self.ckppt_info = None model_config=self.model_config,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,
)
# added later # to hold network if there is one
self.network = None self.network = None
def sample(self, step=None, is_first=False): def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples') sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder): gen_img_config_list = []
os.makedirs(sample_folder, exist_ok=True)
if self.network is not None:
self.network.eval()
# save current seed state for training
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
original_device_dict = {
'vae': self.sd.vae.device,
'unet': self.sd.unet.device,
# 'tokenizer': self.sd.tokenizer.device,
}
# handle sdxl text encoder
if isinstance(self.sd.text_encoder, list):
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
original_device_dict[f'text_encoder_{i}'] = encoder.device
encoder.to(self.device_torch)
else:
original_device_dict['text_encoder'] = self.sd.text_encoder.device
self.sd.text_encoder.to(self.device_torch)
self.sd.vae.to(self.device_torch)
self.sd.unet.to(self.device_torch)
# self.sd.text_encoder.to(self.device_torch)
# self.sd.tokenizer.to(self.device_torch)
# TODO add clip skip
if self.sd.is_xl:
pipeline = StableDiffusionXLPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=self.sd.noise_scheduler,
).to(self.device_torch)
else:
pipeline = StableDiffusionPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder,
tokenizer=self.sd.tokenizer,
scheduler=self.sd.noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
).to(self.device_torch)
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
sample_config = self.first_sample_config if is_first else self.sample_config sample_config = self.first_sample_config if is_first else self.sample_config
start_seed = sample_config.seed start_seed = sample_config.seed
start_multiplier = self.network.multiplier
current_seed = start_seed current_seed = start_seed
for i in range(len(sample_config.prompts)):
if sample_config.walk_seed:
current_seed = start_seed + i
pipeline.to(self.device_torch) step_num = ''
with self.network: if step is not None:
with torch.no_grad(): # zero-pad 9 digits
if self.network is not None: step_num = f"_{str(step).zfill(9)}"
assert self.network.is_active
if self.logging_config.verbose:
print("network_state", {
'is_active': self.network.is_active,
'multiplier': self.network.multiplier,
})
for i in tqdm(range(len(sample_config.prompts)), desc=f"Generating Samples - step: {step}", filename = f"[time]_{step_num}_[count].png"
leave=False):
raw_prompt = sample_config.prompts[i]
neg = sample_config.neg output_path = os.path.join(sample_folder, filename)
multiplier = sample_config.network_multiplier
p_split = raw_prompt.split('--')
prompt = p_split[0].strip()
height = sample_config.height
width = sample_config.width
if len(p_split) > 1: gen_img_config_list.append(GenerateImageConfig(
for split in p_split: prompt=sample_config.prompts[i], # it will autoparse the prompt
flag = split[:1] width=sample_config.width,
content = split[1:].strip() height=sample_config.height,
if flag == 'n': negative_prompt=sample_config.neg,
neg = content seed=current_seed,
elif flag == 'm': guidance_scale=sample_config.guidance_scale,
# multiplier guidance_rescale=sample_config.guidance_rescale,
multiplier = float(content) num_inference_steps=sample_config.sample_steps,
elif flag == 'w': network_multiplier=sample_config.network_multiplier,
# multiplier output_path=output_path,
width = int(content) ))
elif flag == 'h':
# multiplier
height = int(content)
height = max(64, height - height % 8) # round to divisible by 8 # send to be generated
width = max(64, width - width % 8) # round to divisible by 8 self.sd.generate_images(gen_img_config_list)
if sample_config.walk_seed:
current_seed += i
if self.network is not None:
self.network.multiplier = multiplier
torch.manual_seed(current_seed)
torch.cuda.manual_seed(current_seed)
if self.sd.is_xl:
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
guidance_rescale=0.7,
).images[0]
else:
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
step_num = ''
if step is not None:
# zero-pad 9 digits
step_num = f"_{str(step).zfill(9)}"
seconds_since_epoch = int(time.time())
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
output_path = os.path.join(sample_folder, filename)
img.save(output_path)
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
# restore training state
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
self.sd.vae.to(original_device_dict['vae'])
self.sd.unet.to(original_device_dict['unet'])
if isinstance(self.sd.text_encoder, list):
for encoder, i in zip(self.sd.text_encoder, range(len(self.sd.text_encoder))):
encoder.to(original_device_dict[f'text_encoder_{i}'])
else:
self.sd.text_encoder.to(original_device_dict['text_encoder'])
if self.network is not None:
self.network.train()
self.network.multiplier = start_multiplier
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
def update_training_metadata(self): def update_training_metadata(self):
o_dict = OrderedDict({ o_dict = OrderedDict({
@@ -328,148 +190,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_before_train_loop(self): def hook_before_train_loop(self):
pass pass
def get_latent_noise(
self,
height=None,
width=None,
pixel_height=None,
pixel_width=None,
):
if height is None and pixel_height is None:
raise ValueError("height or pixel_height must be specified")
if width is None and pixel_width is None:
raise ValueError("width or pixel_width must be specified")
if height is None:
height = pixel_height // VAE_SCALE_FACTOR
if width is None:
width = pixel_width // VAE_SCALE_FACTOR
noise = torch.randn(
(
self.train_config.batch_size,
UNET_IN_CHANNELS,
height,
width,
),
device="cpu",
)
noise = apply_noise_offset(noise, self.train_config.noise_offset)
return noise
def hook_train_loop(self): def hook_train_loop(self):
# return loss # return loss
return 0.0 return 0.0
def get_time_ids_from_latents(self, latents):
bs, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
dtype = get_torch_dtype(self.train_config.dtype)
if self.sd.is_xl:
prompt_ids = train_util.get_add_time_ids(
height,
width,
dynamic_crops=False, # look into this
dtype=dtype,
).to(self.device_torch, dtype=dtype)
return train_util.concat_embeddings(
prompt_ids, prompt_ids, bs
)
else:
return None
def predict_noise(
self,
latents: torch.FloatTensor,
text_embeddings: PromptEmbeds,
timestep: int,
guidance_scale=7.5,
guidance_rescale=0, # 0.7
add_time_ids=None,
**kwargs,
):
if self.sd.is_xl:
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
added_cond_kwargs = {
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": add_time_ids,
}
# predict the noise residual
noise_pred = self.sd.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
else:
# if we are doing classifier free guidance, need to double up
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
# predict the noise residual
noise_pred = self.sd.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
def diffuse_some_steps(
self,
latents: torch.FloatTensor,
text_embeddings: PromptEmbeds,
total_timesteps: int = 1000,
start_timesteps=0,
guidance_scale=1,
add_time_ids=None,
**kwargs,
):
for timestep in tqdm(self.sd.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
noise_pred = self.predict_noise(
latents,
text_embeddings,
timestep,
guidance_scale=guidance_scale,
add_time_ids=add_time_ids,
**kwargs,
)
latents = self.sd.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
# return latents_steps
return latents
def get_latest_save_path(self): def get_latest_save_path(self):
# get latest saved step # get latest saved step
if os.path.exists(self.save_root): if os.path.exists(self.save_root):
@@ -497,92 +221,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
print("load_weights not implemented for non-network models") print("load_weights not implemented for non-network models")
def run(self): def run(self):
super().run() # run base process run
BaseTrainProcess.run(self)
### HOOK ### ### HOOK ###
self.hook_before_model_load() self.hook_before_model_load()
# run base sd process run
self.sd.load_model()
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
# TODO handle other schedulers # model is loaded from BaseSDProcess
# sch = KDPM2DiscreteScheduler unet = self.sd.unet
sch = DDPMScheduler vae = self.sd.vae
# do our own scheduler tokenizer = self.sd.tokenizer
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" text_encoder = self.sd.text_encoder
scheduler = sch( noise_scheduler = self.sd.noise_scheduler
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
clip_sample=False,
prediction_type=prediction_type,
)
if self.model_config.is_xl:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = CustomStableDiffusionXLPipeline
pipe = pipln.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
else:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = CustomStableDiffusionPipeline
pipe = pipln.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='dpm',
device=self.device_torch,
load_safety_checker=False,
).to(self.device_torch)
pipe.register_to_config(requires_safety_checker=False)
text_encoder = pipe.text_encoder
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
tokenizer = pipe.tokenizer
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler
unet = pipe.unet
noise_scheduler = pipe.scheduler
vae = pipe.vae.to('cpu', dtype=dtype)
vae.eval()
vae.requires_grad_(False)
flush()
self.sd = StableDiffusion(
vae,
tokenizer,
text_encoder,
unet,
noise_scheduler,
is_xl=self.model_config.is_xl,
pipeline=pipe
)
unet.to(self.device_torch, dtype=dtype)
if self.train_config.xformers: if self.train_config.xformers:
vae.set_use_memory_efficient_attention_xformers(True) vae.set_use_memory_efficient_attention_xformers(True)
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
if self.train_config.gradient_checkpointing: if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
unet.to(self.device_torch, dtype=dtype)
unet.requires_grad_(False) unet.requires_grad_(False)
unet.eval() unet.eval()
vae = vae.to(torch.device('cpu'), dtype=dtype)
vae.requires_grad_(False)
vae.eval()
if self.network_config is not None: if self.network_config is not None:
self.network = LoRASpecialNetwork( self.network = LoRASpecialNetwork(
@@ -598,6 +263,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
) )
self.network.force_to(self.device_torch, dtype=dtype) self.network.force_to(self.device_torch, dtype=dtype)
# give network to sd so it can use it
self.sd.network = self.network
self.network.apply_to( self.network.apply_to(
text_encoder, text_encoder,
@@ -650,7 +317,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
optimizer_params=self.train_config.optimizer_params) optimizer_params=self.train_config.optimizer_params)
self.optimizer = optimizer self.optimizer = optimizer
lr_scheduler = train_util.get_lr_scheduler( lr_scheduler = get_lr_scheduler(
self.train_config.lr_scheduler, self.train_config.lr_scheduler,
optimizer, optimizer,
max_iterations=self.train_config.steps, max_iterations=self.train_config.steps,

View File

@@ -0,0 +1,102 @@
import gc
import os
from collections import OrderedDict
from typing import ForwardRef, List
import torch
from safetensors.torch import save_file, load_file
from jobs.process.BaseProcess import BaseProcess
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
add_base_model_info_to_meta
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
class GenerateConfig:
prompts: List[str]
def __init__(self, **kwargs):
self.sampler = kwargs.get('sampler', 'ddpm')
self.width = kwargs.get('width', 512)
self.height = kwargs.get('height', 512)
self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.prompt_2 = kwargs.get('prompt_2', None)
self.neg_2 = kwargs.get('neg_2', None)
self.prompts = kwargs.get('prompts', None)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext = kwargs.get('ext', 'png')
self.prompt_file = kwargs.get('prompt_file', False)
if self.prompts is None:
raise ValueError("Prompts must be set")
if isinstance(self.prompts, str):
if os.path.exists(self.prompts):
with open(self.prompts, 'r') as f:
self.prompts = f.read().splitlines()
self.prompts = [p.strip() for p in self.prompts if len(p.strip()) > 0]
else:
raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
class GenerateProcess(BaseProcess):
process_id: int
config: OrderedDict
progress_bar: ForwardRef('tqdm') = None
sd: StableDiffusion
def __init__(
self,
process_id: int,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.device = self.get_conf('device', self.job.device)
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.model_config.dtype,
)
print(f"Using device {self.device}")
def run(self):
super().run()
print("Loading model...")
self.sd.load_model()
print(f"Generating {len(self.generate_config.prompts)} images")
# build prompt image configs
prompt_image_configs = []
for prompt in self.generate_config.prompts:
prompt_image_configs.append(GenerateImageConfig(
prompt=prompt,
prompt_2=self.generate_config.prompt_2,
width=self.generate_config.width,
height=self.generate_config.height,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg,
negative_prompt_2=self.generate_config.neg_2,
seed=self.generate_config.seed,
guidance_rescale=self.generate_config.guidance_rescale,
output_ext=self.generate_config.ext,
output_folder=self.output_folder,
add_prompt_file=self.generate_config.prompt_file
))
# generate images
self.sd.generate_images(prompt_image_configs)
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()

View File

@@ -202,9 +202,11 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
) )
# get noise # get noise
noise = self.get_latent_noise( noise = self.sd.get_latent_noise(
pixel_height=self.rescale_config.from_resolution, pixel_height=self.rescale_config.from_resolution,
pixel_width=self.rescale_config.from_resolution, pixel_width=self.rescale_config.from_resolution,
batch_size=self.train_config.batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype) ).to(self.device_torch, dtype=dtype)
torch.set_default_device(self.device_torch) torch.set_default_device(self.device_torch)
@@ -238,7 +240,7 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
) )
with torch.no_grad(): with torch.no_grad():
noise_pred_target = self.predict_noise( noise_pred_target = self.sd.predict_noise(
latents, latents,
text_embeddings=text_embeddings, text_embeddings=text_embeddings,
timestep=timestep, timestep=timestep,
@@ -256,7 +258,7 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
with self.network: with self.network:
assert self.network.is_active assert self.network.is_active
self.network.multiplier = 1.0 self.network.multiplier = 1.0
noise_pred_train = self.predict_noise( noise_pred_train = self.sd.predict_noise(
reduced_latents, reduced_latents,
text_embeddings=text_embeddings, text_embeddings=text_embeddings,
timestep=timestep, timestep=timestep,

View File

@@ -1,7 +1,6 @@
# ref: # ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py # - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import random import random
import time
from collections import OrderedDict from collections import OrderedDict
import os import os
from typing import Optional from typing import Optional
@@ -14,16 +13,12 @@ from toolkit.paths import REPOS_ROOT
import sys import sys
from toolkit.stable_diffusion_model import PromptEmbeds from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc import gc
from toolkit import train_tools from toolkit import train_tools
import torch import torch
from leco import train_util, model_util from .BaseSDTrainProcess import BaseSDTrainProcess
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
class ACTION_TYPES_SLIDER: class ACTION_TYPES_SLIDER:
@@ -131,7 +126,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..") self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
if not self.slider_config.prompt_tensors: if not self.slider_config.prompt_tensors:
# shuffle # shuffle
random.shuffle(self.prompt_txt_list) random.shuffle(self.prompt_txt_list)
@@ -175,8 +169,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False): for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
for target in self.slider_config.targets: for target in self.slider_config.targets:
prompt_list = [ prompt_list = [
f"{target.target_class}", # target_class f"{target.target_class}", # target_class
f"{target.target_class} {neutral}", # target_class with neutral f"{target.target_class} {neutral}", # target_class with neutral
f"{target.positive}", # positive_target f"{target.positive}", # positive_target
f"{target.positive} {neutral}", # positive_target with neutral f"{target.positive} {neutral}", # positive_target with neutral
f"{target.negative}", # negative_target f"{target.negative}", # negative_target
@@ -320,7 +314,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
) )
] ]
# move to cpu to save vram # move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling # We don't need text encoder anymore, but keep it on cpu for sampling
# if text encoder is list # if text encoder is list
@@ -364,7 +357,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
loss_function = torch.nn.MSELoss() loss_function = torch.nn.MSELoss()
def get_noise_pred(neg, pos, gs, cts, dn): def get_noise_pred(neg, pos, gs, cts, dn):
return self.predict_noise( return self.sd.predict_noise(
latents=dn, latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings( text_embeddings=train_tools.concat_prompt_embeddings(
neg, # negative prompt neg, # negative prompt
@@ -391,9 +384,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
).item() ).item()
# get noise # get noise
noise = self.get_latent_noise( noise = self.sd.get_latent_noise(
pixel_height=height, pixel_height=height,
pixel_width=width, pixel_width=width,
batch_size=self.train_config.batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype) ).to(self.device_torch, dtype=dtype)
# get latents # get latents
@@ -403,7 +398,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network: with self.network:
assert self.network.is_active assert self.network.is_active
self.network.multiplier = multiplier * rand_weight self.network.multiplier = multiplier * rand_weight
denoised_latents = self.diffuse_some_steps( denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents latents, # pass simple noise latents
train_tools.concat_prompt_embeddings( train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional prompt_pair.positive_target, # unconditional

View File

@@ -245,7 +245,7 @@ class TrainSliderProcessOld(BaseSDTrainProcess):
loss_function = torch.nn.MSELoss() loss_function = torch.nn.MSELoss()
def get_noise_pred(p, n, gs, cts, dn): def get_noise_pred(p, n, gs, cts, dn):
return self.predict_noise( return self.sd.predict_noise(
latents=dn, latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings( text_embeddings=train_tools.concat_prompt_embeddings(
p, # unconditional p, # unconditional
@@ -272,9 +272,11 @@ class TrainSliderProcessOld(BaseSDTrainProcess):
).item() ).item()
# get noise # get noise
noise = self.get_latent_noise( noise = self.sd.get_latent_noise(
pixel_height=height, pixel_height=height,
pixel_width=width, pixel_width=width,
batch_size=self.train_config.batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype) ).to(self.device_torch, dtype=dtype)
# get latents # get latents
@@ -284,7 +286,7 @@ class TrainSliderProcessOld(BaseSDTrainProcess):
with self.network: with self.network:
assert self.network.is_active assert self.network.is_active
self.network.multiplier = multiplier self.network.multiplier = multiplier
denoised_latents = self.diffuse_some_steps( denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents latents, # pass simple noise latents
train_tools.concat_prompt_embeddings( train_tools.concat_prompt_embeddings(
positive, # unconditional positive, # unconditional

View File

@@ -10,3 +10,4 @@ from .TrainSliderProcessOld import TrainSliderProcessOld
from .TrainLoRAHack import TrainLoRAHack from .TrainLoRAHack import TrainLoRAHack
from .TrainSDRescaleProcess import TrainSDRescaleProcess from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .ModRescaleLoraProcess import ModRescaleLoraProcess from .ModRescaleLoraProcess import ModRescaleLoraProcess
from .GenerateProcess import GenerateProcess

View File

@@ -1,4 +1,7 @@
from typing import List import os
import time
from typing import List, Optional
import random
class SaveConfig: class SaveConfig:
@@ -27,6 +30,7 @@ class SampleConfig:
self.guidance_scale = kwargs.get('guidance_scale', 7) self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20) self.sample_steps = kwargs.get('sample_steps', 20)
self.network_multiplier = kwargs.get('network_multiplier', 1) self.network_multiplier = kwargs.get('network_multiplier', 1)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
class NetworkConfig: class NetworkConfig:
@@ -35,7 +39,7 @@ class NetworkConfig:
rank = kwargs.get('rank', None) rank = kwargs.get('rank', None)
linear = kwargs.get('linear', None) linear = kwargs.get('linear', None)
if rank is not None: if rank is not None:
self.rank: int = rank # rank for backward compatibility self.rank: int = rank # rank for backward compatibility
self.linear: int = rank self.linear: int = rank
elif linear is not None: elif linear is not None:
self.rank: int = linear self.rank: int = linear
@@ -71,6 +75,7 @@ class ModelConfig:
self.is_v2: bool = kwargs.get('is_v2', False) self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False) self.is_xl: bool = kwargs.get('is_xl', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False) self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16')
if self.name_or_path is None: if self.name_or_path is None:
raise ValueError('name_or_path must be specified') raise ValueError('name_or_path must be specified')
@@ -103,3 +108,197 @@ class SliderConfig:
self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]]) self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
self.prompt_file: str = kwargs.get('prompt_file', None) self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None) self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
class GenerateImageConfig:
def __init__(
self,
prompt: str = '',
prompt_2: Optional[str] = None,
width: int = 512,
height: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: str = '',
negative_prompt_2: Optional[str] = None,
seed: int = -1,
network_multiplier: float = 1.0,
guidance_rescale: float = 0.0,
# the tag [time] will be replaced with milliseconds since epoch
output_path: str = None, # full image path
output_folder: str = None, # folder to save image in if output_path is not specified
output_ext: str = 'png', # extension to save image as if output_path is not specified
output_tail: str = '', # tail to add to output filename
add_prompt_file: bool = False, # add a prompt file with generated image
):
self.width: int = width
self.height: int = height
self.num_inference_steps: int = num_inference_steps
self.guidance_scale: float = guidance_scale
self.guidance_rescale: float = guidance_rescale
self.prompt: str = prompt
self.prompt_2: str = prompt_2
self.negative_prompt: str = negative_prompt
self.negative_prompt_2: str = negative_prompt_2
self.output_path: str = output_path
self.seed: int = seed
if self.seed == -1:
# generate random one
self.seed = random.randint(0, 2 ** 32 - 1)
self.network_multiplier: float = network_multiplier
self.output_folder: str = output_folder
self.output_ext: str = output_ext
self.add_prompt_file: bool = add_prompt_file
self.output_tail: str = output_tail
self.gen_time: int = int(time.time() * 1000)
# prompt string will override any settings above
self._process_prompt_string()
# handle dual text encoder prompts if nothing passed
if negative_prompt_2 is None:
self.negative_prompt_2 = negative_prompt
if prompt_2 is None:
self.prompt_2 = prompt
# parse prompt paths
if self.output_path is None and self.output_folder is None:
raise ValueError('output_path or output_folder must be specified')
elif self.output_path is not None:
self.output_folder = os.path.dirname(self.output_path)
self.output_ext = os.path.splitext(self.output_path)[1][1:]
self.output_filename_no_ext = os.path.splitext(os.path.basename(self.output_path))[0]
else:
self.output_filename_no_ext = '[time]_[count]'
if len(self.output_tail) > 0:
self.output_filename_no_ext += '_' + self.output_tail
self.output_path = os.path.join(self.output_folder, self.output_filename_no_ext + '.' + self.output_ext)
# adjust height
self.height = max(64, self.height - self.height % 8) # round to divisible by 8
self.width = max(64, self.width - self.width % 8) # round to divisible by 8
def set_gen_time(self, gen_time: int = None):
if gen_time is not None:
self.gen_time = gen_time
else:
self.gen_time = int(time.time() * 1000)
def _get_path_no_ext(self, count: int = 0, max_count=0):
# zero pad count
count_str = str(count).zfill(len(str(max_count)))
# replace [time] with gen time
filename = self.output_filename_no_ext.replace('[time]', str(self.gen_time))
# replace [count] with count
filename = filename.replace('[count]', count_str)
return filename
def get_image_path(self, count: int = 0, max_count=0):
filename = self._get_path_no_ext(count, max_count)
filename += '.' + self.output_ext
# join with folder
return os.path.join(self.output_folder, filename)
def get_prompt_path(self, count: int = 0, max_count=0):
filename = self._get_path_no_ext(count, max_count)
filename += '.txt'
# join with folder
return os.path.join(self.output_folder, filename)
def save_image(self, image, count: int = 0, max_count=0):
# make parent dirs
os.makedirs(self.output_folder, exist_ok=True)
self.set_gen_time()
# TODO save image gen header info for A1111 and us, our seeds probably wont match
image.save(self.get_image_path(count, max_count))
# do prompt file
if self.add_prompt_file:
self.save_prompt_file(count, max_count)
def save_prompt_file(self, count: int = 0, max_count=0):
# save prompt file
with open(self.get_prompt_path(count, max_count), 'w') as f:
prompt = self.prompt
if self.prompt_2 is not None:
prompt += ' --p2 ' + self.prompt_2
if self.negative_prompt is not None:
prompt += ' --n ' + self.negative_prompt
if self.negative_prompt_2 is not None:
prompt += ' --n2 ' + self.negative_prompt_2
prompt += ' --w ' + str(self.width)
prompt += ' --h ' + str(self.height)
prompt += ' --seed ' + str(self.seed)
prompt += ' --cfg ' + str(self.guidance_scale)
prompt += ' --steps ' + str(self.num_inference_steps)
prompt += ' --m ' + str(self.network_multiplier)
prompt += ' --gr ' + str(self.guidance_rescale)
# get gen info
f.write(self.prompt)
def _process_prompt_string(self):
# we will try to support all sd-scripts where we can
# FROM SD-SCRIPTS
# --n Treat everything until the next option as a negative prompt.
# --w Specify the width of the generated image.
# --h Specify the height of the generated image.
# --d Specify the seed for the generated image.
# --l Specify the CFG scale for the generated image.
# --s Specify the number of steps during generation.
# OURS and some QOL additions
# --m Specify the network multiplier for the generated image.
# --p2 Prompt for the second text encoder (SDXL only)
# --n2 Negative prompt for the second text encoder (SDXL only)
# --gr Specify the guidance rescale for the generated image (SDXL only)
# --seed Specify the seed for the generated image same as --d
# --cfg Specify the CFG scale for the generated image same as --l
# --steps Specify the number of steps during generation same as --s
# --network_multiplier Specify the network multiplier for the generated image same as --m
# process prompt string and update values if it has some
if self.prompt is not None and len(self.prompt) > 0:
# process prompt string
prompt = self.prompt
prompt = prompt.strip()
p_split = prompt.split('--')
self.prompt = p_split[0].strip()
if len(p_split) > 1:
for split in p_split[1:]:
# allows multi char flags
flag = split.split(' ')[0].strip()
content = split[len(flag):].strip()
if flag == 'p2':
self.prompt_2 = content
elif flag == 'n':
self.negative_prompt = content
elif flag == 'n2':
self.negative_prompt_2 = content
elif flag == 'w':
self.width = int(content)
elif flag == 'h':
self.height = int(content)
elif flag == 'd':
self.seed = int(content)
elif flag == 'seed':
self.seed = int(content)
elif flag == 'l':
self.guidance_scale = float(content)
elif flag == 'cfg':
self.guidance_scale = float(content)
elif flag == 's':
self.num_inference_steps = int(content)
elif flag == 'steps':
self.num_inference_steps = int(content)
elif flag == 'm':
self.network_multiplier = float(content)
elif flag == 'network_multiplier':
self.network_multiplier = float(content)
elif flag == 'gr':
self.guidance_rescale = float(content)

View File

@@ -16,6 +16,9 @@ def get_job(config_path, name=None):
if job == 'mod': if job == 'mod':
from jobs import ModJob from jobs import ModJob
return ModJob(config) return ModJob(config)
if job == 'generate':
from jobs import GenerateJob
return GenerateJob(config)
# elif job == 'train': # elif job == 'train':
# from jobs import TrainJob # from jobs import TrainJob

33
toolkit/scheduler.py Normal file
View File

@@ -0,0 +1,33 @@
import torch
from typing import Optional
def get_lr_scheduler(
name: Optional[str],
optimizer: torch.optim.Optimizer,
max_iterations: Optional[int],
lr_min: Optional[float],
**kwargs,
):
if name == "cosine":
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
)
elif name == "cosine_with_restarts":
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=max_iterations, T_mult=2, eta_min=lr_min, **kwargs
)
elif name == "step":
return torch.optim.lr_scheduler.StepLR(
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
)
elif name == "constant":
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
elif name == "linear":
return torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.5, end_factor=0.5, total_iters=max_iterations, **kwargs
)
else:
raise ValueError(
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
)

View File

@@ -1,12 +1,16 @@
import gc
import typing import typing
from typing import Union, OrderedDict from typing import Union, OrderedDict, List
import sys import sys
import os import os
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file from safetensors.torch import save_file
from tqdm import tqdm
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype, apply_noise_offset
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco')) sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
@@ -14,6 +18,32 @@ from leco import train_util
import torch import torch
from library import model_util from library import model_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
from diffusers.schedulers import DDPMScheduler
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
class BlankNetwork:
multiplier = 1.0
is_active = True
def __init__(self):
pass
def __enter__(self):
self.is_active = True
def __exit__(self, exc_type, exc_val, exc_tb):
self.is_active = False
def flush():
torch.cuda.empty_cache()
gc.collect()
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
class PromptEmbeds: class PromptEmbeds:
@@ -39,31 +69,382 @@ class PromptEmbeds:
# if is type checking # if is type checking
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from diffusers import StableDiffusionPipeline from diffusers import \
from toolkit.pipelines import CustomStableDiffusionXLPipeline StableDiffusionPipeline, \
AutoencoderKL, \
UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
class StableDiffusion: class StableDiffusion:
pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline'] pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline']
vae: Union[None, 'AutoencoderKL']
unet: Union[None, 'UNet2DConditionModel']
text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler']
device: str
dtype: str
torch_dtype: torch.dtype
device_torch: torch.device
model_config: ModelConfig
def __init__( def __init__(
self, self,
vae, device,
tokenizer, model_config: ModelConfig,
text_encoder, dtype='fp16',
unet, custom_pipeline=None
noise_scheduler,
is_xl=False,
pipeline=None,
): ):
# text encoder has a list of 2 for xl self.custom_pipeline = custom_pipeline
self.vae = vae self.device = device
self.dtype = dtype
self.torch_dtype = get_torch_dtype(dtype)
self.device_torch = torch.device(self.device)
self.model_config = model_config
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
# sdxl stuff
self.logit_scale = None
self.ckppt_info = None
# to hold network if there is one
self.network = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
def load_model(self):
dtype = get_torch_dtype(self.dtype)
# TODO handle other schedulers
# sch = KDPM2DiscreteScheduler
sch = DDPMScheduler
# do our own scheduler
prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
scheduler = sch(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
clip_sample=False,
prediction_type=prediction_type,
steps_offset=1
)
if self.model_config.is_xl:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = CustomStableDiffusionXLPipeline
pipe = pipln.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='ddpm',
device=self.device_torch,
).to(self.device_torch)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
for text_encoder in text_encoders:
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
text_encoder = text_encoders
else:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
pipln = CustomStableDiffusionPipeline
pipe = pipln.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='dpm',
device=self.device_torch,
load_safety_checker=False,
).to(self.device_torch)
pipe.register_to_config(requires_safety_checker=False)
text_encoder = pipe.text_encoder
text_encoder.to(self.device_torch, dtype=dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
tokenizer = pipe.tokenizer
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler
self.unet = pipe.unet
self.noise_scheduler = pipe.scheduler
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
self.vae.eval()
self.vae.requires_grad_(False)
self.unet.to(self.device_torch, dtype=dtype)
self.unet.requires_grad_(False)
self.unet.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.text_encoder = text_encoder self.text_encoder = text_encoder
self.unet = unet self.pipeline = pipe
self.noise_scheduler = noise_scheduler
self.is_xl = is_xl def generate_images(self, image_configs: List[GenerateImageConfig]):
self.pipeline = pipeline # sample_folder = os.path.join(self.save_root, 'samples')
if self.network is not None:
self.network.eval()
network = self.network
else:
network = BlankNetwork()
# save current seed state for training
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
original_device_dict = {
'vae': self.vae.device,
'unet': self.unet.device,
# 'tokenizer': self.tokenizer.device,
}
# handle sdxl text encoder
if isinstance(self.text_encoder, list):
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
original_device_dict[f'text_encoder_{i}'] = encoder.device
encoder.to(self.device_torch)
else:
original_device_dict['text_encoder'] = self.text_encoder.device
self.text_encoder.to(self.device_torch)
self.vae.to(self.device_torch)
self.unet.to(self.device_torch)
# TODO add clip skip
if self.is_xl:
pipeline = StableDiffusionXLPipeline(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
scheduler=self.noise_scheduler,
add_watermarker=False,
).to(self.device_torch)
# force turn that (ruin your images with obvious green and red dots) the #$@@ off!!!
pipeline.watermark = None
else:
pipeline = StableDiffusionPipeline(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=self.noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
).to(self.device_torch)
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
start_multiplier = 1.0
if self.network is not None:
start_multiplier = self.network.multiplier
pipeline.to(self.device_torch)
with network:
with torch.no_grad():
if self.network is not None:
assert self.network.is_active
for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False):
gen_config = image_configs[i]
if self.network is not None:
self.network.multiplier = gen_config.network_multiplier
torch.manual_seed(gen_config.seed)
torch.cuda.manual_seed(gen_config.seed)
if self.is_xl:
img = pipeline(
prompt=gen_config.prompt,
prompt_2=gen_config.prompt_2,
negative_prompt=gen_config.negative_prompt,
negative_prompt_2=gen_config.negative_prompt_2,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=gen_config.guidance_rescale,
).images[0]
else:
img = pipeline(
prompt=gen_config.prompt,
negative_prompt=gen_config.negative_prompt,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
).images[0]
gen_config.save_image(img)
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
# restore training state
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
self.vae.to(original_device_dict['vae'])
self.unet.to(original_device_dict['unet'])
if isinstance(self.text_encoder, list):
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
encoder.to(original_device_dict[f'text_encoder_{i}'])
else:
self.text_encoder.to(original_device_dict['text_encoder'])
if self.network is not None:
self.network.train()
self.network.multiplier = start_multiplier
# self.tokenizer.to(original_device_dict['tokenizer'])
def get_latent_noise(
self,
height=None,
width=None,
pixel_height=None,
pixel_width=None,
batch_size=1,
noise_offset=0.0,
):
if height is None and pixel_height is None:
raise ValueError("height or pixel_height must be specified")
if width is None and pixel_width is None:
raise ValueError("width or pixel_width must be specified")
if height is None:
height = pixel_height // VAE_SCALE_FACTOR
if width is None:
width = pixel_width // VAE_SCALE_FACTOR
noise = torch.randn(
(
batch_size,
UNET_IN_CHANNELS,
height,
width,
),
device="cpu",
)
noise = apply_noise_offset(noise, noise_offset)
return noise
def get_time_ids_from_latents(self, latents: torch.Tensor):
bs, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
dtype = latents.dtype
if self.is_xl:
prompt_ids = train_util.get_add_time_ids(
height,
width,
dynamic_crops=False, # look into this
dtype=dtype,
).to(self.device_torch, dtype=dtype)
return train_util.concat_embeddings(
prompt_ids, prompt_ids, bs
)
else:
return None
def predict_noise(
self,
latents: torch.FloatTensor,
text_embeddings: PromptEmbeds,
timestep: int,
guidance_scale=7.5,
guidance_rescale=0, # 0.7
add_time_ids=None,
**kwargs,
):
if self.is_xl:
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
added_cond_kwargs = {
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": add_time_ids,
}
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
else:
# if we are doing classifier free guidance, need to double up
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
def diffuse_some_steps(
self,
latents: torch.FloatTensor,
text_embeddings: PromptEmbeds,
total_timesteps: int = 1000,
start_timesteps=0,
guidance_scale=1,
add_time_ids=None,
**kwargs,
):
for timestep in tqdm(self.noise_scheduler.timesteps[start_timesteps:total_timesteps], leave=False):
noise_pred = self.predict_noise(
latents,
text_embeddings,
timestep,
guidance_scale=guidance_scale,
add_time_ids=add_time_ids,
**kwargs,
)
latents = self.noise_scheduler.step(noise_pred, timestep, latents).prev_sample
# return latents_steps
return latents
def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds: def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds:
prompt = prompt prompt = prompt