Merge remote-tracking branch 'origin/development'

# Conflicts:
#	toolkit/stable_diffusion_model.py
This commit is contained in:
Jaret Burkett
2023-11-28 10:40:05 -07:00
89 changed files with 22797 additions and 4003 deletions

6
.gitmodules vendored
View File

@@ -4,3 +4,9 @@
[submodule "repositories/leco"]
path = repositories/leco
url = https://github.com/p1atdev/LECO
[submodule "repositories/batch_annotator"]
path = repositories/batch_annotator
url = https://github.com/ostris/batch-annotator
[submodule "repositories/ipadapter"]
path = repositories/ipadapter
url = https://github.com/tencent-ailab/IP-Adapter.git

View File

@@ -0,0 +1,102 @@
import os
from collections import OrderedDict
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
import gc
import torch
from jobs.process import BaseExtensionProcess
from toolkit.train_tools import get_torch_dtype
def flush():
torch.cuda.empty_cache()
gc.collect()
class PureLoraGenerator(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.device = self.get_conf('device', 'cuda')
self.device_torch = torch.device(self.device)
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
self.dtype = self.get_conf('dtype', 'float16')
self.torch_dtype = get_torch_dtype(self.dtype)
lorm_config = self.get_conf('lorm', None)
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
self.device_state_preset = get_train_sd_device_state_preset(
device=torch.device(self.device),
)
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.dtype,
)
def run(self):
super().run()
print("Loading model...")
with torch.no_grad():
self.sd.load_model()
self.sd.unet.eval()
self.sd.unet.to(self.device_torch)
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
te.to(self.device_torch)
else:
self.sd.text_encoder.eval()
self.sd.to(self.device_torch)
print(f"Converting to LoRM UNet")
# replace the unet with LoRMUnet
convert_diffusers_unet_to_lorm(
self.sd.unet,
config=self.lorm_config,
)
sample_folder = os.path.join(self.output_folder)
gen_img_config_list = []
sample_config = self.generate_config
start_seed = sample_config.seed
current_seed = start_seed
for i in range(len(sample_config.prompts)):
if sample_config.walk_seed:
current_seed = start_seed + i
filename = f"[time]_[count].{self.generate_config.ext}"
output_path = os.path.join(sample_folder, filename)
prompt = sample_config.prompts[i]
extra_args = {}
gen_img_config_list.append(GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
width=sample_config.width,
height=sample_config.height,
negative_prompt=sample_config.neg,
seed=current_seed,
guidance_scale=sample_config.guidance_scale,
guidance_rescale=sample_config.guidance_rescale,
num_inference_steps=sample_config.sample_steps,
network_multiplier=sample_config.network_multiplier,
output_path=output_path,
output_ext=sample_config.ext,
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
**extra_args
))
# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()

View File

@@ -0,0 +1,193 @@
import os
import random
from collections import OrderedDict
from typing import List
import numpy as np
from PIL import Image
from diffusers import T2IAdapter
from torch.utils.data import DataLoader
from diffusers import StableDiffusionXLAdapterPipeline
from tqdm import tqdm
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.sampler import get_sampler
from toolkit.stable_diffusion_model import StableDiffusion
import gc
import torch
from jobs.process import BaseExtensionProcess
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.train_tools import get_torch_dtype
from controlnet_aux.midas import MidasDetector
from diffusers.utils import load_image
def flush():
torch.cuda.empty_cache()
gc.collect()
class GenerateConfig:
def __init__(self, **kwargs):
self.prompts: List[str]
self.sampler = kwargs.get('sampler', 'ddpm')
self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1)
self.walk_seed = kwargs.get('walk_seed', False)
self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.prompt_2 = kwargs.get('prompt_2', None)
self.neg_2 = kwargs.get('neg_2', None)
self.prompts = kwargs.get('prompts', None)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext = kwargs.get('ext', 'png')
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
if kwargs.get('shuffle', False):
# shuffle the prompts
random.shuffle(self.prompts)
class ReferenceGenerator(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.device = self.get_conf('device', 'cuda')
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
self.is_latents_cached = True
raw_datasets = self.get_conf('datasets', None)
if raw_datasets is not None and len(raw_datasets) > 0:
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
self.dtype = self.get_conf('dtype', 'float16')
self.torch_dtype = get_torch_dtype(self.dtype)
self.params = []
if raw_datasets is not None and len(raw_datasets) > 0:
for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset)
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
if not is_caching:
self.is_latents_cached = False
if dataset.is_reg:
if self.datasets_reg is None:
self.datasets_reg = []
self.datasets_reg.append(dataset)
else:
if self.datasets is None:
self.datasets = []
self.datasets.append(dataset)
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.dtype,
)
print(f"Using device {self.device}")
self.data_loader: DataLoader = None
self.adapter: T2IAdapter = None
def run(self):
super().run()
print("Loading model...")
self.sd.load_model()
device = torch.device(self.device)
if self.generate_config.t2i_adapter_path is not None:
self.adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16"
).to(device)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to(device)
pipe = StableDiffusionXLAdapterPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=get_sampler(self.generate_config.sampler),
adapter=self.adapter,
).to(device)
pipe.set_progress_bar_config(disable=True)
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
num_batches = len(self.data_loader)
pbar = tqdm(total=num_batches, desc="Generating images")
seed = self.generate_config.seed
# load images from datasets, use tqdm
for i, batch in enumerate(self.data_loader):
batch: DataLoaderBatchDTO = batch
file_item: FileItemDTO = batch.file_items[0]
img_path = file_item.path
img_filename = os.path.basename(img_path)
img_filename_no_ext = os.path.splitext(img_filename)[0]
output_path = os.path.join(self.output_folder, img_filename)
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
caption = batch.get_caption_list()[0]
img: torch.Tensor = batch.tensor.clone()
# image comes in -1 to 1. convert to a PIL RGB image
img = (img + 1) / 2
img = img.clamp(0, 1)
img = img[0].permute(1, 2, 0).cpu().numpy()
img = (img * 255).astype(np.uint8)
image = Image.fromarray(img)
width, height = image.size
min_res = min(width, height)
if self.generate_config.walk_seed:
seed = seed + 1
if self.generate_config.seed == -1:
# random
seed = random.randint(0, 1000000)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# generate depth map
image = midas_depth(
image,
detect_resolution=min_res, # do 512 ?
image_resolution=min_res
)
# image.save(output_depth_path)
gen_images = pipe(
prompt=caption,
negative_prompt=self.generate_config.neg,
image=image,
num_inference_steps=self.generate_config.sample_steps,
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
guidance_scale=self.generate_config.guidance_scale,
).images[0]
gen_images.save(output_path)
# save caption
with open(output_caption_path, 'w') as f:
f.write(caption)
pbar.update(1)
batch.cleanup()
pbar.close()
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()

View File

@@ -0,0 +1,42 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# This is for generic training (LoRA, Dreambooth, FineTuning)
class AdvancedReferenceGeneratorExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "reference_generator"
# name is the name of the extension for printing
name = "Reference Generator"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .ReferenceGenerator import ReferenceGenerator
return ReferenceGenerator
# This is for generic training (LoRA, Dreambooth, FineTuning)
class PureLoraGenerator(Extension):
# uid must be unique, it is how the extension is identified
uid = "pure_lora_generator"
# name is the name of the extension for printing
name = "Pure LoRA Generator"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .PureLoraGenerator import PureLoraGenerator
return PureLoraGenerator
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
AdvancedReferenceGeneratorExtension, PureLoraGenerator
]

View File

@@ -0,0 +1,91 @@
---
job: extension
config:
name: test_v1
process:
- type: 'textual_inversion_trainer'
training_folder: "out/TI"
device: cuda:0
# for tensorboard logging
log_dir: "out/.tensorboard"
embedding:
trigger: "your_trigger_here"
tokens: 12
init_words: "man with short brown hair"
save_format: "safetensors" # 'safetensors' or 'pt'
save:
dtype: float16 # precision to save
save_every: 100 # save every this many steps
max_step_saves_to_keep: 5 # only affects step counts
datasets:
- folder_path: "/path/to/dataset"
caption_ext: "txt"
default_caption: "[trigger]"
buckets: true
resolution: 512
train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 3000
weight_jitter: 0.0
lr: 5e-5
train_unet: false
gradient_checkpointing: true
train_text_encoder: false
optimizer: "adamw"
# optimizer: "prodigy"
optimizer_params:
weight_decay: 1e-2
lr_scheduler: "constant"
max_denoising_steps: 1000
batch_size: 4
dtype: bf16
xformers: true
min_snr_gamma: 5.0
# skip_first_sample: true
noise_offset: 0.0 # not needed for this
model:
# objective reality v2
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
is_v2: false # for v2 models
is_xl: false # for SDXL models
is_v_pred: false # for v-prediction models (most v2 models)
sample:
sampler: "ddpm" # must match train.noise_scheduler
sample_every: 100 # sample every this many steps
width: 512
height: 512
prompts:
- "photo of [trigger] laughing"
- "photo of [trigger] smiling"
- "[trigger] close up"
- "dark scene [trigger] frozen"
- "[trigger] nighttime"
- "a painting of [trigger]"
- "a drawing of [trigger]"
- "a cartoon of [trigger]"
- "[trigger] pixar style"
- "[trigger] costume"
neg: ""
seed: 42
walk_seed: false
guidance_scale: 7
sample_steps: 20
network_multiplier: 1.0
logging:
log_every: 10 # log every this many steps
use_wandb: false # not supported yet
verbose: false
# You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want.
# It is saved in the model so be aware of that. The software will include this
# plus some other information for you automatically
meta:
# [name] gets replaced with the name above
name: "[name]"
# version: '1.0'
# creator:
# name: Your Name
# email: your@gmail.com
# website: https://your.website

View File

@@ -0,0 +1,151 @@
import random
from collections import OrderedDict
from torch.utils.data import DataLoader
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc
import torch
from jobs.process import BaseSDTrainProcess
def flush():
torch.cuda.empty_cache()
gc.collect()
class ConceptReplacementConfig:
def __init__(self, **kwargs):
self.concept: str = kwargs.get('concept', '')
self.replacement: str = kwargs.get('replacement', '')
class ConceptReplacer(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
replacement_list = self.config.get('replacements', [])
self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list]
def before_model_load(self):
pass
def hook_before_train_loop(self):
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
# textual inversion
if self.embedding is not None:
# set text encoder to train. Not sure if this is necessary but diffusers example did it
self.sd.text_encoder.train()
def hook_train_loop(self, batch):
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
else:
network = BlankNetwork()
batch_replacement_list = []
# get a random replacement for each prompt
for prompt in conditioned_prompts:
replacement = random.choice(self.replacement_list)
batch_replacement_list.append(replacement)
# build out prompts
concept_prompts = []
replacement_prompts = []
for idx, replacement in enumerate(batch_replacement_list):
prompt = conditioned_prompts[idx]
# insert shuffled concept at beginning and end of prompt
shuffled_concept = [x.strip() for x in replacement.concept.split(',')]
random.shuffle(shuffled_concept)
shuffled_concept = ', '.join(shuffled_concept)
concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}")
# insert replacement at beginning and end of prompt
shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')]
random.shuffle(shuffled_replacement)
shuffled_replacement = ', '.join(shuffled_replacement)
replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}")
# predict the replacement without network
conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype)
replacement_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
)
del conditional_embeds
replacement_pred = replacement_pred.detach()
self.optimizer.zero_grad()
flush()
# text encoding
grad_on_text_encoder = False
if self.train_config.train_text_encoder:
grad_on_text_encoder = True
if self.embedding:
grad_on_text_encoder = True
# set the weights
network.multiplier = network_weight_list
# activate network if it exits
with network:
with torch.set_grad_enabled(grad_on_text_encoder):
# embed the prompts
conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype)
if not grad_on_text_encoder:
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
self.optimizer.zero_grad()
flush()
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.embedding is not None:
# Let's make sure we don't update any embedding weights besides the newly added token
self.embedding.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}
)
# reset network multiplier
network.multiplier = 1.0
return loss_dict

View File

@@ -0,0 +1,26 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# This is for generic training (LoRA, Dreambooth, FineTuning)
class ConceptReplacerExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "concept_replacer"
# name is the name of the extension for printing
name = "Concept Replacer"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .ConceptReplacer import ConceptReplacer
return ConceptReplacer
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
ConceptReplacerExtension,
]

View File

@@ -0,0 +1,91 @@
---
job: extension
config:
name: test_v1
process:
- type: 'textual_inversion_trainer'
training_folder: "out/TI"
device: cuda:0
# for tensorboard logging
log_dir: "out/.tensorboard"
embedding:
trigger: "your_trigger_here"
tokens: 12
init_words: "man with short brown hair"
save_format: "safetensors" # 'safetensors' or 'pt'
save:
dtype: float16 # precision to save
save_every: 100 # save every this many steps
max_step_saves_to_keep: 5 # only affects step counts
datasets:
- folder_path: "/path/to/dataset"
caption_ext: "txt"
default_caption: "[trigger]"
buckets: true
resolution: 512
train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 3000
weight_jitter: 0.0
lr: 5e-5
train_unet: false
gradient_checkpointing: true
train_text_encoder: false
optimizer: "adamw"
# optimizer: "prodigy"
optimizer_params:
weight_decay: 1e-2
lr_scheduler: "constant"
max_denoising_steps: 1000
batch_size: 4
dtype: bf16
xformers: true
min_snr_gamma: 5.0
# skip_first_sample: true
noise_offset: 0.0 # not needed for this
model:
# objective reality v2
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
is_v2: false # for v2 models
is_xl: false # for SDXL models
is_v_pred: false # for v-prediction models (most v2 models)
sample:
sampler: "ddpm" # must match train.noise_scheduler
sample_every: 100 # sample every this many steps
width: 512
height: 512
prompts:
- "photo of [trigger] laughing"
- "photo of [trigger] smiling"
- "[trigger] close up"
- "dark scene [trigger] frozen"
- "[trigger] nighttime"
- "a painting of [trigger]"
- "a drawing of [trigger]"
- "a cartoon of [trigger]"
- "[trigger] pixar style"
- "[trigger] costume"
neg: ""
seed: 42
walk_seed: false
guidance_scale: 7
sample_steps: 20
network_multiplier: 1.0
logging:
log_every: 10 # log every this many steps
use_wandb: false # not supported yet
verbose: false
# You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want.
# It is saved in the model so be aware of that. The software will include this
# plus some other information for you automatically
meta:
# [name] gets replaced with the name above
name: "[name]"
# version: '1.0'
# creator:
# name: Your Name
# email: your@gmail.com
# website: https://your.website

View File

@@ -0,0 +1,20 @@
from collections import OrderedDict
import gc
import torch
from jobs.process import BaseExtensionProcess
def flush():
torch.cuda.empty_cache()
gc.collect()
class DatasetTools(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
def run(self):
super().run()
raise NotImplementedError("This extension is not yet implemented")

View File

@@ -0,0 +1,196 @@
import copy
import json
import os
from collections import OrderedDict
import gc
import traceback
import torch
from PIL import Image, ImageOps
from tqdm import tqdm
from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
from .tools.fuyu_utils import FuyuImageProcessor
from .tools.image_tools import load_image, ImageProcessor, resize_to_max
from .tools.llava_utils import LLaVAImageProcessor
from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
from jobs.process import BaseExtensionProcess
from .tools.sync_tools import get_img_paths
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
def flush():
torch.cuda.empty_cache()
gc.collect()
VERSION = 2
class SuperTagger(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
parent_dir = config.get('parent_dir', None)
self.dataset_paths: list[str] = config.get('dataset_paths', [])
self.device = config.get('device', 'cuda')
self.steps: list[Step] = config.get('steps', [])
self.caption_method = config.get('caption_method', 'llava:default')
self.caption_prompt = config.get('caption_prompt', default_long_prompt)
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
self.force_reprocess_img = config.get('force_reprocess_img', False)
self.caption_replacements = config.get('caption_replacements', default_replacements)
self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
self.master_dataset_dict = OrderedDict()
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
if parent_dir is not None and len(self.dataset_paths) == 0:
# find all folders in the patent_dataset_path
self.dataset_paths = [
os.path.join(parent_dir, folder)
for folder in os.listdir(parent_dir)
if os.path.isdir(os.path.join(parent_dir, folder))
]
else:
# make sure they exist
for dataset_path in self.dataset_paths:
if not os.path.exists(dataset_path):
raise ValueError(f"Dataset path does not exist: {dataset_path}")
print(f"Found {len(self.dataset_paths)} dataset paths")
self.image_processor: ImageProcessor = self.get_image_processor()
def get_image_processor(self):
if self.caption_method.startswith('llava'):
return LLaVAImageProcessor(device=self.device)
elif self.caption_method.startswith('fuyu'):
return FuyuImageProcessor(device=self.device)
else:
raise ValueError(f"Unknown caption method: {self.caption_method}")
def process_image(self, img_path: str):
root_img_dir = os.path.dirname(os.path.dirname(img_path))
filename = os.path.basename(img_path)
filename_no_ext = os.path.splitext(filename)[0]
train_dir = os.path.join(root_img_dir, TRAIN_DIR)
train_img_path = os.path.join(train_dir, filename)
json_path = os.path.join(train_dir, f"{filename_no_ext}.json")
# check if json exists, if it does load it as image info
if os.path.exists(json_path):
with open(json_path, 'r') as f:
img_info = ImgInfo(**json.load(f))
else:
img_info = ImgInfo()
# always send steps first in case other processes need them
img_info.add_steps(copy.deepcopy(self.steps))
img_info.set_version(VERSION)
img_info.set_caption_method(self.caption_method)
image: Image = None
caption_image: Image = None
did_update_image = False
# trigger reprocess of steps
if self.force_reprocess_img:
img_info.trigger_image_reprocess()
# set the image as updated if it does not exist on disk
if not os.path.exists(train_img_path):
did_update_image = True
image = load_image(img_path)
if img_info.force_image_process:
did_update_image = True
image = load_image(img_path)
# go through the needed steps
for step in copy.deepcopy(img_info.state.steps_to_complete):
if step == 'caption':
# load image
if image is None:
image = load_image(img_path)
if caption_image is None:
caption_image = resize_to_max(image, 1024, 1024)
if not self.image_processor.is_loaded:
print('Loading Model. Takes a while, especially the first time')
self.image_processor.load_model()
img_info.caption = self.image_processor.generate_caption(
image=caption_image,
prompt=self.caption_prompt,
replacements=self.caption_replacements
)
img_info.mark_step_complete(step)
elif step == 'caption_short':
# load image
if image is None:
image = load_image(img_path)
if caption_image is None:
caption_image = resize_to_max(image, 1024, 1024)
if not self.image_processor.is_loaded:
print('Loading Model. Takes a while, especially the first time')
self.image_processor.load_model()
img_info.caption_short = self.image_processor.generate_caption(
image=caption_image,
prompt=self.caption_short_prompt,
replacements=self.caption_short_replacements
)
img_info.mark_step_complete(step)
elif step == 'contrast_stretch':
# load image
if image is None:
image = load_image(img_path)
image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True)
did_update_image = True
img_info.mark_step_complete(step)
else:
raise ValueError(f"Unknown step: {step}")
os.makedirs(os.path.dirname(train_img_path), exist_ok=True)
if did_update_image:
image.save(train_img_path)
if img_info.is_dirty:
with open(json_path, 'w') as f:
json.dump(img_info.to_dict(), f, indent=4)
if self.dataset_master_config_file:
# add to master dict
self.master_dataset_dict[train_img_path] = img_info.to_dict()
def run(self):
super().run()
imgs_to_process = []
# find all images
for dataset_path in self.dataset_paths:
raw_dir = os.path.join(dataset_path, RAW_DIR)
raw_image_paths = get_img_paths(raw_dir)
for raw_image_path in raw_image_paths:
imgs_to_process.append(raw_image_path)
if len(imgs_to_process) == 0:
print(f"No images to process")
else:
print(f"Found {len(imgs_to_process)} to process")
for img_path in tqdm(imgs_to_process, desc="Processing images"):
try:
self.process_image(img_path)
except Exception:
# print full stack trace
print(traceback.format_exc())
continue
# self.process_image(img_path)
if self.dataset_master_config_file is not None:
# save it as json
with open(self.dataset_master_config_file, 'w') as f:
json.dump(self.master_dataset_dict, f, indent=4)
del self.image_processor
flush()

View File

@@ -0,0 +1,131 @@
import os
import shutil
from collections import OrderedDict
import gc
from typing import List
import torch
from tqdm import tqdm
from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
get_img_paths
from jobs.process import BaseExtensionProcess
def flush():
torch.cuda.empty_cache()
gc.collect()
class SyncFromCollection(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.min_width = config.get('min_width', 1024)
self.min_height = config.get('min_height', 1024)
# add our min_width and min_height to each dataset config if they don't exist
for dataset_config in config.get('dataset_sync', []):
if 'min_width' not in dataset_config:
dataset_config['min_width'] = self.min_width
if 'min_height' not in dataset_config:
dataset_config['min_height'] = self.min_height
self.dataset_configs: List[DatasetSyncCollectionConfig] = [
DatasetSyncCollectionConfig(**dataset_config)
for dataset_config in config.get('dataset_sync', [])
]
print(f"Found {len(self.dataset_configs)} dataset configs")
def move_new_images(self, root_dir: str):
raw_dir = os.path.join(root_dir, RAW_DIR)
new_dir = os.path.join(root_dir, NEW_DIR)
new_images = get_img_paths(new_dir)
for img_path in new_images:
# move to raw
new_path = os.path.join(raw_dir, os.path.basename(img_path))
shutil.move(img_path, new_path)
# remove new dir
shutil.rmtree(new_dir)
def sync_dataset(self, config: DatasetSyncCollectionConfig):
if config.host == 'unsplash':
get_images = get_unsplash_images
elif config.host == 'pexels':
get_images = get_pexels_images
else:
raise ValueError(f"Unknown host: {config.host}")
results = {
'num_downloaded': 0,
'num_skipped': 0,
'bad': 0,
'total': 0,
}
photos = get_images(config)
raw_dir = os.path.join(config.directory, RAW_DIR)
new_dir = os.path.join(config.directory, NEW_DIR)
raw_images = get_local_image_file_names(raw_dir)
new_images = get_local_image_file_names(new_dir)
for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
try:
if photo.filename not in raw_images and photo.filename not in new_images:
download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
results['num_downloaded'] += 1
else:
results['num_skipped'] += 1
except Exception as e:
print(f" - BAD({photo.id}): {e}")
results['bad'] += 1
continue
results['total'] += 1
return results
def print_results(self, results):
print(
f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
def run(self):
super().run()
print(f"Syncing {len(self.dataset_configs)} datasets")
all_results = None
failed_datasets = []
for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
try:
results = self.sync_dataset(dataset_config)
if all_results is None:
all_results = {**results}
else:
for key, value in results.items():
all_results[key] += value
self.print_results(results)
except Exception as e:
print(f" - FAILED: {e}")
if 'response' in e.__dict__:
error = f"{e.response.status_code}: {e.response.text}"
print(f" - {error}")
failed_datasets.append({'dataset': dataset_config, 'error': error})
else:
failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
continue
print("Moving new images to raw")
for dataset_config in self.dataset_configs:
self.move_new_images(dataset_config.directory)
print("Done syncing datasets")
self.print_results(all_results)
if len(failed_datasets) > 0:
print(f"Failed to sync {len(failed_datasets)} datasets")
for failed in failed_datasets:
print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
print(f" - ERR: {failed['error']}")

View File

@@ -0,0 +1,43 @@
from toolkit.extension import Extension
class DatasetToolsExtension(Extension):
uid = "dataset_tools"
# name is the name of the extension for printing
name = "Dataset Tools"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .DatasetTools import DatasetTools
return DatasetTools
class SyncFromCollectionExtension(Extension):
uid = "sync_from_collection"
name = "Sync from Collection"
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .SyncFromCollection import SyncFromCollection
return SyncFromCollection
class SuperTaggerExtension(Extension):
uid = "super_tagger"
name = "Super Tagger"
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .SuperTagger import SuperTagger
return SuperTagger
AI_TOOLKIT_EXTENSIONS = [
SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension
]

View File

@@ -0,0 +1,53 @@
caption_manipulation_steps = ['caption', 'caption_short']
default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
default_short_prompt = 'caption this image in less than ten words'
default_replacements = [
("the image features", ""),
("the image shows", ""),
("the image depicts", ""),
("the image is", ""),
("in this image", ""),
("in the image", ""),
]
def clean_caption(cap, replacements=None):
if replacements is None:
replacements = default_replacements
# remove any newlines
cap = cap.replace("\n", ", ")
cap = cap.replace("\r", ", ")
cap = cap.replace(".", ",")
cap = cap.replace("\"", "")
# remove unicode characters
cap = cap.encode('ascii', 'ignore').decode('ascii')
# make lowercase
cap = cap.lower()
# remove any extra spaces
cap = " ".join(cap.split())
for replacement in replacements:
if replacement[0].startswith('*'):
# we are removing all text if it starts with this and the rest matches
search_text = replacement[0][1:]
if cap.startswith(search_text):
cap = ""
else:
cap = cap.replace(replacement[0].lower(), replacement[1].lower())
cap_list = cap.split(",")
# trim whitespace
cap_list = [c.strip() for c in cap_list]
# remove empty strings
cap_list = [c for c in cap_list if c != ""]
# remove duplicates
cap_list = list(dict.fromkeys(cap_list))
# join back together
cap = ", ".join(cap_list)
return cap

View File

@@ -0,0 +1,187 @@
import json
from typing import Literal, Type, TYPE_CHECKING
Host: Type = Literal['unsplash', 'pexels']
RAW_DIR = "raw"
NEW_DIR = "_tmp"
TRAIN_DIR = "train"
DEPTH_DIR = "depth"
from .image_tools import Step, img_manipulation_steps
from .caption import caption_manipulation_steps
class DatasetSyncCollectionConfig:
def __init__(self, **kwargs):
self.host: Host = kwargs.get('host', None)
self.collection_id: str = kwargs.get('collection_id', None)
self.directory: str = kwargs.get('directory', None)
self.api_key: str = kwargs.get('api_key', None)
self.min_width: int = kwargs.get('min_width', 1024)
self.min_height: int = kwargs.get('min_height', 1024)
if self.host is None:
raise ValueError("host is required")
if self.collection_id is None:
raise ValueError("collection_id is required")
if self.directory is None:
raise ValueError("directory is required")
if self.api_key is None:
raise ValueError(f"api_key is required: {self.host}:{self.collection_id}")
class ImageState:
def __init__(self, **kwargs):
self.steps_complete: list[Step] = kwargs.get('steps_complete', [])
self.steps_to_complete: list[Step] = kwargs.get('steps_to_complete', [])
def to_dict(self):
return {
'steps_complete': self.steps_complete
}
class Rect:
def __init__(self, **kwargs):
self.x = kwargs.get('x', 0)
self.y = kwargs.get('y', 0)
self.width = kwargs.get('width', 0)
self.height = kwargs.get('height', 0)
def to_dict(self):
return {
'x': self.x,
'y': self.y,
'width': self.width,
'height': self.height
}
class ImgInfo:
def __init__(self, **kwargs):
self.version: int = kwargs.get('version', None)
self.caption: str = kwargs.get('caption', None)
self.caption_short: str = kwargs.get('caption_short', None)
self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])]
self.state = ImageState(**kwargs.get('state', {}))
self.caption_method = kwargs.get('caption_method', None)
self.other_captions = kwargs.get('other_captions', {})
self._upgrade_state()
self.force_image_process: bool = False
self._requested_steps: list[Step] = []
self.is_dirty: bool = False
def _upgrade_state(self):
# upgrades older states
if self.caption is not None and 'caption' not in self.state.steps_complete:
self.mark_step_complete('caption')
self.is_dirty = True
if self.caption_short is not None and 'caption_short' not in self.state.steps_complete:
self.mark_step_complete('caption_short')
self.is_dirty = True
if self.caption_method is None and self.caption is not None:
# added caption method in version 2. Was all llava before that
self.caption_method = 'llava:default'
self.is_dirty = True
def to_dict(self):
return {
'version': self.version,
'caption_method': self.caption_method,
'caption': self.caption,
'caption_short': self.caption_short,
'poi': [poi.to_dict() for poi in self.poi],
'state': self.state.to_dict(),
'other_captions': self.other_captions
}
def mark_step_complete(self, step: Step):
if step not in self.state.steps_complete:
self.state.steps_complete.append(step)
if step in self.state.steps_to_complete:
self.state.steps_to_complete.remove(step)
self.is_dirty = True
def add_step(self, step: Step):
if step not in self.state.steps_to_complete and step not in self.state.steps_complete:
self.state.steps_to_complete.append(step)
def trigger_image_reprocess(self):
if self._requested_steps is None:
raise Exception("Must call add_steps before trigger_image_reprocess")
steps = self._requested_steps
# remove all image manipulationf from steps_to_complete
for step in img_manipulation_steps:
if step in self.state.steps_to_complete:
self.state.steps_to_complete.remove(step)
if step in self.state.steps_complete:
self.state.steps_complete.remove(step)
self.force_image_process = True
self.is_dirty = True
# we want to keep the order passed in process file
for step in steps:
if step in img_manipulation_steps:
self.add_step(step)
def add_steps(self, steps: list[Step]):
self._requested_steps = [step for step in steps]
for stage in steps:
self.add_step(stage)
# update steps if we have any img processes not complete, we have to reprocess them all
# if any steps_to_complete are in img_manipulation_steps
is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete])
order_has_changed = False
if not is_manipulating_image:
# check to see if order has changed. No need to if already redoing it. Will detect if ones are removed
target_img_manipulation_order = [step for step in steps if step in img_manipulation_steps]
current_img_manipulation_order = [step for step in self.state.steps_complete if
step in img_manipulation_steps]
if target_img_manipulation_order != current_img_manipulation_order:
order_has_changed = True
if is_manipulating_image or order_has_changed:
self.trigger_image_reprocess()
def set_caption_method(self, method: str):
if self._requested_steps is None:
raise Exception("Must call add_steps before set_caption_method")
if self.caption_method != method:
self.is_dirty = True
# move previous caption method to other_captions
if self.caption_method is not None and self.caption is not None or self.caption_short is not None:
self.other_captions[self.caption_method] = {
'caption': self.caption,
'caption_short': self.caption_short,
}
self.caption_method = method
self.caption = None
self.caption_short = None
# see if we have a caption from the new method
if method in self.other_captions:
self.caption = self.other_captions[method].get('caption', None)
self.caption_short = self.other_captions[method].get('caption_short', None)
else:
self.trigger_new_caption()
def trigger_new_caption(self):
self.caption = None
self.caption_short = None
self.is_dirty = True
# check to see if we have any steps in the complete list and move them to the to_complete list
for step in self.state.steps_complete:
if step in caption_manipulation_steps:
self.state.steps_complete.remove(step)
self.state.steps_to_complete.append(step)
def to_json(self):
return json.dumps(self.to_dict())
def set_version(self, version: int):
if self.version != version:
self.is_dirty = True
self.version = version

View File

@@ -0,0 +1,66 @@
from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
import torch
from PIL import Image
class FuyuImageProcessor:
def __init__(self, device='cuda'):
from transformers import FuyuProcessor, FuyuForCausalLM
self.device = device
self.model: FuyuForCausalLM = None
self.processor: FuyuProcessor = None
self.dtype = torch.bfloat16
self.tokenizer: AutoTokenizer
self.is_loaded = False
def load_model(self):
from transformers import FuyuProcessor, FuyuForCausalLM
model_path = "adept/fuyu-8b"
kwargs = {"device_map": self.device}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=self.dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
self.processor = FuyuProcessor.from_pretrained(model_path)
self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
self.is_loaded = True
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs)
self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer)
def generate_caption(
self, image: Image,
prompt: str = default_long_prompt,
replacements=default_replacements,
max_new_tokens=512
):
# prepare inputs for the model
# text_prompt = f"{prompt}\n"
# image = image.convert('RGB')
model_inputs = self.processor(text=prompt, images=[image])
model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in
model_inputs.items()}
generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens)
prompt_len = model_inputs["input_ids"].shape[-1]
output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
output = clean_caption(output, replacements=replacements)
return output
# inputs = self.processor(text=text_prompt, images=image, return_tensors="pt")
# for k, v in inputs.items():
# inputs[k] = v.to(self.device)
# # autoregressively generate text
# generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True)
# output = generation_text[0]
#
# return clean_caption(output, replacements=replacements)

View File

@@ -0,0 +1,49 @@
from typing import Literal, Type, TYPE_CHECKING, Union
import cv2
import numpy as np
from PIL import Image, ImageOps
Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretch']
img_manipulation_steps = ['contrast_stretch']
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
if TYPE_CHECKING:
from .llava_utils import LLaVAImageProcessor
from .fuyu_utils import FuyuImageProcessor
ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor']
def pil_to_cv2(image):
"""Convert a PIL image to a cv2 image."""
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
def cv2_to_pil(image):
"""Convert a cv2 image to a PIL image."""
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def load_image(img_path: str):
image = Image.open(img_path).convert('RGB')
try:
# transpose with exif data
image = ImageOps.exif_transpose(image)
except Exception as e:
pass
return image
def resize_to_max(image, max_width=1024, max_height=1024):
width, height = image.size
if width <= max_width and height <= max_height:
return image
scale = min(max_width / width, max_height / height)
width = int(width * scale)
height = int(height * scale)
return image.resize((width, height), Image.LANCZOS)

View File

@@ -0,0 +1,85 @@
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
import torch
from PIL import Image, ImageOps
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
class LLaVAImageProcessor:
def __init__(self, device='cuda'):
try:
from llava.model import LlavaLlamaForCausalLM
except ImportError:
# print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
print(
"You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
raise
self.device = device
self.model: LlavaLlamaForCausalLM = None
self.tokenizer: AutoTokenizer = None
self.image_processor: CLIPImageProcessor = None
self.is_loaded = False
def load_model(self):
from llava.model import LlavaLlamaForCausalLM
model_path = "4bit/llava-v1.5-13b-3GB"
# kwargs = {"device_map": "auto"}
kwargs = {"device_map": self.device}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
vision_tower = self.model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=self.device)
self.image_processor = vision_tower.image_processor
self.is_loaded = True
def generate_caption(
self, image:
Image, prompt: str = default_long_prompt,
replacements=default_replacements,
max_new_tokens=512
):
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
# question = "how many dogs are in the picture?"
disable_torch_init()
conv_mode = "llava_v0"
conv = conv_templates[conv_mode].copy()
roles = conv.roles
image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda()
inp = f"{roles[0]}: {prompt}"
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
raw_prompt = conv.get_prompt()
input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
top_p=0.8
)
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv.messages[-1][-1] = outputs
output = outputs.rsplit('</s>', 1)[0]
return clean_caption(output, replacements=replacements)

View File

@@ -0,0 +1,279 @@
import os
import requests
import tqdm
from typing import List, Optional, TYPE_CHECKING
def img_root_path(img_id: str):
return os.path.dirname(os.path.dirname(img_id))
if TYPE_CHECKING:
from .dataset_tools_config_modules import DatasetSyncCollectionConfig
img_exts = ['.jpg', '.jpeg', '.webp', '.png']
class Photo:
def __init__(
self,
id,
host,
width,
height,
url,
filename
):
self.id = str(id)
self.host = host
self.width = width
self.height = height
self.url = url
self.filename = filename
def get_desired_size(img_width: int, img_height: int, min_width: int, min_height: int):
if img_width > img_height:
scale = min_height / img_height
else:
scale = min_width / img_width
new_width = int(img_width * scale)
new_height = int(img_height * scale)
return new_width, new_height
def get_pexels_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
all_images = []
next_page = f"https://api.pexels.com/v1/collections/{config.collection_id}?page=1&per_page=80&type=photos"
while True:
response = requests.get(next_page, headers={
"Authorization": f"{config.api_key}"
})
response.raise_for_status()
data = response.json()
all_images.extend(data['media'])
if 'next_page' in data and data['next_page']:
next_page = data['next_page']
else:
break
photos = []
for image in all_images:
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
url = f"{image['src']['original']}?auto=compress&cs=tinysrgb&h={new_height}&w={new_width}"
filename = os.path.basename(image['src']['original'])
photos.append(Photo(
id=image['id'],
host="pexels",
width=image['width'],
height=image['height'],
url=url,
filename=filename
))
return photos
def get_unsplash_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
headers = {
# "Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"
"Authorization": f"Client-ID {config.api_key}"
}
# headers['Authorization'] = f"Bearer {token}"
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page=1&per_page=30"
response = requests.get(url, headers=headers)
response.raise_for_status()
res_headers = response.headers
# parse the link header to get the next page
# 'Link': '<https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=82>; rel="last", <https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=2>; rel="next"'
has_next_page = False
if 'Link' in res_headers:
has_next_page = True
link_header = res_headers['Link']
link_header = link_header.split(',')
link_header = [link.strip() for link in link_header]
link_header = [link.split(';') for link in link_header]
link_header = [[link[0].strip('<>'), link[1].strip().strip('"')] for link in link_header]
link_header = {link[1]: link[0] for link in link_header}
# get page number from last url
last_page = link_header['rel="last']
last_page = last_page.split('?')[1]
last_page = last_page.split('&')
last_page = [param.split('=') for param in last_page]
last_page = {param[0]: param[1] for param in last_page}
last_page = int(last_page['page'])
all_images = response.json()
if has_next_page:
# assume we start on page 1, so we don't need to get it again
for page in tqdm.tqdm(range(2, last_page + 1)):
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page={page}&per_page=30"
response = requests.get(url, headers=headers)
response.raise_for_status()
all_images.extend(response.json())
photos = []
for image in all_images:
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
url = f"{image['urls']['raw']}&w={new_width}"
filename = f"{image['id']}.jpg"
photos.append(Photo(
id=image['id'],
host="unsplash",
width=image['width'],
height=image['height'],
url=url,
filename=filename
))
return photos
def get_img_paths(dir_path: str):
os.makedirs(dir_path, exist_ok=True)
local_files = os.listdir(dir_path)
# remove non image files
local_files = [file for file in local_files if os.path.splitext(file)[1].lower() in img_exts]
# make full path
local_files = [os.path.join(dir_path, file) for file in local_files]
return local_files
def get_local_image_ids(dir_path: str):
os.makedirs(dir_path, exist_ok=True)
local_files = get_img_paths(dir_path)
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
return set([os.path.basename(file).split('.')[0] for file in local_files])
def get_local_image_file_names(dir_path: str):
os.makedirs(dir_path, exist_ok=True)
local_files = get_img_paths(dir_path)
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
return set([os.path.basename(file) for file in local_files])
def download_image(photo: Photo, dir_path: str, min_width: int = 1024, min_height: int = 1024):
img_width = photo.width
img_height = photo.height
if img_width < min_width or img_height < min_height:
raise ValueError(f"Skipping {photo.id} because it is too small: {img_width}x{img_height}")
img_response = requests.get(photo.url)
img_response.raise_for_status()
os.makedirs(dir_path, exist_ok=True)
filename = os.path.join(dir_path, photo.filename)
with open(filename, 'wb') as file:
file.write(img_response.content)
def update_caption(img_path: str):
# if the caption is a txt file, convert it to a json file
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
# see if it exists
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json")):
# todo add poi and what not
return # we have a json file
caption = ""
# see if txt file exists
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")):
# read it
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"), 'r') as file:
caption = file.read()
# write json file
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json"), 'w') as file:
file.write(f'{{"caption": "{caption}"}}')
# delete txt file
os.remove(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"))
# def equalize_img(img_path: str):
# input_path = img_path
# output_path = os.path.join(img_root_path(img_path), COLOR_CORRECTED_DIR, os.path.basename(img_path))
# os.makedirs(os.path.dirname(output_path), exist_ok=True)
# process_img(
# img_path=input_path,
# output_path=output_path,
# equalize=True,
# max_size=2056,
# white_balance=False,
# gamma_correction=False,
# strength=0.6,
# )
# def annotate_depth(img_path: str):
# # make fake args
# args = argparse.Namespace()
# args.annotator = "midas"
# args.res = 1024
#
# img = cv2.imread(img_path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#
# output = annotate(img, args)
#
# output = output.astype('uint8')
# output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
#
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
# output_path = os.path.join(img_root_path(img_path), DEPTH_DIR, os.path.basename(img_path))
#
# cv2.imwrite(output_path, output)
# def invert_depth(img_path: str):
# img = cv2.imread(img_path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# # invert the colors
# img = cv2.bitwise_not(img)
#
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
# output_path = os.path.join(img_root_path(img_path), INVERTED_DEPTH_DIR, os.path.basename(img_path))
# cv2.imwrite(output_path, img)
#
# # update our list of raw images
# raw_images = get_img_paths(raw_dir)
#
# # update raw captions
# for image_id in tqdm.tqdm(raw_images, desc="Updating raw captions"):
# update_caption(image_id)
#
# # equalize images
# for img_path in tqdm.tqdm(raw_images, desc="Equalizing images"):
# if img_path not in eq_images:
# equalize_img(img_path)
#
# # update our list of eq images
# eq_images = get_img_paths(eq_dir)
# # update eq captions
# for image_id in tqdm.tqdm(eq_images, desc="Updating eq captions"):
# update_caption(image_id)
#
# # annotate depth
# depth_dir = os.path.join(root_dir, DEPTH_DIR)
# depth_images = get_img_paths(depth_dir)
# for img_path in tqdm.tqdm(eq_images, desc="Annotating depth"):
# if img_path not in depth_images:
# annotate_depth(img_path)
#
# depth_images = get_img_paths(depth_dir)
#
# # invert depth
# inv_depth_dir = os.path.join(root_dir, INVERTED_DEPTH_DIR)
# inv_depth_images = get_img_paths(inv_depth_dir)
# for img_path in tqdm.tqdm(depth_images, desc="Inverting depth"):
# if img_path not in inv_depth_images:
# invert_depth(img_path)

View File

@@ -16,6 +16,7 @@ from toolkit import train_tools
import torch
from jobs.process import BaseSDTrainProcess
import random
from toolkit.basic import value_map
def flush():
@@ -91,11 +92,18 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
network_neg_weight = network_neg_weight.item()
# get an array of random floats between -weight_jitter and weight_jitter
loss_jitter_multiplier = 1.0
weight_jitter = self.slider_config.weight_jitter
if weight_jitter > 0.0:
jitter_list = random.uniform(-weight_jitter, weight_jitter)
orig_network_pos_weight = network_pos_weight
network_pos_weight += jitter_list
network_neg_weight += (jitter_list * -1.0)
# penalize the loss for its distance from network_pos_weight
# a jitter_list of abs(3.0) on a weight of 5.0 is a 60% jitter
# so the loss_jitter_multiplier needs to be 0.4
loss_jitter_multiplier = value_map(abs(jitter_list), 0.0, weight_jitter, 1.0, 0.0)
# if items in network_weight list are tensors, convert them to floats
@@ -146,38 +154,33 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
timesteps = torch.cat([timesteps, timesteps], dim=0)
network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0]
flush()
loss_float = None
loss_mirror_float = None
self.optimizer.zero_grad()
noisy_latents.requires_grad = False
# if training text encoder enable grads, else do context of no grad
with torch.set_grad_enabled(self.train_config.train_text_encoder):
# text encoding
embedding_list = []
# embed the prompts
# fix issue with them being tuples sometimes
prompt_list = []
for prompt in prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
if isinstance(prompt, tuple):
prompt = prompt[0]
prompt_list.append(prompt)
conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
if self.model_config.is_xl:
# todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
network_multiplier_list = network_multiplier
noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
noise_list = torch.chunk(noise, 2, dim=0)
timesteps_list = torch.chunk(timesteps, 2, dim=0)
conditional_embeds_list = split_prompt_embeds(conditional_embeds)
else:
network_multiplier_list = [network_multiplier]
noisy_latent_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditional_embeds_list = [conditional_embeds]
# if self.model_config.is_xl:
# # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
# network_multiplier_list = network_multiplier
# noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
# noise_list = torch.chunk(noise, 2, dim=0)
# timesteps_list = torch.chunk(timesteps, 2, dim=0)
# conditional_embeds_list = split_prompt_embeds(conditional_embeds)
# else:
network_multiplier_list = [network_multiplier]
noisy_latent_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditional_embeds_list = [conditional_embeds]
losses = []
# allow to chunk it out to save vram
@@ -205,20 +208,17 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
# todo add snr gamma here
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
loss_slide_float = loss.item()
loss = loss.mean() * loss_jitter_multiplier
loss_float = loss.item()
losses.append(loss_float)
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
optimizer.step()

View File

@@ -1,11 +1,21 @@
from collections import OrderedDict
from torch.utils.data import DataLoader
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from typing import Union, Literal, List
from diffusers import T2IAdapter
from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std
from toolkit.config_modules import GuidanceConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import PromptEmbeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
apply_learnable_snr_gos, LearnableSNRGamma
import gc
import torch
from jobs.process import BaseSDTrainProcess
from torchvision import transforms
def flush():
@@ -13,101 +23,618 @@ def flush():
gc.collect()
adapter_transforms = transforms.Compose([
transforms.ToTensor(),
])
class SDTrainer(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', None]
self.do_prior_prediction = False
self.do_long_prompts = False
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
def before_model_load(self):
pass
def before_dataset_load(self):
self.assistant_adapter = None
# get adapter assistant if one is set
if self.train_config.adapter_assist_name_or_path is not None:
adapter_path = self.train_config.adapter_assist_name_or_path
# dont name this adapter since we are not training it
self.assistant_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
).to(self.device_torch)
self.assistant_adapter.eval()
self.assistant_adapter.requires_grad_(False)
flush()
def hook_before_train_loop(self):
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
# textual inversion
if self.embedding is not None:
# keep original embeddings as reference
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
# set text encoder to train. Not sure if this is necessary but diffusers example did it
self.sd.text_encoder.train()
def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
self.optimizer.zero_grad()
flush()
# text encoding
grad_on_text_encoder = False
if self.train_config.train_text_encoder:
grad_on_text_encoder = True
if self.embedding:
grad_on_text_encoder = True
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
# move vae to device if we did not cache latents
if not self.is_latents_cached:
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
else:
network = BlankNetwork()
# offload it. Already cached
self.sd.vae.to('cpu')
flush()
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
# activate network if it exits
with network:
with torch.set_grad_enabled(grad_on_text_encoder):
embedding_list = []
# embed the prompts
for prompt in conditioned_prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
# you can expand these in a child class to make customization easier
def calculate_loss(
self,
noise_pred: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
prior_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
)
# 9.18 gb
noise = noise.to(self.device_torch, dtype=dtype)
prior_mask_multiplier = None
target_mask_multiplier = None
if self.sd.prediction_type == 'v_prediction':
if self.train_config.match_noise_norm:
# match the norm of the noise
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.inverted_mask_prior:
# we need to make the noise prediction be a masked blending of noise and prior_pred
prior_mask_multiplier = 1.0 - mask_multiplier
# target_mask_multiplier = mask_multiplier
# mask_multiplier = 1.0
target = noise
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
# set masked multiplier to 1.0 so we dont double apply it
# mask_multiplier = 1.0
elif prior_pred is not None:
# matching adapter prediction
target = prior_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
pred = noise_pred
ignore_snr = False
if loss_target == 'source' or loss_target == 'unaugmented':
# ignore_snr = True
if batch.sigmas is None:
raise ValueError("Batch sigmas is None. This should not happen")
# src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190
denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents
weighing = batch.sigmas ** -2.0
if loss_target == 'source':
# denoise the latent and compare to the latent in the batch
target = batch.latents
elif loss_target == 'unaugmented':
# we have to encode images into latents for now
# we also denoise as the unaugmented tensor is not a noisy diffirental
with torch.no_grad():
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
target = unaugmented_latents.detach()
# Get the target for loss depending on the prediction type
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
target = target # we are computing loss against denoise latents
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
# mse loss without reduction
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# multiply by our mask
loss = loss * mask_multiplier
if self.train_config.inverted_mask_prior:
# to a loss to unmasked areas of the prior for unmasked regularization
prior_loss = torch.nn.functional.mse_loss(
prior_pred.float(),
pred.float(),
reduction="none"
)
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
return loss
# back propagate loss to free ram
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch
def get_guided_loss(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
with torch.no_grad():
# Perform targeted guidance (working title)
conditional_noisy_latents = noisy_latents # target images
dtype = get_torch_dtype(self.train_config.dtype)
if batch.unconditional_latents is not None:
# unconditional latents are the "neutral" images. Add noise here identical to
# the noise added to the conditional latents, at the same timesteps
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
batch.unconditional_latents, noise, timesteps
)
# calculate the differential between our conditional (target image) and out unconditional (neutral image)
target_differential_noise = unconditional_noisy_latents - conditional_noisy_latents
target_differential_noise = target_differential_noise.detach()
# add the target differential to the target latents as if it were noise with the scheduler, scaled to
# the current timestep. Scaling the noise here is important as it scales our guidance to the current
# timestep. This is the key to making the guidance work.
guidance_latents = self.sd.noise_scheduler.add_noise(
conditional_noisy_latents,
target_differential_noise,
timesteps
)
# Disable the LoRA network so we can predict parent network knowledge without it
self.network.is_active = False
self.sd.unet.eval()
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
# This acts as our control to preserve the unaltered parts of the image.
baseline_prediction = self.sd.predict_noise(
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
# turn the LoRA network back on.
self.sd.unet.train()
self.network.is_active = True
self.network.multiplier = network_weight_list
# do our prediction with LoRA active on the scaled guidance latents
prediction = self.sd.predict_noise(
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
# remove the baseline prediction from our prediction to get the differential between the two
# all that should be left is the differential between the conditional and unconditional images
pred_differential_noise = prediction - baseline_prediction
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
# not the timestep scaled noise that was added. This is the diffusion training process.
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
# differential between the conditional and unconditional images (target)
loss = torch.nn.functional.mse_loss(
pred_differential_noise.float(),
target_differential_noise.float(),
reduction="none"
)
loss = loss.mean([1, 2, 3])
loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
loss.backward()
flush()
# apply gradients
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
loss.requires_grad_(True)
return loss
def get_prior_prediction(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
# dont use network on this
# self.network.multiplier = 0.0
was_network_active = self.network.is_active
self.network.is_active = False
self.sd.unet.eval()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
self.sd.unet.train()
prior_pred = prior_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
# restore network
# self.network.multiplier = network_weight_list
self.network.is_active = was_network_active
return prior_pred
def before_unet_predict(self):
pass
def after_unet_predict(self):
pass
def end_of_training_loop(self):
pass
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
if self.train_config.single_item_batching:
network_weight_list = network_weight_list + network_weight_list
has_adapter_img = batch.control_tensor is not None
match_adapter_assist = False
# check if we are matching the adapter assistant
if self.assistant_adapter:
if self.train_config.match_adapter_chance == 1.0:
match_adapter_assist = True
elif self.train_config.match_adapter_chance > 0.0:
match_adapter_assist = torch.rand(
(1,), device=self.device_torch, dtype=dtype
) < self.train_config.match_adapter_chance
self.timer.stop('preprocess_batch')
with torch.no_grad():
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
for idx, file_item in enumerate(batch.file_items):
if file_item.is_reg:
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
adapter_images = None
sigmas = None
if has_adapter_img and (self.adapter or self.assistant_adapter):
with self.timer('get_adapter_images'):
# todo move this to data loader
if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
# match in channels
if self.assistant_adapter is not None:
in_channels = self.assistant_adapter.config.in_channels
if adapter_images.shape[1] != in_channels:
# we need to match the channels
adapter_images = adapter_images[:, :in_channels, :, :]
else:
raise NotImplementedError("Adapter images now must be loaded with dataloader")
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
if batch.mask_tensor is not None:
with self.timer('get_mask_multiplier'):
# upsampling no supported for bfloat16
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
mask_multiplier = torch.nn.functional.interpolate(
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
)
# expand to match latents
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
def get_adapter_multiplier():
if self.adapter and isinstance(self.adapter, T2IAdapter):
# training a t2i adapter, not using as assistant.
return 1.0
elif match_adapter_assist:
# training a texture. We want it high
adapter_strength_min = 0.9
adapter_strength_max = 1.0
else:
# training with assistance, we want it low
adapter_strength_min = 0.4
adapter_strength_max = 0.7
# adapter_strength_min = 0.9
# adapter_strength_max = 1.1
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return adapter_conditioning_scale
# flush()
with self.timer('grad_setup'):
# text encoding
grad_on_text_encoder = False
if self.train_config.train_text_encoder:
grad_on_text_encoder = True
if self.embedding:
grad_on_text_encoder = True
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
else:
network = BlankNetwork()
# set the weights
network.multiplier = network_weight_list
self.optimizer.zero_grad(set_to_none=True)
# activate network if it exits
prompts_1 = conditioned_prompts
prompts_2 = None
if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
prompts_1 = batch.get_caption_short_list()
prompts_2 = conditioned_prompts
# make the batch splits
if self.train_config.single_item_batching:
if self.model_config.refiner_name_or_path is not None:
raise ValueError("Single item batching is not supported when training the refiner")
batch_size = noisy_latents.shape[0]
# chunk/split everything
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
noise_list = torch.chunk(noise, batch_size, dim=0)
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
conditioned_prompts_list = [[prompt] for prompt in prompts_1]
if imgs is not None:
imgs_list = torch.chunk(imgs, batch_size, dim=0)
else:
imgs_list = [None for _ in range(batch_size)]
if adapter_images is not None:
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
else:
adapter_images_list = [None for _ in range(batch_size)]
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
if prompts_2 is None:
prompt_2_list = [None for _ in range(batch_size)]
else:
prompt_2_list = [[prompt] for prompt in prompts_2]
else:
noisy_latents_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditioned_prompts_list = [prompts_1]
imgs_list = [imgs]
adapter_images_list = [adapter_images]
mask_multiplier_list = [mask_multiplier]
if prompts_2 is None:
prompt_2_list = [None]
else:
prompt_2_list = [prompts_2]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
noisy_latents_list,
noise_list,
timesteps_list,
conditioned_prompts_list,
imgs_list,
adapter_images_list,
mask_multiplier_list,
prompt_2_list
):
if self.train_config.negative_prompt is not None:
# add negative prompt
conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
range(len(conditioned_prompts))]
if prompt_2 is not None:
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
with network:
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
# flush()
pred_kwargs = {}
if has_adapter_img and (
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.adapter if self.adapter else self.assistant_adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
)
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None:
# do guided loss
loss = self.get_guided_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
)
else:
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
self.after_unet_predict()
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")
with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
loss = loss * loss_multiplier.mean()
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
loss.backward()
# flush()
if not self.is_grad_accumulation_step:
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# only step if we are not accumulating
with self.timer('optimizer_step'):
# apply gradients
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
else:
# gradient accumulation. Just a place for breakpoint
pass
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
with self.timer('scheduler_step'):
self.lr_scheduler.step()
if self.embedding is not None:
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
index_no_updates[
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
with torch.no_grad():
self.sd.text_encoder.get_input_embeddings().weight[
index_no_updates
] = self.orig_embeds_params[index_no_updates]
with self.timer('restore_embeddings'):
# Let's make sure we don't update any embedding weights besides the newly added token
self.embedding.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}
)
self.end_of_training_loop()
return loss_dict

View File

@@ -19,7 +19,7 @@ config:
max_step_saves_to_keep: 5 # only affects step counts
datasets:
- folder_path: "/path/to/dataset"
caption_type: "txt"
caption_ext: "txt"
default_caption: "[trigger]"
buckets: true
resolution: 512

View File

@@ -3,6 +3,6 @@ from collections import OrderedDict
v = OrderedDict()
v["name"] = "ai-toolkit"
v["repo"] = "https://github.com/ostris/ai-toolkit"
v["version"] = "0.0.4"
v["version"] = "0.1.0"
software_meta = v

View File

@@ -2,6 +2,8 @@ import copy
import json
from collections import OrderedDict
from toolkit.timer import Timer
class BaseProcess(object):
@@ -18,6 +20,9 @@ class BaseProcess(object):
self.raw_process_config = config
self.name = self.get_conf('name', self.job.name)
self.meta = copy.deepcopy(self.job.meta)
self.timer: Timer = Timer(f'{self.name} Timer')
self.performance_log_every = self.get_conf('performance_log_every', 0)
print(json.dumps(self.config, indent=4))
def get_conf(self, key, default=None, required=False, as_type=None):

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,10 @@
import random
from datetime import datetime
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Union
import torch
import yaml
from jobs.process.BaseProcess import BaseProcess
@@ -28,9 +30,18 @@ class BaseTrainProcess(BaseProcess):
self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
self.progress_bar: 'tqdm' = None
self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None)
# if training seed is set, use it
if self.training_seed is not None:
torch.manual_seed(self.training_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(self.training_seed)
random.seed(self.training_seed)
self.progress_bar = None
self.writer = None
self.training_folder = self.get_conf('training_folder', self.job.training_folder if hasattr(self.job, 'training_folder') else None)
self.training_folder = self.get_conf('training_folder',
self.job.training_folder if hasattr(self.job, 'training_folder') else None)
self.save_root = os.path.join(self.training_folder, self.name)
self.step = 0
self.first_step = 0
@@ -62,8 +73,7 @@ class BaseTrainProcess(BaseProcess):
self.writer = SummaryWriter(summary_dir)
def save_training_config(self):
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
os.makedirs(self.save_root, exist_ok=True)
save_dif = os.path.join(self.save_root, f'process_config_{timestamp}.yaml')
save_dif = os.path.join(self.save_root, f'config.yaml')
with open(save_dif, 'w') as f:
yaml.dump(self.raw_process_config, f)
yaml.dump(self.job.raw_config, f)

View File

@@ -98,7 +98,7 @@ class GenerateProcess(BaseProcess):
add_prompt_file=self.generate_config.prompt_file
))
# generate images
self.sd.generate_images(prompt_image_configs)
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
print("Done generating images")
# cleanup

View File

@@ -3,10 +3,12 @@ import glob
import os
import time
from collections import OrderedDict
from typing import List, Optional
from PIL import Image
from PIL.ImageOps import exif_transpose
# from basicsr.archs.rrdbnet_arch import RRDBNet
from toolkit.basic import flush
from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys
from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader, ConcatDataset
@@ -67,9 +69,10 @@ class TrainESRGANProcess(BaseTrainProcess):
self.augmentations = self.get_conf('augmentations', {})
self.torch_dtype = get_torch_dtype(self.dtype)
if self.torch_dtype == torch.bfloat16:
self.esrgan_dtype = torch.float16
self.esrgan_dtype = torch.float32
else:
self.esrgan_dtype = torch.float32
self.vgg_19 = None
self.style_weight_scalers = []
self.content_weight_scalers = []
@@ -232,6 +235,7 @@ class TrainESRGANProcess(BaseTrainProcess):
pattern_size=self.zoom,
dtype=self.torch_dtype
).to(self.device, dtype=self.torch_dtype)
self._pattern_loss = self._pattern_loss.to(self.device, dtype=self.torch_dtype)
loss = torch.mean(self._pattern_loss(pred, target))
return loss
@@ -269,13 +273,63 @@ class TrainESRGANProcess(BaseTrainProcess):
if self.use_critic:
self.critic.save(step)
def sample(self, step=None):
def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None):
sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder):
os.makedirs(sample_folder, exist_ok=True)
batch_sample_folder = os.path.join(self.save_root, 'samples_batch')
batch_targets = None
batch_inputs = None
if batch is not None and not os.path.exists(batch_sample_folder):
os.makedirs(batch_sample_folder, exist_ok=True)
self.model.eval()
def process_and_save(img, target_img, save_path):
img = img.to(self.device, dtype=self.esrgan_dtype)
output = self.model(img)
# output = (output / 2 + 0.5).clamp(0, 1)
output = output.clamp(0, 1)
img = img.clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
img = img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
# convert to pillow image
output = Image.fromarray((output * 255).astype(np.uint8))
img = Image.fromarray((img * 255).astype(np.uint8))
if isinstance(target_img, torch.Tensor):
# convert to pil
target_img = target_img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
target_img = Image.fromarray((target_img * 255).astype(np.uint8))
# upscale to size * self.upscale_sample while maintaining pixels
output = output.resize(
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
resample=Image.NEAREST
)
img = img.resize(
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
resample=Image.NEAREST
)
width, height = output.size
# stack input image and decoded image
target_image = target_img.resize((width, height))
output = output.resize((width, height))
img = img.resize((width, height))
output_img = Image.new('RGB', (width * 3, height))
output_img.paste(img, (0, 0))
output_img.paste(output, (width, 0))
output_img.paste(target_image, (width * 2, 0))
output_img.save(save_path)
with torch.no_grad():
for i, img_url in enumerate(self.sample_sources):
img = exif_transpose(Image.open(img_url))
@@ -295,30 +349,6 @@ class TrainESRGANProcess(BaseTrainProcess):
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype)
img = img
output = self.model(img)
# output = (output / 2 + 0.5).clamp(0, 1)
output = output.clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
# convert to pillow image
output = Image.fromarray((output * 255).astype(np.uint8))
# upscale to size * self.upscale_sample while maintaining pixels
output = output.resize(
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
resample=Image.NEAREST
)
width, height = output.size
# stack input image and decoded image
target_image = target_image.resize((width, height))
output = output.resize((width, height))
output_img = Image.new('RGB', (width * 2, height))
output_img.paste(target_image, (0, 0))
output_img.paste(output, (width, 0))
step_num = ''
if step is not None:
@@ -327,8 +357,24 @@ class TrainESRGANProcess(BaseTrainProcess):
seconds_since_epoch = int(time.time())
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
output_img.save(os.path.join(sample_folder, filename))
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
process_and_save(img, target_image, os.path.join(sample_folder, filename))
if batch is not None:
batch_targets = batch[0].detach()
batch_inputs = batch[1].detach()
batch_targets = torch.chunk(batch_targets, batch_targets.shape[0], dim=0)
batch_inputs = torch.chunk(batch_inputs, batch_inputs.shape[0], dim=0)
for i in range(len(batch_inputs)):
if step is not None:
# zero-pad 9 digits
step_num = f"_{str(step).zfill(9)}"
seconds_since_epoch = int(time.time())
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))
self.model.train()
@@ -376,13 +422,14 @@ class TrainESRGANProcess(BaseTrainProcess):
def run(self):
super().run()
self.load_datasets()
steps_per_step = (self.critic.num_critic_per_gen + 1)
max_step_epochs = self.max_steps // len(self.data_loader)
max_step_epochs = self.max_steps // (len(self.data_loader) // steps_per_step)
num_epochs = self.epochs
if num_epochs is None or num_epochs > max_step_epochs:
num_epochs = max_step_epochs
max_epoch_steps = len(self.data_loader) * num_epochs
max_epoch_steps = len(self.data_loader) * num_epochs * steps_per_step
num_steps = self.max_steps
if num_steps is None or num_steps > max_epoch_steps:
num_steps = max_epoch_steps
@@ -445,35 +492,60 @@ class TrainESRGANProcess(BaseTrainProcess):
print("Generating baseline samples")
self.sample(step=0)
# range start at self.epoch_num go to self.epochs
critic_losses = []
for epoch in range(self.epoch_num, self.epochs, 1):
if self.step_num >= self.max_steps:
break
flush()
for targets, inputs in self.data_loader:
if self.step_num >= self.max_steps:
break
with torch.no_grad():
targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
is_critic_only_step = False
if self.use_critic and 1 / (self.critic.num_critic_per_gen + 1) < np.random.uniform():
is_critic_only_step = True
pred = self.model(inputs)
targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
optimizer.zero_grad()
# dont do grads here for critic step
do_grad = not is_critic_only_step
with torch.set_grad_enabled(do_grad):
pred = self.model(inputs)
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
stacked = torch.cat([pred, targets], dim=0)
# stacked = (stacked / 2 + 0.5).clamp(0, 1)
stacked = stacked.clamp(0, 1)
self.vgg_19(stacked)
pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
if torch.isnan(pred).any():
raise ValueError('pred has nan values')
if torch.isnan(targets).any():
raise ValueError('targets has nan values')
if self.use_critic:
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
stacked = torch.cat([pred, targets], dim=0)
# stacked = (stacked / 2 + 0.5).clamp(0, 1)
stacked = stacked.clamp(0, 1)
self.vgg_19(stacked)
# make sure we dont have nans
if torch.isnan(self.vgg19_pool_4.tensor).any():
raise ValueError('vgg19_pool_4 has nan values')
if is_critic_only_step:
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
critic_losses.append(critic_d_loss)
# don't do generator step
continue
else:
critic_d_loss = 0.0
# doing a regular step
if len(critic_losses) == 0:
critic_d_loss = 0
else:
critic_d_loss = sum(critic_losses) / len(critic_losses)
style_loss = self.get_style_loss() * self.style_weight
content_loss = self.get_content_loss() * self.content_weight
mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight
tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight
pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight
@@ -483,10 +555,13 @@ class TrainESRGANProcess(BaseTrainProcess):
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
# make sure non nan
if torch.isnan(loss):
raise ValueError('loss is nan')
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
scheduler.step()
@@ -549,7 +624,7 @@ class TrainESRGANProcess(BaseTrainProcess):
if self.sample_every and self.step_num % self.sample_every == 0:
# print above the progress bar
self.print(f"Sampling at step {self.step_num}")
self.sample(self.step_num)
self.sample(self.step_num, batch=[targets, inputs])
if self.save_every and self.step_num % self.save_every == 0:
# print above the progress bar

View File

@@ -1,76 +0,0 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import time
from collections import OrderedDict
import os
from toolkit.config_modules import SliderConfig
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
import torch
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
def flush():
torch.cuda.empty_cache()
gc.collect()
class LoRAHack:
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'suppression')
class TrainLoRAHack(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
def hook_before_train_loop(self):
# we don't need text encoder so move it to cpu
self.sd.text_encoder.to("cpu")
flush()
# end hook_before_train_loop
if self.hack_config.type == 'suppression':
# set all params to self.current_suppression
params = self.network.parameters()
for param in params:
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = noise * 0.001
def supress_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
loss_dict = OrderedDict(
{'sup': 0.0}
)
# increase noise
for param in self.network.parameters():
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = param.data + noise * 0.001
return loss_dict
def hook_train_loop(self, batch):
if self.hack_config.type == 'suppression':
return self.supress_loop()
else:
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
# end hook_train_loop

View File

@@ -1,9 +1,19 @@
import copy
import os
import random
from collections import OrderedDict
from typing import Union
from PIL import Image
from diffusers import T2IAdapter
from torchvision.transforms import transforms
from tqdm import tqdm
from toolkit.basic import value_map
from toolkit.config_modules import SliderConfig
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_learnable_snr_gos
import gc
from toolkit import train_tools
from toolkit.prompt_utils import \
@@ -21,6 +31,11 @@ def flush():
gc.collect()
adapter_transforms = transforms.Compose([
transforms.ToTensor(),
])
class TrainSliderProcess(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -42,6 +57,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
# trim targets
self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
# get presets
self.eval_slider_device_state = get_train_sd_device_state_preset(
self.device_torch,
train_unet=False,
train_text_encoder=False,
cached_latents=self.is_latents_cached,
train_lora=False,
train_adapter=False,
train_embedding=False,
)
self.train_slider_device_state = get_train_sd_device_state_preset(
self.device_torch,
train_unet=self.train_config.train_unet,
train_text_encoder=False,
cached_latents=self.is_latents_cached,
train_lora=True,
train_adapter=False,
train_embedding=False,
)
def before_model_load(self):
pass
@@ -66,6 +102,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# trim list to our max steps
cache = PromptEmbedsCache()
print(f"Building prompt cache")
# get encoded latents for our prompts
with torch.no_grad():
@@ -175,31 +212,107 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.sd.vae.to(self.device_torch)
# end hook_before_train_loop
def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype)
def before_dataset_load(self):
if self.slider_config.use_adapter == 'depth':
print(f"Loading T2I Adapter for depth")
# called before LoRA network is loaded but after model is loaded
# attach the adapter here so it is there before we load the network
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
if self.model_config.is_xl:
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
print(f"Loading T2I Adapter from {adapter_path}")
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
]
# move to device and dtype
prompt_pair.to(self.device_torch, dtype=dtype)
# dont name this adapter since we are not training it
self.t2i_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
).to(self.device_torch)
self.t2i_adapter.eval()
self.t2i_adapter.requires_grad_(False)
flush()
# get a random resolution
height, width = self.slider_config.resolutions[
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
]
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
@torch.no_grad()
def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']):
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
adapter_folder_path = self.slider_config.adapter_img_dir
adapter_images = []
# loop through images
for file_item in batch.file_items:
img_path = file_item.path
file_name_no_ext = os.path.basename(img_path).split('.')[0]
# find the image
for ext in img_ext_list:
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
break
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
adapter_tensors = []
# load images with torch transforms
for idx, adapter_image in enumerate(adapter_images):
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
# to the smallest dimension
img: Image.Image = Image.open(adapter_image)
if img.width > img.height:
# scale down so height is the same as batch
new_height = height
new_width = int(img.width * (height / img.height))
else:
new_width = width
new_height = int(img.height * (width / img.width))
img = img.resize((new_width, new_height))
crop_fn = transforms.CenterCrop((height, width))
# crop the center to match batch
img = crop_fn(img)
img = adapter_transforms(img)
adapter_tensors.append(img)
# stack them
adapter_tensors = torch.stack(adapter_tensors).to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
return adapter_tensors
def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]):
# set to eval mode
self.sd.set_device_state(self.eval_slider_device_state)
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
]
# move to device and dtype
prompt_pair.to(self.device_torch, dtype=dtype)
# get a random resolution
height, width = self.slider_config.resolutions[
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
]
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
pred_kwargs = {}
def get_noise_pred(neg, pos, gs, cts, dn):
down_kwargs = copy.deepcopy(pred_kwargs)
if 'down_block_additional_residuals' in down_kwargs:
dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0]
if dbr_batch_size != dn.shape[0]:
amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
down_kwargs['down_block_additional_residuals'] = [
torch.cat([sample.clone()] * amount_to_add) for sample in
down_kwargs['down_block_additional_residuals']
]
return self.sd.predict_noise(
latents=dn,
text_embeddings=train_tools.concat_prompt_embeddings(
@@ -209,9 +322,13 @@ class TrainSliderProcess(BaseSDTrainProcess):
),
timestep=cts,
guidance_scale=gs,
**down_kwargs
)
with torch.no_grad():
adapter_images = None
self.sd.unet.eval()
# for a complete slider, the batch size is 4 to begin with now
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
from_batch = False
@@ -219,9 +336,32 @@ class TrainSliderProcess(BaseSDTrainProcess):
# traing from a batch of images, not generating ourselves
from_batch = True
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.slider_config.adapter_img_dir is not None:
adapter_images = self.get_adapter_images(batch)
adapter_strength_min = 0.9
adapter_strength_max = 1.0
denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size
denoised_latents = torch.cat(denoised_latent_chunks, dim=0)
def rand_strength(sample):
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale
down_block_additional_residuals = self.t2i_adapter(adapter_images)
down_block_additional_residuals = [
rand_strength(sample) for sample in down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0)
current_timestep = timesteps
else:
@@ -229,14 +369,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.train_config.max_denoising_steps, device=self.device_torch
)
self.optimizer.zero_grad()
# ger a random number of steps
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
).item()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=height,
@@ -249,32 +386,60 @@ class TrainSliderProcess(BaseSDTrainProcess):
latents = noise * self.sd.noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
with self.network:
assert self.network.is_active
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional
prompt_pair.target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
)
assert not self.network.is_active
self.sd.unet.eval()
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
prompt_pair.positive_target, # unconditional
prompt_pair.target_class, # target
self.train_config.batch_size,
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
)
noise_scheduler.set_timesteps(1000)
# split the latents into out prompt pair chunks
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
# split the latents into out prompt pair chunks
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
# flush() # 4.2GB to 3GB on 512x512
mask_multiplier = torch.ones((denoised_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
has_mask = False
if batch and batch.mask_tensor is not None:
with self.timer('get_mask_multiplier'):
# upsampling no supported for bfloat16
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
mask_multiplier = torch.nn.functional.interpolate(
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
)
# expand to match latents
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
has_mask = True
if has_mask:
unmasked_target = get_noise_pred(
prompt_pair.positive_target, # negative prompt
prompt_pair.target_class, # positive prompt
1,
current_timestep,
denoised_latents
)
unmasked_target = unmasked_target.detach()
unmasked_target.requires_grad = False
else:
unmasked_target = None
# 4.20 GB RAM for 512x512
positive_latents = get_noise_pred(
@@ -286,7 +451,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
positive_latents = positive_latents.detach()
positive_latents.requires_grad = False
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
neutral_latents = get_noise_pred(
prompt_pair.positive_target, # negative prompt
@@ -297,7 +461,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
neutral_latents = neutral_latents.detach()
neutral_latents.requires_grad = False
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
unconditional_latents = get_noise_pred(
prompt_pair.positive_target, # negative prompt
@@ -308,13 +471,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
unconditional_latents = unconditional_latents.detach()
unconditional_latents.requires_grad = False
unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
denoised_latents = denoised_latents.detach()
flush() # 4.2GB to 3GB on 512x512
self.sd.set_device_state(self.train_slider_device_state)
self.sd.unet.train()
# start accumulating gradients
self.optimizer.zero_grad(set_to_none=True)
# 4.20 GB RAM for 512x512
anchor_loss_float = None
if len(self.anchor_pairs) > 0:
with torch.no_grad():
@@ -369,9 +533,34 @@ class TrainSliderProcess(BaseSDTrainProcess):
del anchor_target_noise
# move anchor back to cpu
anchor.to("cpu")
flush()
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
with torch.no_grad():
if self.slider_config.low_ram:
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
unconditional_latents_chunks = torch.chunk(
unconditional_latents.detach(),
self.prompt_chunk_size,
dim=0
)
mask_multiplier_chunks = torch.chunk(mask_multiplier, self.prompt_chunk_size, dim=0)
if unmasked_target is not None:
unmasked_target_chunks = torch.chunk(unmasked_target, self.prompt_chunk_size, dim=0)
else:
unmasked_target_chunks = [None for _ in range(self.prompt_chunk_size)]
else:
# run through in one instance
prompt_pair_chunks = [prompt_pair.detach()]
denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()]
positive_latents_chunks = [positive_latents.detach()]
neutral_latents_chunks = [neutral_latents.detach()]
unconditional_latents_chunks = [unconditional_latents.detach()]
mask_multiplier_chunks = [mask_multiplier]
unmasked_target_chunks = [unmasked_target]
# flush()
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
# 3.28 GB RAM for 512x512
with self.network:
@@ -381,13 +570,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latent_chunk, \
positive_latents_chunk, \
neutral_latents_chunk, \
unconditional_latents_chunk \
unconditional_latents_chunk, \
mask_multiplier_chunk, \
unmasked_target_chunk \
in zip(
prompt_pair_chunks,
denoised_latent_chunks,
positive_latents_chunks,
neutral_latents_chunks,
unconditional_latents_chunks,
mask_multiplier_chunks,
unmasked_target_chunks
):
self.network.multiplier = prompt_pair_chunk.multiplier_list
target_latents = get_noise_pred(
@@ -421,17 +614,43 @@ class TrainSliderProcess(BaseSDTrainProcess):
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
# do inverted mask to preserve non masked
if has_mask and unmasked_target_chunk is not None:
loss = loss * mask_multiplier_chunk
# match the mask unmasked_target_chunk
mask_target_loss = torch.nn.functional.mse_loss(
target_latents.float(),
unmasked_target_chunk.float(),
reduction="none"
)
mask_target_loss = mask_target_loss * (1.0 - mask_multiplier_chunk)
loss += mask_target_loss
loss = loss.mean([1, 2, 3])
if self.train_config.learnable_snr_gos:
if from_batch:
# match batch size
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
self.train_config.min_snr_gamma)
else:
# match batch size
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps_index_list, self.snr_gos)
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
if from_batch:
# match batch size
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
self.train_config.min_snr_gamma)
else:
# match batch size
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler,
self.train_config.min_snr_gamma)
loss = loss.mean() * prompt_pair_chunk.weight
@@ -440,7 +659,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
del target_latents
del offset_neutral
del loss
flush()
# flush()
optimizer.step()
lr_scheduler.step()
@@ -457,7 +676,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
# move back to cpu
prompt_pair.to("cpu")
flush()
# flush()
# reset network
self.network.multiplier = 1.0

View File

@@ -7,7 +7,6 @@ from .TrainVAEProcess import TrainVAEProcess
from .BaseMergeProcess import BaseMergeProcess
from .TrainSliderProcess import TrainSliderProcess
from .TrainSliderProcessOld import TrainSliderProcessOld
from .TrainLoRAHack import TrainLoRAHack
from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .ModRescaleLoraProcess import ModRescaleLoraProcess
from .GenerateProcess import GenerateProcess

View File

@@ -154,28 +154,28 @@ class Critic:
# train critic here
self.model.train()
self.model.requires_grad_(True)
self.optimizer.zero_grad()
critic_losses = []
for i in range(self.num_critic_per_gen):
inputs = vgg_output.detach()
inputs = inputs.to(self.device, dtype=self.torch_dtype)
self.optimizer.zero_grad()
inputs = vgg_output.detach()
inputs = inputs.to(self.device, dtype=self.torch_dtype)
self.optimizer.zero_grad()
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
stacked_output = self.model(inputs)
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
stacked_output = self.model(inputs).float()
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
# Compute gradient penalty
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# Compute gradient penalty
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# Compute WGAN-GP critic loss
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
critic_loss.backward()
self.optimizer.zero_grad()
self.optimizer.step()
self.scheduler.step()
critic_losses.append(critic_loss.item())
# Compute WGAN-GP critic loss
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.scheduler.step()
critic_losses.append(critic_loss.item())
# avg loss
loss = np.mean(critic_losses)

View File

@@ -1,9 +1,9 @@
torch
torchvision
safetensors
diffusers
transformers
lycoris_lora
diffusers==0.21.3
git+https://github.com/huggingface/transformers.git
lycoris-lora==1.8.3
flatten_json
pyyaml
oyaml
@@ -15,4 +15,10 @@ accelerate
toml
albumentations
pydantic
omegaconf
omegaconf
k-diffusion
open_clip_torch
timm
prodigyopt
controlnet_aux==0.0.7
python-dotenv

14
run.py
View File

@@ -1,8 +1,22 @@
import os
import sys
from typing import Union, OrderedDict
from dotenv import load_dotenv
# Load the .env file if it exists
load_dotenv()
sys.path.insert(0, os.getcwd())
# must come before ANY torch or fastai imports
# import toolkit.cuda_malloc
# turn off diffusers telemetry until I can figure out how to make it opt-in
os.environ['DISABLE_TELEMETRY'] = 'YES'
# check if we have DEBUG_TOOLKIT in env
if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
# set torch to trace mode
import torch
torch.autograd.set_detect_anomaly(True)
import argparse
from toolkit.job import get_job

128
scripts/convert_cog.py Normal file
View File

@@ -0,0 +1,128 @@
import json
from collections import OrderedDict
import os
import torch
from safetensors import safe_open
from safetensors.torch import save_file
device = torch.device('cpu')
# [diffusers] -> kohya
embedding_mapping = {
'text_encoders_0': 'clip_l',
'text_encoders_1': 'clip_g'
}
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
KEYMAP_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps')
sdxl_keymap_path = os.path.join(KEYMAP_ROOT, 'stable_diffusion_locon_sdxl.json')
# load keymap
with open(sdxl_keymap_path, 'r') as f:
ldm_diffusers_keymap = json.load(f)['ldm_diffusers_keymap']
# invert the item / key pairs
diffusers_ldm_keymap = {v: k for k, v in ldm_diffusers_keymap.items()}
def get_ldm_key(diffuser_key):
diffuser_key = f"lora_unet_{diffuser_key.replace('.', '_')}"
diffuser_key = diffuser_key.replace('_lora_down_weight', '.lora_down.weight')
diffuser_key = diffuser_key.replace('_lora_up_weight', '.lora_up.weight')
diffuser_key = diffuser_key.replace('_alpha', '.alpha')
diffuser_key = diffuser_key.replace('_processor_to_', '_to_')
diffuser_key = diffuser_key.replace('_to_out.', '_to_out_0.')
if diffuser_key in diffusers_ldm_keymap:
return diffusers_ldm_keymap[diffuser_key]
else:
raise KeyError(f"Key {diffuser_key} not found in keymap")
def convert_cog(lora_path, embedding_path):
embedding_state_dict = OrderedDict()
lora_state_dict = OrderedDict()
# # normal dict
# normal_dict = OrderedDict()
# example_path = "/mnt/Models/stable-diffusion/models/LoRA/sdxl/LogoRedmond_LogoRedAF.safetensors"
# with safe_open(example_path, framework="pt", device='cpu') as f:
# keys = list(f.keys())
# for key in keys:
# normal_dict[key] = f.get_tensor(key)
with safe_open(embedding_path, framework="pt", device='cpu') as f:
keys = list(f.keys())
for key in keys:
new_key = embedding_mapping[key]
embedding_state_dict[new_key] = f.get_tensor(key)
with safe_open(lora_path, framework="pt", device='cpu') as f:
keys = list(f.keys())
lora_rank = None
# get the lora dim first. Check first 3 linear layers just to be safe
for key in keys:
new_key = get_ldm_key(key)
tensor = f.get_tensor(key)
num_checked = 0
if len(tensor.shape) == 2:
this_dim = min(tensor.shape)
if lora_rank is None:
lora_rank = this_dim
elif lora_rank != this_dim:
raise ValueError(f"lora rank is not consistent, got {tensor.shape}")
else:
num_checked += 1
if num_checked >= 3:
break
for key in keys:
new_key = get_ldm_key(key)
tensor = f.get_tensor(key)
if new_key.endswith('.lora_down.weight'):
alpha_key = new_key.replace('.lora_down.weight', '.alpha')
# diffusers does not have alpha, they usa an alpha multiplier of 1 which is a tensor weight of the dims
# assume first smallest dim is the lora rank if shape is 2
lora_state_dict[alpha_key] = torch.ones(1).to(tensor.device, tensor.dtype) * lora_rank
lora_state_dict[new_key] = tensor
return lora_state_dict, embedding_state_dict
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'lora_path',
type=str,
help='Path to lora file'
)
parser.add_argument(
'embedding_path',
type=str,
help='Path to embedding file'
)
parser.add_argument(
'--lora_output',
type=str,
default="lora_output",
)
parser.add_argument(
'--embedding_output',
type=str,
default="embedding_output",
)
args = parser.parse_args()
lora_state_dict, embedding_state_dict = convert_cog(args.lora_path, args.embedding_path)
# save them
save_file(lora_state_dict, args.lora_output)
save_file(embedding_state_dict, args.embedding_output)
print(f"Saved lora to {args.lora_output}")
print(f"Saved embedding to {args.embedding_output}")

View File

@@ -0,0 +1,57 @@
import argparse
from collections import OrderedDict
import torch
from toolkit.config_modules import ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion
parser = argparse.ArgumentParser()
parser.add_argument(
'input_path',
type=str,
help='Path to original sdxl model'
)
parser.add_argument(
'output_path',
type=str,
help='output path'
)
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--refiner', action='store_true', help='is refiner model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
device = torch.device('cpu')
dtype = torch.float32
print(f"Loading model from {args.input_path}")
diffusers_model_config = ModelConfig(
name_or_path=args.input_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
print(f"Loaded model from {args.input_path}")
diffusers_sd.pipeline.fuse_lora()
meta = OrderedDict()
diffusers_sd.save(args.output_path, meta=meta)
print(f"Saved to {args.output_path}")

View File

@@ -0,0 +1,67 @@
import argparse
from collections import OrderedDict
import torch
from toolkit.config_modules import ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion
parser = argparse.ArgumentParser()
parser.add_argument(
'input_path',
type=str,
help='Path to original sdxl model'
)
parser.add_argument(
'output_path',
type=str,
help='output path'
)
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--refiner', action='store_true', help='is refiner model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
device = torch.device('cpu')
dtype = torch.float32
print(f"Loading model from {args.input_path}")
if args.sdxl:
adapter_id = "latent-consistency/lcm-lora-sdxl"
if args.refiner:
adapter_id = "latent-consistency/lcm-lora-sdxl"
elif args.ssd:
adapter_id = "latent-consistency/lcm-lora-ssd-1b"
else:
adapter_id = "latent-consistency/lcm-lora-sdv1-5"
diffusers_model_config = ModelConfig(
name_or_path=args.input_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
print(f"Loaded model from {args.input_path}")
diffusers_sd.pipeline.load_lora_weights(adapter_id)
diffusers_sd.pipeline.fuse_lora()
meta = OrderedDict()
diffusers_sd.save(args.output_path, meta=meta)
print(f"Saved to {args.output_path}")

View File

@@ -1,547 +0,0 @@
import gc
import time
import argparse
import itertools
import math
import os
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import custom_tools.train_tools as train_tools
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
# perlin_noise,
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
SD_SCRIPTS_ROOT = os.path.join(PROJECT_ROOT, "repositories", "sd-scripts")
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False)
cache_latents = args.cache_latents
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.no_token_padding:
train_dataset_group.disable_token_padding()
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# replace captions with names
if args.name_replace is not None:
print(f"Replacing captions [name] with '{args.name_replace}'")
train_dataset_group = train_tools.replace_filewords_in_dataset_group(
train_dataset_group, args
)
# acceleratorを準備する
print("prepare accelerator")
if args.gradient_accumulation_steps > 1:
print(
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
)
print(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
# verify load/save model formats
if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None
else:
src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors
else:
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(train_text_encoder)
if not train_text_encoder:
print("Text Encoder is not trained.")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
if train_text_encoder:
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
trainable_params = unet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
if args.sample_first or args.sample_only:
# Do initial sample before starting training
train_tools.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer,
text_encoder, unet, force_sample=True)
if args.sample_only:
return
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
with accelerator.accumulate(unet):
with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents
if args.train_noise_seed is not None:
torch.manual_seed(args.train_noise_seed)
torch.cuda.manual_seed(args.train_noise_seed)
# make same seed for each item in the batch by stacking them
single_noise = torch.randn_like(latents[0])
noise = torch.stack([single_noise for _ in range(b_size)])
noise = noise.to(latents.device)
elif args.seed_lock:
noise = train_tools.get_noise_from_latents(latents)
else:
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
# elif args.perlin_noise:
# noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
if train_text_encoder:
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
params_to_clip = unet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
# checking for saving is in util
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--no_token_padding",
action="store_true",
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作",
)
parser.add_argument(
"--stop_text_encoder_training",
type=int,
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
parser.add_argument(
"--sample_first",
action="store_true",
help="Sample first interval before training",
default=False
)
parser.add_argument(
"--name_replace",
type=str,
help="Replaces [name] in prompts. Used is sampling, training, and regs",
default=None
)
parser.add_argument(
"--train_noise_seed",
type=int,
help="Use custom seed for training noise",
default=None
)
parser.add_argument(
"--sample_only",
action="store_true",
help="Only generate samples. Used for generating training data with specific seeds to alter during training",
default=False
)
parser.add_argument(
"--seed_lock",
action="store_true",
help="Locks the seed to the latent images so the same latent will always have the same noise",
default=False
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -0,0 +1,130 @@
from collections import OrderedDict
import torch
from safetensors.torch import load_file
import argparse
import os
import json
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
keymap_path = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', 'stable_diffusion_sdxl.json')
# load keymap
with open(keymap_path, 'r') as f:
keymap = json.load(f)
lora_keymap = OrderedDict()
# convert keymap to lora key naming
for ldm_key, diffusers_key in keymap['ldm_diffusers_keymap'].items():
if ldm_key.endswith('.bias') or diffusers_key.endswith('.bias'):
# skip it
continue
# sdxl has same te for locon with kohya and ours
if ldm_key.startswith('conditioner'):
#skip it
continue
# ignore vae
if ldm_key.startswith('first_stage_model'):
continue
ldm_key = ldm_key.replace('model.diffusion_model.', 'lora_unet_')
ldm_key = ldm_key.replace('.weight', '')
ldm_key = ldm_key.replace('.', '_')
diffusers_key = diffusers_key.replace('unet_', 'lora_unet_')
diffusers_key = diffusers_key.replace('.weight', '')
diffusers_key = diffusers_key.replace('.', '_')
lora_keymap[f"{ldm_key}.alpha"] = f"{diffusers_key}.alpha"
lora_keymap[f"{ldm_key}.lora_down.weight"] = f"{diffusers_key}.lora_down.weight"
lora_keymap[f"{ldm_key}.lora_up.weight"] = f"{diffusers_key}.lora_up.weight"
parser = argparse.ArgumentParser()
parser.add_argument("input", help="input file")
parser.add_argument("input2", help="input2 file")
args = parser.parse_args()
# name = args.name
# if args.sdxl:
# name += '_sdxl'
# elif args.sd2:
# name += '_sd2'
# else:
# name += '_sd1'
name = 'stable_diffusion_locon_sdxl'
locon_save = load_file(args.input)
our_save = load_file(args.input2)
our_extra_keys = list(set(our_save.keys()) - set(locon_save.keys()))
locon_extra_keys = list(set(locon_save.keys()) - set(our_save.keys()))
print(f"we have {len(our_extra_keys)} extra keys")
print(f"locon has {len(locon_extra_keys)} extra keys")
save_dtype = torch.float16
print(f"our extra keys: {our_extra_keys}")
print(f"locon extra keys: {locon_extra_keys}")
def export_state_dict(our_save):
converted_state_dict = OrderedDict()
for key, value in our_save.items():
# test encoders share keys for some reason
if key.startswith('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
else:
converted_key = key
for ldm_key, diffusers_key in lora_keymap.items():
if converted_key == diffusers_key:
converted_key = ldm_key
converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
return converted_state_dict
def import_state_dict(loaded_state_dict):
converted_state_dict = OrderedDict()
for key, value in loaded_state_dict.items():
if key.startswith('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
else:
converted_key = key
for ldm_key, diffusers_key in lora_keymap.items():
if converted_key == ldm_key:
converted_key = diffusers_key
converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
return converted_state_dict
# check it again
converted_state_dict = export_state_dict(our_save)
converted_extra_keys = list(set(converted_state_dict.keys()) - set(locon_save.keys()))
locon_extra_keys = list(set(locon_save.keys()) - set(converted_state_dict.keys()))
print(f"we have {len(converted_extra_keys)} extra keys")
print(f"locon has {len(locon_extra_keys)} extra keys")
print(f"our extra keys: {converted_extra_keys}")
# convert back
cycle_state_dict = import_state_dict(converted_state_dict)
cycle_extra_keys = list(set(cycle_state_dict.keys()) - set(our_save.keys()))
our_extra_keys = list(set(our_save.keys()) - set(cycle_state_dict.keys()))
print(f"we have {len(our_extra_keys)} extra keys")
print(f"cycle has {len(cycle_extra_keys)} extra keys")
# save keymap
to_save = OrderedDict()
to_save['ldm_diffusers_keymap'] = lora_keymap
with open(os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', f'{name}.json'), 'w') as f:
json.dump(to_save, f, indent=4)

View File

@@ -6,6 +6,8 @@ import os
# add project root to sys path
import sys
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
@@ -50,6 +52,8 @@ parser.add_argument(
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--refiner', action='store_true', help='is refiner model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
@@ -60,24 +64,76 @@ find_matches = False
print(f'Loading diffusers model')
diffusers_model_config = ModelConfig(
name_or_path=file_path,
is_xl=args.sdxl,
is_v2=args.sd2,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
# delete things we dont need
del diffusers_sd.tokenizer
flush()
ignore_ldm_begins_with = []
diffusers_file_path = file_path
if args.ssd:
diffusers_file_path = "segmind/SSD-1B"
# if args.refiner:
# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
if not args.refiner:
diffusers_model_config = ModelConfig(
name_or_path=diffusers_file_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
# delete things we dont need
del diffusers_sd.tokenizer
flush()
print(f'Loading ldm model')
diffusers_state_dict = diffusers_sd.state_dict()
else:
# refiner wont work directly with stable diffusion
# so we need to load the model and then load the state dict
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
diffusers_file_path,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to(device)
# diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
# file_path,
# torch_dtype=torch.float16,
# use_safetensors=True,
# variant="fp16",
# ).to(device)
SD_PREFIX_VAE = "vae"
SD_PREFIX_UNET = "unet"
SD_PREFIX_REFINER_UNET = "refiner_unet"
SD_PREFIX_TEXT_ENCODER = "te"
SD_PREFIX_TEXT_ENCODER1 = "te0"
SD_PREFIX_TEXT_ENCODER2 = "te1"
diffusers_state_dict = OrderedDict()
for k, v in diffusers_pipeline.vae.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
diffusers_state_dict[new_key] = v
for k, v in diffusers_pipeline.text_encoder_2.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
diffusers_state_dict[new_key] = v
for k, v in diffusers_pipeline.unet.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
diffusers_state_dict[new_key] = v
# add ignore ones as we are only going to focus on unet and copy the rest
# ignore_ldm_begins_with = ["conditioner.", "first_stage_model."]
print(f'Loading ldm model')
diffusers_state_dict = diffusers_sd.state_dict()
diffusers_dict_keys = list(diffusers_state_dict.keys())
ldm_state_dict = load_file(file_path)
@@ -93,18 +149,26 @@ total_keys = len(ldm_dict_keys)
matched_ldm_keys = []
matched_diffusers_keys = []
error_margin = 1e-4
error_margin = 1e-8
tmp_merge_key = "TMP___MERGE"
te_suffix = ''
proj_pattern_weight = None
proj_pattern_bias = None
text_proj_layer = None
if args.sdxl:
if args.sdxl or args.ssd:
te_suffix = '1'
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "conditioner.embedders.1.model.text_projection"
if args.refiner:
te_suffix = '1'
ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks"
proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "conditioner.embedders.0.model.text_projection"
if args.sd2:
te_suffix = ''
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
@@ -112,10 +176,13 @@ if args.sd2:
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "cond_stage_model.model.text_projection"
if args.sdxl or args.sd2:
if args.sdxl or args.sd2 or args.ssd or args.refiner:
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
else:
d_model = 1024
@@ -139,7 +206,7 @@ if args.sdxl or args.sd2:
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
# make diffusers convertable_dict
diffusers_state_dict[
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"] = new_val
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.weight"] = new_val
# add operator
ldm_operator_map[ldm_key] = {
@@ -148,7 +215,6 @@ if args.sdxl or args.sd2:
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
],
"target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"
}
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
@@ -189,7 +255,7 @@ if args.sdxl or args.sd2:
matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
# make diffusers convertable_dict
diffusers_state_dict[
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"] = new_val
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.bias"] = new_val
# add operator
ldm_operator_map[ldm_key] = {
@@ -198,7 +264,6 @@ if args.sdxl or args.sd2:
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
],
# "target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"
}
# add diffusers operators
@@ -237,11 +302,11 @@ for ldm_key in ldm_dict_keys:
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
# That was easy. Same key
if ldm_key == diffusers_key:
ldm_diffusers_keymap[ldm_key] = diffusers_key
matched_ldm_keys.append(ldm_key)
matched_diffusers_keys.append(diffusers_key)
break
# if ldm_key == diffusers_key:
# ldm_diffusers_keymap[ldm_key] = diffusers_key
# matched_ldm_keys.append(ldm_key)
# matched_diffusers_keys.append(diffusers_key)
# break
# if we already have this key mapped, skip it
if diffusers_key in matched_diffusers_keys:
@@ -266,7 +331,7 @@ for ldm_key in ldm_dict_keys:
did_reduce_diffusers = True
# check to see if they match within a margin of error
mse = torch.nn.functional.mse_loss(ldm_weight, diffusers_weight)
mse = torch.nn.functional.mse_loss(ldm_weight.float(), diffusers_weight.float())
if mse < error_margin:
ldm_diffusers_keymap[ldm_key] = diffusers_key
matched_ldm_keys.append(ldm_key)
@@ -289,6 +354,10 @@ pbar.close()
name = args.name
if args.sdxl:
name += '_sdxl'
elif args.ssd:
name += '_ssd'
elif args.refiner:
name += '_refiner'
elif args.sd2:
name += '_sd2'
else:
@@ -359,13 +428,35 @@ for key in unmatched_ldm_keys:
save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors'))
print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}')
# do cleanup of some left overs and bugs
to_remove = []
for ldm_key, diffusers_key in ldm_diffusers_keymap.items():
# get rid of tmp merge keys used to slicing
if tmp_merge_key in diffusers_key or tmp_merge_key in ldm_key:
to_remove.append(ldm_key)
for key in to_remove:
del ldm_diffusers_keymap[key]
to_remove = []
# remove identical shape mappings. Not sure why they exist but they do
for ldm_key, shape_list in ldm_diffusers_shape_map.items():
# remove identical shape mappings. Not sure why they exist but they do
# convert to json string to make it easier to compare
ldm_shape = json.dumps(shape_list[0])
diffusers_shape = json.dumps(shape_list[1])
if ldm_shape == diffusers_shape:
to_remove.append(ldm_key)
for key in to_remove:
del ldm_diffusers_shape_map[key]
dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json')
save_obj = OrderedDict()
save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap
save_obj["ldm_diffusers_shape_map"] = ldm_diffusers_shape_map
save_obj["ldm_diffusers_operator_map"] = ldm_operator_map
save_obj["diffusers_ldm_operator_map"] = diffusers_operator_map
with open(dest_path, 'w') as f:
f.write(json.dumps(save_obj, indent=4))

View File

@@ -1,37 +1,107 @@
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
# make sure we can import from the toolkit
import time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import sys
import os
import cv2
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets
from toolkit.paths import SD_SCRIPTS_ROOT
from toolkit.image_utils import show_img
sys.path.append(SD_SCRIPTS_ROOT)
from library.model_util import load_vae
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args()
dataset_folder = args.dataset_folder
resolution = 512
resolution = 1024
bucket_tolerance = 64
batch_size = 4
batch_size = 1
##
dataset_config = DatasetConfig(
folder_path=dataset_folder,
dataset_path=dataset_folder,
resolution=resolution,
caption_type='txt',
caption_ext='json',
default_caption='default',
buckets=True,
bucket_tolerance=bucket_tolerance,
poi='person',
augmentations=[
{
'method': 'RandomBrightnessContrast',
'brightness_limit': (-0.3, 0.3),
'contrast_limit': (-0.3, 0.3),
'brightness_by_max': False,
'p': 1.0
},
{
'method': 'HueSaturationValue',
'hue_shift_limit': (-0, 0),
'sat_shift_limit': (-40, 40),
'val_shift_limit': (-40, 40),
'p': 1.0
},
# {
# 'method': 'RGBShift',
# 'r_shift_limit': (-20, 20),
# 'g_shift_limit': (-20, 20),
# 'b_shift_limit': (-20, 20),
# 'p': 1.0
# },
]
)
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
# run through an epoch ang check sizes
for batch in dataloader:
print(list(batch[0].shape))
dataloader_iterator = iter(dataloader)
for epoch in range(args.epochs):
for batch in dataloader:
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0)
min_val = big_img.min()
max_val = big_img.max()
big_img = (big_img / 2 + 0.5).clamp(0, 1)
# convert to image
img = transforms.ToPILImage()(big_img)
show_img(img)
time.sleep(1.0)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)
cv2.destroyAllWindows()
print('done')

View File

@@ -3,6 +3,8 @@ import os
# add project root to sys path
import sys
from tqdm import tqdm
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
@@ -20,6 +22,8 @@ from toolkit.stable_diffusion_model import StableDiffusion
# you probably wont need this. Unless they change them.... again... again
# on second thought, you probably will
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
device = torch.device('cpu')
dtype = torch.float32
@@ -109,7 +113,26 @@ keys_in_both.sort()
if len(keys_not_in_state_dict_2) == 0 and len(keys_not_in_state_dict_1) == 0:
print("All keys match!")
exit(0)
print("Checking values...")
mismatch_keys = []
loss = torch.nn.MSELoss()
tolerance = 1e-6
for key in tqdm(keys_in_both):
if loss(state_dict_file_1[key], state_dict_file_2[key]) > tolerance:
print(f"Values for key {key} don't match!")
print(f"Loss: {loss(state_dict_file_1[key], state_dict_file_2[key])}")
mismatch_keys.append(key)
if len(mismatch_keys) == 0:
print("All values match!")
else:
print("Some valued font match!")
print(mismatch_keys)
mismatched_path = os.path.join(project_root, 'config', 'mismatch.json')
with open(mismatched_path, 'w') as f:
f.write(json.dumps(mismatch_keys, indent=4))
exit(0)
else:
print("Keys don't match!, generating info...")
@@ -132,17 +155,17 @@ for key in keys_not_in_state_dict_2:
# print(json_data)
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
json_save_path = os.path.join(project_root, 'config', 'keys.json')
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
state_dict_1_filename = os.path.basename(args.file_1[0])
state_dict_2_filename = os.path.basename(args.file_2[0])
# state_dict_2_filename = os.path.basename(args.file_2[0])
# save key names for each in own file
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f:
f.write(json.dumps(state_dict_1_keys, indent=4))
with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f:
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}_loop.json'), 'w') as f:
f.write(json.dumps(state_dict_2_keys, indent=4))
with open(json_save_path, 'w') as f:

View File

@@ -1,4 +1,50 @@
import gc
import torch
def value_map(inputs, min_in, max_in, min_out, max_out):
return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out
def flush(garbage_collect=True):
torch.cuda.empty_cache()
if garbage_collect:
gc.collect()
def get_mean_std(tensor):
if len(tensor.shape) == 3:
tensor = tensor.unsqueeze(0)
elif len(tensor.shape) != 4:
raise Exception("Expected tensor of shape (batch_size, channels, width, height)")
mean, variance = torch.mean(
tensor, dim=[2, 3], keepdim=True
), torch.var(
tensor, dim=[2, 3],
keepdim=True
)
std = torch.sqrt(variance + 1e-5)
return mean, std
def adain(content_features, style_features):
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
# Step 1: Calculate mean and variance of content features
content_mean, content_var = torch.mean(content_features, dim=[2, 3], keepdim=True), torch.var(content_features,
dim=[2, 3],
keepdim=True)
# Step 2: Calculate mean and variance of style features
style_mean, style_var = torch.mean(style_features, dim=[2, 3], keepdim=True), torch.var(style_features, dim=[2, 3],
keepdim=True)
# Step 3: Normalize content features
content_std = torch.sqrt(content_var + 1e-5)
normalized_content = (content_features - content_mean) / content_std
# Step 4: Scale and shift normalized content with style's statistics
style_std = torch.sqrt(style_var + 1e-5)
stylized_content = normalized_content * style_std + style_mean
return stylized_content

127
toolkit/buckets.py Normal file
View File

@@ -0,0 +1,127 @@
from typing import Type, List, Union, TypedDict
class BucketResolution(TypedDict):
width: int
height: int
# resolutions SDXL was trained on with a 1024x1024 base resolution
resolutions_1024: List[BucketResolution] = [
# SDXL Base resolution
{"width": 1024, "height": 1024},
# SDXL Resolutions, widescreen
{"width": 2048, "height": 512},
{"width": 1984, "height": 512},
{"width": 1920, "height": 512},
{"width": 1856, "height": 512},
{"width": 1792, "height": 576},
{"width": 1728, "height": 576},
{"width": 1664, "height": 576},
{"width": 1600, "height": 640},
{"width": 1536, "height": 640},
{"width": 1472, "height": 704},
{"width": 1408, "height": 704},
{"width": 1344, "height": 704},
{"width": 1344, "height": 768},
{"width": 1280, "height": 768},
{"width": 1216, "height": 832},
{"width": 1152, "height": 832},
{"width": 1152, "height": 896},
{"width": 1088, "height": 896},
{"width": 1088, "height": 960},
{"width": 1024, "height": 960},
# SDXL Resolutions, portrait
{"width": 960, "height": 1024},
{"width": 960, "height": 1088},
{"width": 896, "height": 1088},
{"width": 896, "height": 1152}, # 2:3
{"width": 832, "height": 1152},
{"width": 832, "height": 1216},
{"width": 768, "height": 1280},
{"width": 768, "height": 1344},
{"width": 704, "height": 1408},
{"width": 704, "height": 1472},
{"width": 640, "height": 1536},
{"width": 640, "height": 1600},
{"width": 576, "height": 1664},
{"width": 576, "height": 1728},
{"width": 576, "height": 1792},
{"width": 512, "height": 1856},
{"width": 512, "height": 1920},
{"width": 512, "height": 1984},
{"width": 512, "height": 2048},
]
def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]:
# determine scaler form 1024 to resolution
scaler = resolution / 1024
bucket_size_list = []
for bucket in resolutions_1024:
# must be divisible by 8
width = int(bucket["width"] * scaler)
height = int(bucket["height"] * scaler)
if width % divisibility != 0:
width = width - (width % divisibility)
if height % divisibility != 0:
height = height - (height % divisibility)
bucket_size_list.append({"width": width, "height": height})
return bucket_size_list
def get_resolution(width, height):
num_pixels = width * height
# determine same number of pixels for square image
square_resolution = int(num_pixels ** 0.5)
return square_resolution
def get_bucket_for_image_size(
width: int,
height: int,
bucket_size_list: List[BucketResolution] = None,
resolution: Union[int, None] = None,
divisibility: int = 8
) -> BucketResolution:
if bucket_size_list is None and resolution is None:
# get resolution from width and height
resolution = get_resolution(width, height)
if bucket_size_list is None:
# if real resolution is smaller, use that instead
real_resolution = get_resolution(width, height)
resolution = min(resolution, real_resolution)
bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility)
# Check for exact match first
for bucket in bucket_size_list:
if bucket["width"] == width and bucket["height"] == height:
return bucket
# If exact match not found, find the closest bucket
closest_bucket = None
min_removed_pixels = float("inf")
for bucket in bucket_size_list:
scale_w = bucket["width"] / width
scale_h = bucket["height"] / height
# To minimize pixels, we use the larger scale factor to minimize the amount that has to be cropped.
scale = max(scale_w, scale_h)
new_width = int(width * scale)
new_height = int(height * scale)
removed_pixels = (new_width - bucket["width"]) * new_height + (new_height - bucket["height"]) * new_width
if removed_pixels < min_removed_pixels:
min_removed_pixels = removed_pixels
closest_bucket = bucket
if closest_bucket is None:
raise ValueError("No suitable bucket found")
return closest_bucket

View File

@@ -17,6 +17,24 @@ def get_cwd_abs_path(path):
return path
def replace_env_vars_in_string(s: str) -> str:
"""
Replace placeholders like ${VAR_NAME} with the value of the corresponding environment variable.
If the environment variable is not set, raise an error.
"""
def replacer(match):
var_name = match.group(1)
value = os.environ.get(var_name)
if value is None:
raise ValueError(f"Environment variable {var_name} not set. Please ensure it's defined before proceeding.")
return value
return re.sub(r'\$\{([^}]+)\}', replacer, s)
def preprocess_config(config: OrderedDict, name: str = None):
if "job" not in config:
raise ValueError("config file must have a job key")
@@ -81,13 +99,14 @@ def get_config(
raise ValueError(f"Could not find config file {config_file_path}")
# if we found it, check if it is a json or yaml file
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
with open(real_config_path, 'r', encoding='utf-8') as f:
config = json.load(f, object_pairs_hook=OrderedDict)
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
with open(real_config_path, 'r', encoding='utf-8') as f:
config = yaml.load(f, Loader=fixed_loader)
else:
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
with open(real_config_path, 'r', encoding='utf-8') as f:
content = f.read()
content_with_env_replaced = replace_env_vars_in_string(content)
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
config = json.loads(content_with_env_replaced, object_pairs_hook=OrderedDict)
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
config = yaml.load(content_with_env_replaced, Loader=fixed_loader)
else:
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
return preprocess_config(config, name)

View File

@@ -1,14 +1,25 @@
import os
import time
from typing import List, Optional, Literal
from typing import List, Optional, Literal, Union
import random
import torch
from toolkit.prompt_utils import PromptEmbeds
ImgExt = Literal['jpg', 'png', 'webp']
SaveFormat = Literal['safetensors', 'diffusers']
class SaveConfig:
def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000)
self.dtype: str = kwargs.get('save_dtype', 'float16')
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
if self.save_format not in ['safetensors', 'diffusers']:
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
class LogingConfig:
@@ -20,6 +31,7 @@ class LogingConfig:
class SampleConfig:
def __init__(self, **kwargs):
self.sampler: str = kwargs.get('sampler', 'ddpm')
self.sample_every: int = kwargs.get('sample_every', 100)
self.width: int = kwargs.get('width', 512)
self.height: int = kwargs.get('height', 512)
@@ -31,11 +43,59 @@ class SampleConfig:
self.sample_steps = kwargs.get('sample_steps', 20)
self.network_multiplier = kwargs.get('network_multiplier', 1)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext: ImgExt = kwargs.get('format', 'jpg')
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists
class LormModuleSettingsConfig:
def __init__(self, **kwargs):
self.contains: str = kwargs.get('contains', '4nt$3')
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
# min num parameters to attach to
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
class LoRMConfig:
def __init__(self, **kwargs):
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
self.do_conv: bool = kwargs.get('do_conv', False)
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
module_settings = kwargs.get('module_settings', [])
default_module_settings = {
'extract_mode': self.extract_mode,
'extract_mode_param': self.extract_mode_param,
'parameter_threshold': self.parameter_threshold,
}
module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings]
self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for
module_setting in module_settings]
def get_config_for_module(self, block_name):
for setting in self.module_settings:
contain_pieces = setting.contains.split('|')
if all(contain_piece in block_name for contain_piece in contain_pieces):
return setting
# try replacing the . with _
contain_pieces = setting.contains.replace('.', '_').split('|')
if all(contain_piece in block_name for contain_piece in contain_pieces):
return setting
# do default
return LormModuleSettingsConfig(**{
'extract_mode': self.extract_mode,
'extract_mode_param': self.extract_mode_param,
'parameter_threshold': self.parameter_threshold,
})
NetworkType = Literal['lora', 'locon', 'lorm']
class NetworkConfig:
def __init__(self, **kwargs):
self.type: str = kwargs.get('type', 'lora')
self.type: NetworkType = kwargs.get('type', 'lora')
rank = kwargs.get('rank', None)
linear = kwargs.get('linear', None)
if rank is not None:
@@ -48,7 +108,37 @@ class NetworkConfig:
self.alpha: float = kwargs.get('alpha', 1.0)
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
self.normalize = kwargs.get('normalize', False)
self.dropout: Union[float, None] = kwargs.get('dropout', None)
self.lorm_config: Union[LoRMConfig, None] = None
lorm = kwargs.get('lorm', None)
if lorm is not None:
self.lorm_config: LoRMConfig = LoRMConfig(**lorm)
if self.type == 'lorm':
# set linear to arbitrary values so it makes them
self.linear = 4
self.rank = 4
if self.lorm_config.do_conv:
self.conv = 4
AdapterTypes = Literal['t2i', 'ip', 'ip+']
class AdapterConfig:
def __init__(self, **kwargs):
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
self.in_channels: int = kwargs.get('in_channels', 3)
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
self.downscale_factor: int = kwargs.get('downscale_factor', 8)
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
self.image_dir: str = kwargs.get('image_dir', None)
self.test_img_path: str = kwargs.get('test_img_path', None)
self.train: str = kwargs.get('train', False)
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
self.name_or_path = kwargs.get('name_or_path', None)
class EmbeddingConfig:
@@ -59,26 +149,94 @@ class EmbeddingConfig:
self.save_format = kwargs.get('save_format', 'safetensors')
ContentOrStyleType = Literal['balanced', 'style', 'content']
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
class TrainConfig:
def __init__(self, **kwargs):
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced')
self.steps: int = kwargs.get('steps', 1000)
self.lr = kwargs.get('lr', 1e-6)
self.unet_lr = kwargs.get('unet_lr', self.lr)
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
self.refiner_lr = kwargs.get('refiner_lr', self.lr)
self.embedding_lr = kwargs.get('embedding_lr', self.lr)
self.adapter_lr = kwargs.get('adapter_lr', self.lr)
self.optimizer = kwargs.get('optimizer', 'adamw')
self.optimizer_params = kwargs.get('optimizer_params', {})
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)
self.sdp = kwargs.get('sdp', False)
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.train_refiner = kwargs.get('train_refiner', True)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
self.snr_gamma = kwargs.get('snr_gamma', None)
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
# this should balance the learning rate across all timesteps over time
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
self.start_step = kwargs.get('start_step', None)
self.free_u = kwargs.get('free_u', False)
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
self.negative_prompt = kwargs.get('negative_prompt', None)
# multiplier applied to loos on regularization images
self.reg_weight = kwargs.get('reg_weight', 1.0)
# dropout that happens before encoding. It functions independently per text encoder
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
# match the norm of the noise before computing loss. This will help the model maintain its
# current understandin of the brightness of images.
self.match_noise_norm = kwargs.get('match_noise_norm', False)
# set to -1 to accumulate gradients for entire epoch
# warning, only do this with a small dataset or you will run out of memory
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
# short long captions will double your batch size. This only works when a dataset is
# prepared with a json caption file that has both short and long captions in it. It will
# Double up every image and run it through with both short and long captions. The idea
# is that the network will learn how to generate good images with both short and long captions
self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
# if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only
self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False)
# basically gradient accumulation but we run just 1 item through the network
# and accumulate gradients. This can be used as basic gradient accumulation but is very helpful
# for training tricks that increase batch size but need a single gradient step
self.single_item_batching = kwargs.get('single_item_batching', False)
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
self.loss_target: LossTarget = kwargs.get('loss_target',
'noise') # noise, source, unaugmented, differential_noise
# When a mask is passed in a dataset, and this is true,
# we will predict noise without a the LoRa network and use the prediction as a target for
# unmasked reign. It is unmasked regularization basically
self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False)
self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5)
# legacy
if match_adapter_assist and self.match_adapter_chance == 0.0:
self.match_adapter_chance = 1.0
class ModelConfig:
@@ -86,17 +244,27 @@ class ModelConfig:
self.name_or_path: str = kwargs.get('name_or_path', None)
self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_ssd: bool = kwargs.get('is_ssd', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16')
self.vae_path = kwargs.get('vae_path', None)
self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None)
self._original_refiner_name_or_path = self.refiner_name_or_path
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
# only for SDXL models for now
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
self.use_text_encoder_2: bool = kwargs.get('use_text_encoder_2', True)
self.experimental_xl: bool = kwargs.get('experimental_xl', False)
if self.name_or_path is None:
raise ValueError('name_or_path must be specified')
if self.is_ssd:
# sed sdxl as true since it is mostly the same architecture
self.is_xl = True
class ReferenceDatasetConfig:
def __init__(self, **kwargs):
@@ -126,6 +294,14 @@ class SliderTargetConfig:
self.shuffle: bool = kwargs.get('shuffle', False)
class GuidanceConfig:
def __init__(self, **kwargs):
self.target_class: str = kwargs.get('target_class', '')
self.guidance_scale: float = kwargs.get('guidance_scale', 1.0)
self.positive_prompt: str = kwargs.get('positive_prompt', '')
self.negative_prompt: str = kwargs.get('negative_prompt', '')
class SliderConfigAnchors:
def __init__(self, **kwargs):
self.prompt = kwargs.get('prompt', '')
@@ -143,28 +319,41 @@ class SliderConfig:
self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
self.low_ram = kwargs.get('low_ram', False)
# expand targets if shuffling
from toolkit.prompt_utils import get_slider_target_permutations
self.targets: List[SliderTargetConfig] = []
targets = [SliderTargetConfig(**target) for target in targets]
# do permutations if shuffle is true
print(f"Building slider targets")
for target in targets:
if target.shuffle:
target_permutations = get_slider_target_permutations(target)
target_permutations = get_slider_target_permutations(target, max_permutations=8)
self.targets = self.targets + target_permutations
else:
self.targets.append(target)
print(f"Built {len(self.targets)} slider targets (with permutations)")
class DatasetConfig:
caption_type: Literal["txt", "caption"] = 'txt'
"""
Dataset config for sd-datasets
"""
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'image') # sd, slider, reference
# will be legacy
self.folder_path: str = kwargs.get('folder_path', None)
# can be json or folder path
self.dataset_path: str = kwargs.get('dataset_path', None)
self.default_caption: str = kwargs.get('default_caption', None)
self.caption_type: str = kwargs.get('caption_type', None)
self.random_triggers: List[str] = kwargs.get('random_triggers', [])
self.caption_ext: str = kwargs.get('caption_ext', None)
self.random_scale: bool = kwargs.get('random_scale', False)
self.random_crop: bool = kwargs.get('random_crop', False)
self.resolution: int = kwargs.get('resolution', 512)
@@ -172,6 +361,66 @@ class DatasetConfig:
self.buckets: bool = kwargs.get('buckets', False)
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
self.is_reg: bool = kwargs.get('is_reg', False)
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
self.mask_path: str = kwargs.get('mask_path',
None) # focus mask (black and white. White has higher loss than black)
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi',
None) # if one is set and in json data, will be used as auto crop scale point of interes
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
# cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False)
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
# https://albumentations.ai/docs/api_reference/augmentations/transforms
# augmentations are returned as a separate image and cannot currently be cached
self.augmentations: List[dict] = kwargs.get('augmentations', None)
self.shuffle_augmentations: bool = kwargs.get('shuffle_augmentations', False)
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk):
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
self.cache_latents = False
self.cache_latents_to_disk = False
# legacy compatability
legacy_caption_type = kwargs.get('caption_type', None)
if legacy_caption_type:
self.caption_ext = legacy_caption_type
self.caption_type = self.caption_ext
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
"""
This just splits up the datasets by resolutions so you dont have to do it manually
:param raw_config:
:return:
"""
# split up datasets by resolutions
new_config = []
for dataset in raw_config:
resolution = dataset.get('resolution', 512)
if isinstance(resolution, list):
resolution_list = resolution
else:
resolution_list = [resolution]
for res in resolution_list:
dataset_copy = dataset.copy()
dataset_copy['resolution'] = res
new_config.append(dataset_copy)
return new_config
class GenerateImageConfig:
@@ -191,9 +440,14 @@ class GenerateImageConfig:
# the tag [time] will be replaced with milliseconds since epoch
output_path: str = None, # full image path
output_folder: str = None, # folder to save image in if output_path is not specified
output_ext: str = 'png', # extension to save image as if output_path is not specified
output_ext: str = ImgExt, # extension to save image as if output_path is not specified
output_tail: str = '', # tail to add to output filename
add_prompt_file: bool = False, # add a prompt file with generated image
adapter_image_path: str = None, # path to adapter image
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
latents: Union[torch.Tensor | None] = None, # input latent to start with,
extra_kwargs: dict = None, # extra data to save with prompt file
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
):
self.width: int = width
self.height: int = height
@@ -204,6 +458,7 @@ class GenerateImageConfig:
self.prompt_2: str = prompt_2
self.negative_prompt: str = negative_prompt
self.negative_prompt_2: str = negative_prompt_2
self.latents: Union[torch.Tensor | None] = latents
self.output_path: str = output_path
self.seed: int = seed
@@ -216,6 +471,10 @@ class GenerateImageConfig:
self.add_prompt_file: bool = add_prompt_file
self.output_tail: str = output_tail
self.gen_time: int = int(time.time() * 1000)
self.adapter_image_path: str = adapter_image_path
self.adapter_conditioning_scale: float = adapter_conditioning_scale
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
self.refiner_start_at = refiner_start_at
# prompt string will override any settings above
self._process_prompt_string()
@@ -370,3 +629,15 @@ class GenerateImageConfig:
self.network_multiplier = float(content)
elif flag == 'gr':
self.guidance_rescale = float(content)
elif flag == 'a':
self.adapter_conditioning_scale = float(content)
elif flag == 'ref':
self.refiner_start_at = float(content)
def post_process_embeddings(
self,
conditional_prompt_embeds: PromptEmbeds,
unconditional_prompt_embeds: Optional[PromptEmbeds] = None,
):
# this is called after prompt embeds are encoded. We can override them in the future here
pass

93
toolkit/cuda_malloc.py Normal file
View File

@@ -0,0 +1,93 @@
# ref comfy ui
import os
import importlib.util
# Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
if os.name == 'nt':
import ctypes
# Define necessary C structures and types
class DISPLAY_DEVICEA(ctypes.Structure):
_fields_ = [
('cb', ctypes.c_ulong),
('DeviceName', ctypes.c_char * 32),
('DeviceString', ctypes.c_char * 128),
('StateFlags', ctypes.c_ulong),
('DeviceID', ctypes.c_char * 128),
('DeviceKey', ctypes.c_char * 128)
]
# Load user32.dll
user32 = ctypes.windll.user32
# Call EnumDisplayDevicesA
def enum_display_devices():
device_info = DISPLAY_DEVICEA()
device_info.cb = ctypes.sizeof(device_info)
device_index = 0
gpu_names = set()
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
device_index += 1
gpu_names.add(device_info.DeviceString.decode('utf-8'))
return gpu_names
return enum_display_devices()
else:
return set()
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950",
"GeForce 945M",
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745",
"Quadro K620",
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000",
"Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630"
}
def cuda_malloc_supported():
try:
names = get_gpu_names()
except:
names = set()
for x in names:
if "NVIDIA" in x:
for b in blacklist:
if b in x:
return False
return True
cuda_malloc = False
if not cuda_malloc:
try:
version = ""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py")
if os.path.isfile(ver_file):
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2: # enable by default for torch version 2.0 and up
cuda_malloc = cuda_malloc_supported()
except:
pass
if cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
print("CUDA Malloc Async Enabled")

View File

@@ -1,10 +1,13 @@
import copy
import json
import os
import random
from typing import List
import traceback
from functools import lru_cache
from typing import List, TYPE_CHECKING
import cv2
import numpy as np
import torch
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
@@ -12,9 +15,13 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
from tqdm import tqdm
import albumentations as A
from toolkit import image_utils
from toolkit.config_modules import DatasetConfig
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
class ImageDataset(Dataset, CaptionMixin):
@@ -27,7 +34,7 @@ class ImageDataset(Dataset, CaptionMixin):
self.include_prompt = self.get_config('include_prompt', False)
self.default_prompt = self.get_config('default_prompt', '')
if self.include_prompt:
self.caption_type = self.get_config('caption_type', 'txt')
self.caption_type = self.get_config('caption_ext', 'txt')
else:
self.caption_type = None
# we always random crop if random scale is enabled
@@ -48,6 +55,8 @@ class ImageDataset(Dataset, CaptionMixin):
else:
bad_count += 1
self.file_list = new_file_list
print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.path}"
@@ -85,7 +94,10 @@ class ImageDataset(Dataset, CaptionMixin):
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
scaler = scale_size / min_img_size
scale_width = int((img.width + 5) * scaler)
scale_height = int((img.height + 5) * scaler)
img = img.resize((scale_width, scale_height), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
@@ -100,21 +112,7 @@ class ImageDataset(Dataset, CaptionMixin):
return img
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
self.params = kwargs.get('params', {})
# convert kwargs enums for cv2
for key, value in self.params.items():
if isinstance(value, str):
# split the string
split_string = value.split('.')
if len(split_string) == 2 and split_string[0] == 'cv2':
if hasattr(cv2, split_string[1]):
self.params[key] = getattr(cv2, split_string[1].upper())
else:
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
class AugmentedImageDataset(ImageDataset):
@@ -265,6 +263,38 @@ class PairedImageDataset(Dataset):
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
img_path = img_path_or_tuple[1]
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
# always use # 2 (pos)
bucket_resolution = get_bucket_for_image_size(
width=img2.width,
height=img2.height,
resolution=self.size,
# divisibility=self.
)
# images will be same base dimension, but may be trimmed. We need to shrink and then central crop
if bucket_resolution['width'] > bucket_resolution['height']:
img1_scale_to_height = bucket_resolution["height"]
img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height))
img2_scale_to_height = bucket_resolution["height"]
img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height))
else:
img1_scale_to_width = bucket_resolution["width"]
img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width))
img2_scale_to_width = bucket_resolution["width"]
img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width))
img1_crop_height = bucket_resolution["height"]
img1_crop_width = bucket_resolution["width"]
img2_crop_height = bucket_resolution["height"]
img2_crop_width = bucket_resolution["width"]
# scale then center crop images
img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC)
img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1)
img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC)
img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2)
# combine them side by side
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
img.paste(img1, (0, 0))
@@ -272,52 +302,45 @@ class PairedImageDataset(Dataset):
else:
img_path = img_path_or_tuple
img = exif_transpose(Image.open(img_path)).convert('RGB')
height = self.size
# determine width to keep aspect ratio
width = int(img.size[0] * height / img.size[1])
# Downscale the source image first
img = img.resize((width, height), Image.BICUBIC)
prompt = self.get_prompt_item(index)
height = self.size
# determine width to keep aspect ratio
width = int(img.size[0] * height / img.size[1])
# Downscale the source image first
img = img.resize((width, height), Image.BICUBIC)
img = self.transform(img)
return img, prompt, (self.neg_weight, self.pos_weight)
printed_messages = []
class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
def print_once(msg):
global printed_messages
if msg not in printed_messages:
print(msg)
printed_messages.append(msg)
class FileItem:
def __init__(self, **kwargs):
self.path = kwargs.get('path', None)
self.width = kwargs.get('width', None)
self.height = kwargs.get('height', None)
# we scale first, then crop
self.scale_to_width = kwargs.get('scale_to_width', self.width)
self.scale_to_height = kwargs.get('scale_to_height', self.height)
# crop values are from scaled size
self.crop_x = kwargs.get('crop_x', 0)
self.crop_y = kwargs.get('crop_y', 0)
self.crop_width = kwargs.get('crop_width', self.scale_to_width)
self.crop_height = kwargs.get('crop_height', self.scale_to_height)
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
def __init__(
self,
dataset_config: 'DatasetConfig',
batch_size=1,
sd: 'StableDiffusion' = None,
):
super().__init__()
self.dataset_config = dataset_config
self.folder_path = dataset_config.folder_path
self.caption_type = dataset_config.caption_type
folder_path = dataset_config.folder_path
self.dataset_path = dataset_config.dataset_path
if self.dataset_path is None:
self.dataset_path = folder_path
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
self.is_caching_latents_to_memory = dataset_config.cache_latents
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
self.epoch_num = 0
self.sd = sd
if self.sd is None and self.is_caching_latents:
raise ValueError(f"sd is required for caching latents")
self.caption_type = dataset_config.caption_ext
self.default_caption = dataset_config.default_caption
self.random_scale = dataset_config.random_scale
self.scale = dataset_config.scale
@@ -325,176 +348,211 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
self.resolution = dataset_config.resolution
self.file_list: List['FileItem'] = []
self.caption_dict = None
self.file_list: List['FileItemDTO'] = []
# get the file list
file_list = [
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
# check if dataset_path is a folder or json
if os.path.isdir(self.dataset_path):
file_list = [
os.path.join(self.dataset_path, file) for file in os.listdir(self.dataset_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
else:
# assume json
with open(self.dataset_path, 'r') as f:
self.caption_dict = json.load(f)
# keys are file paths
file_list = list(self.caption_dict.keys())
if self.dataset_config.num_repeats > 1:
# repeat the list
file_list = file_list * self.dataset_config.num_repeats
# this might take a while
print(f" - Preprocessing image dimensions")
bad_count = 0
for file in tqdm(file_list):
try:
w, h = image_utils.get_image_size(file)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
img = Image.open(file)
h, w = img.size
if int(min(h, w) * self.scale) >= self.resolution:
self.file_list.append(
FileItem(
path=file,
width=w,
height=h,
scale_to_width=int(w * self.scale),
scale_to_height=int(h * self.scale),
)
file_item = FileItemDTO(
path=file,
dataset_config=dataset_config
)
else:
self.file_list.append(file_item)
except Exception as e:
print(traceback.format_exc())
print(f"Error processing image: {file}")
print(e)
bad_count += 1
print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
# print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
# handle x axis flips
if self.dataset_config.flip_x:
print(" - adding x axis flips")
current_file_list = [x for x in self.file_list]
for file_item in current_file_list:
# create a copy that is flipped on the x axis
new_file_item = copy.deepcopy(file_item)
new_file_item.flip_x = True
self.file_list.append(new_file_item)
# handle y axis flips
if self.dataset_config.flip_y:
print(" - adding y axis flips")
current_file_list = [x for x in self.file_list]
for file_item in current_file_list:
# create a copy that is flipped on the y axis
new_file_item = copy.deepcopy(file_item)
new_file_item.flip_y = True
self.file_list.append(new_file_item)
if self.dataset_config.flip_x or self.dataset_config.flip_y:
print(f" - Found {len(self.file_list)} images after adding flips")
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
self.setup_epoch()
def setup_epoch(self):
if self.epoch_num == 0:
# initial setup
# do not call for now
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
if self.is_caching_latents:
self.cache_latents_all_latents()
else:
if self.dataset_config.poi is not None:
# handle cropping to a specific point of interest
# setup buckets every epoch
self.setup_buckets(quiet=True)
self.epoch_num += 1
def __len__(self):
if self.dataset_config.buckets:
return len(self.batch_indices)
return len(self.file_list)
def _get_single_item(self, index):
file_item = self.file_list[index]
# todo make sure this matches
img = exif_transpose(Image.open(file_item.path)).convert('RGB')
w, h = img.size
if w > h and file_item.scale_to_width < file_item.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
elif h > w and file_item.scale_to_height < file_item.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
# Downscale the source image first
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
min_img_size = min(img.size)
if self.dataset_config.buckets:
# todo allow scaling and cropping, will be hard to add
# scale and crop based on file item
img = img.resize((file_item.scale_to_width, file_item.scale_to_height), Image.BICUBIC)
img = transforms.CenterCrop((file_item.crop_height, file_item.crop_width))(img)
else:
if self.random_crop:
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file_item.path}")
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
img = self.transform(img)
# todo convert it all
dataset_config_dict = {
"is_reg": 1 if self.dataset_config.is_reg else 0,
}
if self.caption_type is not None:
prompt = self.get_caption_item(index)
return img, prompt, dataset_config_dict
else:
return img, dataset_config_dict
def _get_single_item(self, index) -> 'FileItemDTO':
file_item = copy.deepcopy(self.file_list[index])
file_item.load_and_process_image(self.transform)
file_item.load_caption(self.caption_dict)
return file_item
def __getitem__(self, item):
if self.dataset_config.buckets:
# for buckets we collate ourselves for now
# todo allow a scheduler to dynamically make buckets
# we collate ourselves
if len(self.batch_indices) - 1 < item:
# tried everything to solve this. No way to reset length when redoing things. Pick another index
item = random.randint(0, len(self.batch_indices) - 1)
idx_list = self.batch_indices[item]
tensor_list = []
prompt_list = []
dataset_config_dict_list = []
for idx in idx_list:
if self.caption_type is not None:
img, prompt, dataset_config_dict = self._get_single_item(idx)
prompt_list.append(prompt)
dataset_config_dict_list.append(dataset_config_dict)
else:
img, dataset_config_dict = self._get_single_item(idx)
dataset_config_dict_list.append(dataset_config_dict)
tensor_list.append(img.unsqueeze(0))
if self.caption_type is not None:
return torch.cat(tensor_list, dim=0), prompt_list, dataset_config_dict_list
else:
return torch.cat(tensor_list, dim=0), dataset_config_dict_list
return [self._get_single_item(idx) for idx in idx_list]
else:
# Dataloader is batching
return self._get_single_item(item)
def get_dataloader_from_datasets(dataset_options, batch_size=1):
# TODO do bucketing
def get_dataloader_from_datasets(
dataset_options,
batch_size=1,
sd: 'StableDiffusion' = None,
) -> DataLoader:
if dataset_options is None or len(dataset_options) == 0:
return None
datasets = []
has_buckets = False
is_caching_latents = False
dataset_config_list = []
# preprocess them all
for dataset_option in dataset_options:
if isinstance(dataset_option, DatasetConfig):
config = dataset_option
dataset_config_list.append(dataset_option)
else:
config = DatasetConfig(**dataset_option)
# preprocess raw data
split_configs = preprocess_dataset_raw_config([dataset_option])
for x in split_configs:
dataset_config_list.append(DatasetConfig(**x))
for config in dataset_config_list:
if config.type == 'image':
dataset = AiToolkitDataset(config, batch_size=batch_size)
dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd)
datasets.append(dataset)
if config.buckets:
has_buckets = True
if config.cache_latents or config.cache_latents_to_disk:
is_caching_latents = True
else:
raise ValueError(f"invalid dataset type: {config.type}")
concatenated_dataset = ConcatDataset(datasets)
# todo build scheduler that can get buckets from all datasets that match
# todo and evenly distribute reg images
def dto_collation(batch: List['FileItemDTO']):
# create DTO batch
batch = DataLoaderBatchDTO(
file_items=batch
)
return batch
# check if is caching latents
if has_buckets:
# make sure they all have buckets
for dataset in datasets:
assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none"
def custom_collate_fn(batch):
# just return as is
return batch
data_loader = DataLoader(
concatenated_dataset,
batch_size=None, # we batch in the dataloader
batch_size=None, # we batch in the datasets for now
drop_last=False,
shuffle=True,
collate_fn=custom_collate_fn, # Use the custom collate function
num_workers=2
collate_fn=dto_collation, # Use the custom collate function
num_workers=4
)
else:
data_loader = DataLoader(
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
num_workers=4,
collate_fn=dto_collation
)
return data_loader
def trigger_dataloader_setup_epoch(dataloader: DataLoader):
# hacky but needed because of different types of datasets and dataloaders
dataloader.len = None
if isinstance(dataloader.dataset, list):
for dataset in dataloader.dataset:
if hasattr(dataset, 'datasets'):
for sub_dataset in dataset.datasets:
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None
elif hasattr(dataset, 'setup_epoch'):
dataset.setup_epoch()
dataset.len = None
elif hasattr(dataloader.dataset, 'setup_epoch'):
dataloader.dataset.setup_epoch()
dataloader.dataset.len = None
elif hasattr(dataloader.dataset, 'datasets'):
dataloader.dataset.len = None
for sub_dataset in dataloader.dataset.datasets:
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None

View File

@@ -0,0 +1,202 @@
from typing import TYPE_CHECKING, List, Union
import torch
import random
from PIL import Image
from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
printed_messages = []
def print_once(msg):
global printed_messages
if msg not in printed_messages:
print(msg)
printed_messages.append(msg)
class FileItemDTO(
LatentCachingFileItemDTOMixin,
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
UnconditionalFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin,
):
def __init__(self, *args, **kwargs):
self.path = kwargs.get('path', None)
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
# process width and height
try:
w, h = image_utils.get_image_size(self.path)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
img = exif_transpose(Image.open(self.path))
h, w = img.size
self.width: int = w
self.height: int = h
super().__init__(*args, **kwargs)
# self.caption_path: str = kwargs.get('caption_path', None)
self.raw_caption: str = kwargs.get('raw_caption', None)
# we scale first, then crop
self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale))
self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale))
# crop values are from scaled size
self.crop_x: int = kwargs.get('crop_x', 0)
self.crop_y: int = kwargs.get('crop_y', 0)
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_x', False)
self.augments: List[str] = self.dataset_config.augments
self.network_weight: float = self.dataset_config.network_weight
self.is_reg = self.dataset_config.is_reg
self.tensor: Union[torch.Tensor, None] = None
def cleanup(self):
self.tensor = None
self.cleanup_latent()
self.cleanup_control()
self.cleanup_mask()
self.cleanup_unconditional()
class DataLoaderBatchDTO:
def __init__(self, **kwargs):
try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
self.unconditional_latents: Union[torch.Tensor, None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
# if we have encoded latents, we concatenate them
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
if any([x.control_tensor is not None for x in self.file_items]):
# find one to use as a base
base_control_tensor = None
for x in self.file_items:
if x.control_tensor is not None:
base_control_tensor = x.control_tensor
break
control_tensors = []
for x in self.file_items:
if x.control_tensor is None:
control_tensors.append(torch.zeros_like(base_control_tensor))
else:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
if any([x.mask_tensor is not None for x in self.file_items]):
# find one to use as a base
base_mask_tensor = None
for x in self.file_items:
if x.mask_tensor is not None:
base_mask_tensor = x.mask_tensor
break
mask_tensors = []
for x in self.file_items:
if x.mask_tensor is None:
mask_tensors.append(torch.zeros_like(base_mask_tensor))
else:
mask_tensors.append(x.mask_tensor)
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
# add unaugmented tensors for ones with augments
if any([x.unaugmented_tensor is not None for x in self.file_items]):
# find one to use as a base
base_unaugmented_tensor = None
for x in self.file_items:
if x.unaugmented_tensor is not None:
base_unaugmented_tensor = x.unaugmented_tensor
break
unaugmented_tensor = []
for x in self.file_items:
if x.unaugmented_tensor is None:
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
else:
unaugmented_tensor.append(x.unaugmented_tensor)
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
# add unconditional tensors
if any([x.unconditional_tensor is not None for x in self.file_items]):
# find one to use as a base
base_unconditional_tensor = None
for x in self.file_items:
if x.unaugmented_tensor is not None:
base_unconditional_tensor = x.unconditional_tensor
break
unconditional_tensor = []
for x in self.file_items:
if x.unconditional_tensor is None:
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
else:
unconditional_tensor.append(x.unconditional_tensor)
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
except Exception as e:
print(e)
raise e
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]
def get_network_weight_list(self):
return [x.network_weight for x in self.file_items]
def get_caption_list(
self,
trigger=None,
to_replace_list=None,
add_if_not_present=True
):
return [x.get_caption(
trigger=trigger,
to_replace_list=to_replace_list,
add_if_not_present=add_if_not_present
) for x in self.file_items]
def get_caption_short_list(
self,
trigger=None,
to_replace_list=None,
add_if_not_present=True
):
return [x.get_caption(
trigger=trigger,
to_replace_list=to_replace_list,
add_if_not_present=add_if_not_present,
short_caption=True
) for x in self.file_items]
def cleanup(self):
del self.latents
del self.tensor
del self.control_tensor
for file_item in self.file_items:
file_item.cleanup()

File diff suppressed because it is too large Load Diff

View File

@@ -21,7 +21,8 @@ class Embedding:
def __init__(
self,
sd: 'StableDiffusion',
embed_config: 'EmbeddingConfig'
embed_config: 'EmbeddingConfig',
state_dict: OrderedDict = None,
):
self.name = embed_config.trigger
self.sd = sd
@@ -38,74 +39,115 @@ class Embedding:
additional_tokens.append(f"{self.embed_config.trigger}_{i}")
placeholder_tokens += additional_tokens
num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.embed_config.tokens:
raise ValueError(
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# handle dual tokenizer
self.tokenizer_list = self.sd.tokenizer if isinstance(self.sd.tokenizer, list) else [self.sd.tokenizer]
self.text_encoder_list = self.sd.text_encoder if isinstance(self.sd.text_encoder, list) else [
self.sd.text_encoder]
# Convert the initializer_token, placeholder_token to ids
init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
# if length of token ids is more than number of orm embedding tokens fill with *
if len(init_token_ids) > self.embed_config.tokens:
init_token_ids = init_token_ids[:self.embed_config.tokens]
elif len(init_token_ids) < self.embed_config.tokens:
pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False)
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
self.placeholder_token_ids = []
self.embedding_tokens = []
self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens)
print(f"Adding {placeholder_tokens} tokens to tokenizer")
print(f"Adding {self.embed_config.tokens} tokens to tokenizer")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
# todo SDXL has 2 text encoders, need to do both for all of this
self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer))
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.embed_config.tokens:
raise ValueError(
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
)
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# Convert the initializer_token, placeholder_token to ids
init_token_ids = tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
# if length of token ids is more than number of orm embedding tokens fill with *
if len(init_token_ids) > self.embed_config.tokens:
init_token_ids = init_token_ids[:self.embed_config.tokens]
elif len(init_token_ids) < self.embed_config.tokens:
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
self.placeholder_token_ids.append(placeholder_token_ids)
# returns the string to have in the prompt to trigger the embedding
def get_embedding_string(self):
return self.embedding_tokens
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
# backup text encoder embeddings
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
def restore_embeddings(self):
# Let's make sure we don't update any embedding weights besides the newly added token
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
self.tokenizer_list,
self.orig_embeds_params,
self.placeholder_token_ids):
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
with torch.no_grad():
text_encoder.get_input_embeddings().weight[
index_no_updates
] = orig_embeds[index_no_updates]
def get_trainable_params(self):
# todo only get this one as we could have more than one
return self.sd.text_encoder.get_input_embeddings().parameters()
params = []
for text_encoder in self.text_encoder_list:
params += text_encoder.get_input_embeddings().parameters()
return params
# make setter and getter for vec
@property
def vec(self):
def _get_vec(self, text_encoder_idx=0):
# should we get params instead
# create vector from token embeds
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
# stack the tokens along batch axis adding that axis
new_vector = torch.stack(
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
[token_embeds[token_id] for token_id in self.placeholder_token_ids[text_encoder_idx]],
dim=0
)
return new_vector
@vec.setter
def vec(self, new_vector):
def _set_vec(self, new_vector, text_encoder_idx=0):
# shape is (1, 768) for SD 1.5 for 1 token
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
for i in range(new_vector.shape[0]):
# apply the weights to the placeholder tokens while preserving gradient
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
x = 1
token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone()
# make setter and getter for vec
@property
def vec(self):
return self._get_vec(0)
@vec.setter
def vec(self, new_vector):
self._set_vec(new_vector, 0)
@property
def vec2(self):
return self._get_vec(1)
@vec2.setter
def vec2(self, new_vector):
self._set_vec(new_vector, 1)
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
# however, on training we don't use that pipeline, so we have to do it ourselves
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
output_prompt = prompt
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
default_replacements = ["[name]", "[trigger]"]
replace_with = self.embedding_tokens if expand_token else self.trigger
replace_with = embedding_tokens if expand_token else self.trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
@@ -120,7 +162,7 @@ class Embedding:
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = prompt.count(replace_with)
num_instances = output_prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
@@ -128,10 +170,21 @@ class Embedding:
if num_instances > 1:
print(
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt
def state_dict(self):
if self.sd.is_xl:
state_dict = OrderedDict()
state_dict['clip_l'] = self.vec
state_dict['clip_g'] = self.vec2
else:
state_dict = OrderedDict()
state_dict['emb_params'] = self.vec
return state_dict
def save(self, filename):
# todo check to see how to get the vector out of the embedding
@@ -145,13 +198,14 @@ class Embedding:
"sd_checkpoint_name": None,
"notes": None,
}
# TODO we do not currently support this. Check how auto is doing it. Only safetensors supported sor sdxl
if filename.endswith('.pt'):
torch.save(embedding_data, filename)
elif filename.endswith('.bin'):
torch.save(embedding_data, filename)
elif filename.endswith('.safetensors'):
# save the embedding as a safetensors file
state_dict = {"emb_params": self.vec}
state_dict = self.state_dict()
# add all embedding data (except string_to_param), to metadata
metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"})
metadata["string_to_param"] = {"*": "emb_params"}
@@ -163,6 +217,7 @@ class Embedding:
path = os.path.realpath(file_path)
filename = os.path.basename(path)
name, ext = os.path.splitext(filename)
tensors = {}
ext = ext.upper()
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
_, second_ext = os.path.splitext(name)
@@ -170,10 +225,12 @@ class Embedding:
return
if ext in ['.BIN', '.PT']:
# todo check this
if self.sd.is_xl:
raise Exception("XL not supported yet for bin, pt")
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
# rebuild the embedding from the safetensors file if it has it
tensors = {}
with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
@@ -195,26 +252,32 @@ class Embedding:
else:
return
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict,
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
if self.sd.is_xl:
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
if 'step' in data:
self.step = int(data['step'])
else:
raise Exception(
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict,
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
if 'step' in data:
self.step = int(data['step'])
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
self.vec = emb.detach().to(device, dtype=torch.float32)
if 'step' in data:
self.step = int(data['step'])
self.vec = emb.detach().to(device, dtype=torch.float32)

View File

@@ -1,9 +1,16 @@
# ref https://github.com/scardine/image_size/blob/master/get_image_size.py
import atexit
import collections
import json
import os
import io
import struct
from typing import TYPE_CHECKING
import cv2
import numpy as np
import torch
from diffusers import AutoencoderTiny
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
@@ -112,7 +119,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
width = int(w)
height = int(h)
elif ((size >= 24) and data.startswith(b'\211PNG\r\n\032\n')
and (data[12:16] == b'IHDR')):
and (data[12:16] == b'IHDR')):
# PNGs
imgtype = PNG
w, h = struct.unpack(">LL", data[16:24])
@@ -190,7 +197,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
9: (4, boChar + "l"), # SLONG
10: (8, boChar + "ll"), # SRATIONAL
11: (4, boChar + "f"), # FLOAT
12: (8, boChar + "d") # DOUBLE
12: (8, boChar + "d") # DOUBLE
}
ifdOffset = struct.unpack(boChar + "L", data[4:8])[0]
try:
@@ -206,7 +213,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
input.seek(entryOffset)
tag = input.read(2)
tag = struct.unpack(boChar + "H", tag)[0]
if(tag == 256 or tag == 257):
if (tag == 256 or tag == 257):
# if type indicates that value fits into 4 bytes, value
# offset is not an offset but value itself
type = input.read(2)
@@ -229,7 +236,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
except Exception as e:
raise UnknownImageFormat(str(e))
elif size >= 2:
# see http://en.wikipedia.org/wiki/ICO_(file_format)
# see http://en.wikipedia.org/wiki/ICO_(file_format)
imgtype = 'ICO'
input.seek(0)
reserved = input.read(2)
@@ -350,13 +357,13 @@ def main(argv=None):
prs.add_option('-v', '--verbose',
dest='verbose',
action='store_true',)
action='store_true', )
prs.add_option('-q', '--quiet',
dest='quiet',
action='store_true',)
action='store_true', )
prs.add_option('-t', '--test',
dest='run_tests',
action='store_true',)
action='store_true', )
argv = list(argv) if argv is not None else sys.argv[1:]
(opts, args) = prs.parse_args(args=argv)
@@ -417,6 +424,60 @@ def main(argv=None):
return EX_OK
is_window_shown = False
def show_img(img, name='AI Toolkit'):
global is_window_shown
img = np.clip(img, 0, 255).astype(np.uint8)
cv2.imshow(name, img[:, :, ::-1])
k = cv2.waitKey(10) & 0xFF
if k == 27: # Esc key to stop
print('\nESC pressed, stopping')
raise KeyboardInterrupt
if not is_window_shown:
is_window_shown = True
def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
# if rank is 4
if len(imgs.shape) == 4:
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
else:
img_list = [imgs]
# put images side by side
img = torch.cat(img_list, dim=3)
# img is -1 to 1, convert to 0 to 255
img = img / 2 + 0.5
img_numpy = img.to(torch.float32).detach().cpu().numpy()
img_numpy = np.clip(img_numpy, 0, 1) * 255
# convert to numpy Move channel to last
img_numpy = img_numpy.transpose(0, 2, 3, 1)
# convert to uint8
img_numpy = img_numpy.astype(np.uint8)
show_img(img_numpy[0], name=name)
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
# decode latents
if vae.device == 'cpu':
vae.to(latents.device)
latents = latents / vae.config['scaling_factor']
imgs = vae.decode(latents).sample
show_tensors(imgs, name=name)
def on_exit():
if is_window_shown:
cv2.destroyAllWindows()
atexit.register(on_exit)
if __name__ == "__main__":
import sys
sys.exit(main(argv=sys.argv[1:]))
sys.exit(main(argv=sys.argv[1:]))

410
toolkit/inversion_utils.py Normal file
View File

@@ -0,0 +1,410 @@
# ref https://huggingface.co/spaces/editing-images/ledits/blob/main/inversion_utils.py
import torch
import os
from tqdm import tqdm
from toolkit import train_tools
from toolkit.prompt_utils import PromptEmbeds
from toolkit.stable_diffusion_model import StableDiffusion
def mu_tilde(model, xt, x0, timestep):
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
alpha_t = model.scheduler.alphas[timestep]
beta_t = 1 - alpha_t
alpha_bar = model.scheduler.alphas_cumprod[timestep]
return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1 - alpha_bar)) * x0 + (
(alpha_t ** 0.5 * (1 - alpha_prod_t_prev)) / (1 - alpha_bar)) * xt
def sample_xts_from_x0(sd: StableDiffusion, sample: torch.Tensor, num_inference_steps=50):
"""
Samples from P(x_1:T|x_0)
"""
# torch.manual_seed(43256465436)
alpha_bar = sd.noise_scheduler.alphas_cumprod
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
alphas = sd.noise_scheduler.alphas
betas = 1 - alphas
# variance_noise_shape = (
# num_inference_steps,
# sd.unet.in_channels,
# sd.unet.sample_size,
# sd.unet.sample_size)
variance_noise_shape = list(sample.shape)
variance_noise_shape[0] = num_inference_steps
timesteps = sd.noise_scheduler.timesteps.to(sd.device)
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
xts = torch.zeros(variance_noise_shape).to(sample.device, dtype=torch.float16)
for t in reversed(timesteps):
idx = t_to_idx[int(t)]
xts[idx] = sample * (alpha_bar[t] ** 0.5) + torch.randn_like(sample, dtype=torch.float16) * sqrt_one_minus_alpha_bar[t]
xts = torch.cat([xts, sample], dim=0)
return xts
def encode_text(model, prompts):
text_input = model.tokenizer(
prompts,
padding="max_length",
max_length=model.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
return text_encoding
def forward_step(sd: StableDiffusion, model_output, timestep, sample):
next_timestep = min(
sd.noise_scheduler.config['num_train_timesteps'] - 2,
timestep + sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
)
# 2. compute alphas, betas
alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
# alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 5. TODO: simple noising implementation
next_sample = sd.noise_scheduler.add_noise(
pred_original_sample,
model_output,
torch.LongTensor([next_timestep]))
return next_sample
def get_variance(sd: StableDiffusion, timestep): # , prev_timestep):
prev_timestep = timestep - sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def get_time_ids_from_latents(sd: StableDiffusion, latents: torch.Tensor):
VAE_SCALE_FACTOR = 2 ** (len(sd.vae.config['block_out_channels']) - 1)
if sd.is_xl:
bs, ch, h, w = list(latents.shape)
height = h * VAE_SCALE_FACTOR
width = w * VAE_SCALE_FACTOR
dtype = latents.dtype
# just do it without any cropping nonsense
target_size = (height, width)
original_size = (height, width)
crops_coords_top_left = (0, 0)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
batch_time_ids = torch.cat(
[add_time_ids for _ in range(bs)]
)
return batch_time_ids
else:
return None
def inversion_forward_process(
sd: StableDiffusion,
sample: torch.Tensor,
conditional_embeddings: PromptEmbeds,
unconditional_embeddings: PromptEmbeds,
etas=None,
prog_bar=False,
cfg_scale=3.5,
num_inference_steps=50, eps=None
):
current_num_timesteps = len(sd.noise_scheduler.timesteps)
sd.noise_scheduler.set_timesteps(num_inference_steps, device=sd.device)
timesteps = sd.noise_scheduler.timesteps.to(sd.device)
# variance_noise_shape = (
# num_inference_steps,
# sd.unet.in_channels,
# sd.unet.sample_size,
# sd.unet.sample_size
# )
variance_noise_shape = list(sample.shape)
variance_noise_shape[0] = num_inference_steps
if etas is None or (type(etas) in [int, float] and etas == 0):
eta_is_zero = True
zs = None
else:
eta_is_zero = False
if type(etas) in [int, float]: etas = [etas] * sd.noise_scheduler.num_inference_steps
xts = sample_xts_from_x0(sd, sample, num_inference_steps=num_inference_steps)
alpha_bar = sd.noise_scheduler.alphas_cumprod
zs = torch.zeros(size=variance_noise_shape, device=sd.device, dtype=torch.float16)
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
noisy_sample = sample
op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
for timestep in op:
idx = t_to_idx[int(timestep)]
# 1. predict noise residual
if not eta_is_zero:
noisy_sample = xts[idx][None]
added_cond_kwargs = {}
with torch.no_grad():
text_embeddings = train_tools.concat_prompt_embeddings(
unconditional_embeddings, # negative embedding
conditional_embeddings, # positive embedding
1, # batch size
)
if sd.is_xl:
add_time_ids = get_time_ids_from_latents(sd, noisy_sample)
# add extra for cfg
add_time_ids = torch.cat(
[add_time_ids] * 2, dim=0
)
added_cond_kwargs = {
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": add_time_ids,
}
# double up for cfg
latent_model_input = torch.cat(
[noisy_sample] * 2, dim=0
)
noise_pred = sd.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
# out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=uncond_embedding)
# cond_out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=text_embeddings)
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
if eta_is_zero:
# 2. compute more noisy image and set x_t -> x_t+1
noisy_sample = forward_step(sd, noise_pred, timestep, noisy_sample)
xts = None
else:
xtm1 = xts[idx + 1][None]
# pred of x0
pred_original_sample = (noisy_sample - (1 - alpha_bar[timestep]) ** 0.5 * noise_pred) / alpha_bar[
timestep] ** 0.5
# direction to xt
prev_timestep = timestep - sd.noise_scheduler.config[
'num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
variance = get_variance(sd, timestep)
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
zs[idx] = z
# correction to avoid error accumulation
xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
xts[idx + 1] = xtm1
if not zs is None:
zs[-1] = torch.zeros_like(zs[-1])
# restore timesteps
sd.noise_scheduler.set_timesteps(current_num_timesteps, device=sd.device)
return noisy_sample, zs, xts
#
# def inversion_forward_process(
# model,
# sample,
# etas=None,
# prog_bar=False,
# prompt="",
# cfg_scale=3.5,
# num_inference_steps=50, eps=None
# ):
# if not prompt == "":
# text_embeddings = encode_text(model, prompt)
# uncond_embedding = encode_text(model, "")
# timesteps = model.scheduler.timesteps.to(model.device)
# variance_noise_shape = (
# num_inference_steps,
# model.unet.in_channels,
# model.unet.sample_size,
# model.unet.sample_size)
# if etas is None or (type(etas) in [int, float] and etas == 0):
# eta_is_zero = True
# zs = None
# else:
# eta_is_zero = False
# if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
# xts = sample_xts_from_x0(model, sample, num_inference_steps=num_inference_steps)
# alpha_bar = model.scheduler.alphas_cumprod
# zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype=torch.float16)
#
# t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
# noisy_sample = sample
# op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
#
# for t in op:
# idx = t_to_idx[int(t)]
# # 1. predict noise residual
# if not eta_is_zero:
# noisy_sample = xts[idx][None]
#
# with torch.no_grad():
# out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=uncond_embedding)
# if not prompt == "":
# cond_out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=text_embeddings)
#
# if not prompt == "":
# ## classifier free guidance
# noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
# else:
# noise_pred = out.sample
#
# if eta_is_zero:
# # 2. compute more noisy image and set x_t -> x_t+1
# noisy_sample = forward_step(model, noise_pred, t, noisy_sample)
#
# else:
# xtm1 = xts[idx + 1][None]
# # pred of x0
# pred_original_sample = (noisy_sample - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
#
# # direction to xt
# prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
# alpha_prod_t_prev = model.scheduler.alphas_cumprod[
# prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
#
# variance = get_variance(model, t)
# pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
#
# mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
#
# z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
# zs[idx] = z
#
# # correction to avoid error accumulation
# xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
# xts[idx + 1] = xtm1
#
# if not zs is None:
# zs[-1] = torch.zeros_like(zs[-1])
#
# return noisy_sample, zs, xts
def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None):
# 1. get previous step value (=t-1)
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
# variance = self.scheduler._get_variance(timestep, prev_timestep)
variance = get_variance(model, timestep) # , prev_timestep)
std_dev_t = eta * variance ** (0.5)
# Take care of asymetric reverse process (asyrp)
model_output_direction = model_output
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
# 8. Add noice if eta > 0
if eta > 0:
if variance_noise is None:
variance_noise = torch.randn(model_output.shape, device=model.device, dtype=torch.float16)
sigma_z = eta * variance ** (0.5) * variance_noise
prev_sample = prev_sample + sigma_z
return prev_sample
def inversion_reverse_process(
model,
xT,
etas=0,
prompts="",
cfg_scales=None,
prog_bar=False,
zs=None,
controller=None,
asyrp=False):
batch_size = len(prompts)
cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1, 1, 1, 1).to(model.device, dtype=torch.float16)
text_embeddings = encode_text(model, prompts)
uncond_embedding = encode_text(model, [""] * batch_size)
if etas is None: etas = 0
if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
assert len(etas) == model.scheduler.num_inference_steps
timesteps = model.scheduler.timesteps.to(model.device)
xt = xT.expand(batch_size, -1, -1, -1)
op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
for t in op:
idx = t_to_idx[int(t)]
## Unconditional embedding
with torch.no_grad():
uncond_out = model.unet.forward(xt, timestep=t,
encoder_hidden_states=uncond_embedding)
## Conditional embedding
if prompts:
with torch.no_grad():
cond_out = model.unet.forward(xt, timestep=t,
encoder_hidden_states=text_embeddings)
z = zs[idx] if not zs is None else None
z = z.expand(batch_size, -1, -1, -1)
if prompts:
## classifier free guidance
noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
else:
noise_pred = uncond_out.sample
# 2. compute less noisy image and set x_t -> x_t-1
xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z)
if controller is not None:
xt = controller.step_callback(xt)
return xt, zs

157
toolkit/ip_adapter.py Normal file
View File

@@ -0,0 +1,157 @@
import torch
import sys
from PIL import Image
from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.paths import REPOS_ROOT
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds
import weakref
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
super().__init__()
self.config = adapter_config
self.sd_ref: weakref.ref = weakref.ref(sd)
self.clip_image_processor = CLIPImageProcessor()
self.device = self.sd_ref().unet.device
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path)
if adapter_config.type == 'ip':
# ip-adapter
image_proj_model = ImageProjModel(
cross_attention_dim=sd.unet.config['cross_attention_dim'],
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=4,
)
elif adapter_config.type == 'ip+':
# ip-adapter-plus
num_tokens = 16
image_proj_model = Resampler(
dim=sd.unet.config['cross_attention_dim'],
depth=4,
dim_head=64,
heads=12,
num_queries=num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4
)
else:
raise ValueError(f"unknown adapter type: {adapter_config.type}")
# init adapter modules
attn_procs = {}
unet_sd = sd.unet.state_dict()
for name in sd.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):
hidden_size = sd.unet.config['block_out_channels'][-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id]
else:
# they didnt have this, but would lead to undefined below
raise ValueError(f"unknown attn processor name: {name}")
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
attn_procs[name].load_state_dict(weights)
sd.unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
sd.adapter = self
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.image_encoder.to(*args, **kwargs)
self.image_proj_model.to(*args, **kwargs)
self.adapter_modules.to(*args, **kwargs)
return self
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
state_dict["image_proj"] = self.image_proj_model.state_dict()
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
return state_dict
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
@torch.no_grad()
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], drop=False) -> torch.Tensor:
# todo: add support for sdxl
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
if drop:
clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
return clip_image_embeds
@torch.no_grad()
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False) -> torch.Tensor:
# tensors should be 0-1
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
if drop:
clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
return clip_image_embeds
# use drop for prompt dropout, or negatives
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
clip_image_embeds = clip_image_embeds.detach()
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
image_prompt_embeds = self.image_proj_model(clip_image_embeds.detach())
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
return embeddings
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
for attn_processor in self.adapter_modules:
yield from attn_processor.parameters(recurse)
yield from self.image_proj_model.parameters(recurse)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
{
"ldm": {
"conditioner.embedders.0.model.logit_scale": {
"shape": [],
"min": 4.60546875,
"max": 4.60546875
},
"conditioner.embedders.0.model.text_projection": {
"shape": [
1280,
1280
],
"min": -0.15966796875,
"max": 0.230712890625
}
},
"diffusers": {
"te1_text_projection.weight": {
"shape": [
1280,
1280
],
"min": -0.15966796875,
"max": 0.230712890625
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,8 +4,6 @@
"cond_stage_model.model.ln_final.weight": "te_text_model.final_layer_norm.weight",
"cond_stage_model.model.positional_embedding": "te_text_model.embeddings.position_embedding.weight",
"cond_stage_model.model.token_embedding.weight": "te_text_model.embeddings.token_embedding.weight",
"cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias": "te_text_model.encoder.layers.0.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight": "te_text_model.encoder.layers.0.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.0.attn.out_proj.bias": "te_text_model.encoder.layers.0.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight": "te_text_model.encoder.layers.0.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.0.ln_1.bias": "te_text_model.encoder.layers.0.layer_norm1.bias",
@@ -16,8 +14,6 @@
"cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.weight": "te_text_model.encoder.layers.0.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.bias": "te_text_model.encoder.layers.0.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.weight": "te_text_model.encoder.layers.0.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias": "te_text_model.encoder.layers.1.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight": "te_text_model.encoder.layers.1.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.1.attn.out_proj.bias": "te_text_model.encoder.layers.1.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.1.attn.out_proj.weight": "te_text_model.encoder.layers.1.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.1.ln_1.bias": "te_text_model.encoder.layers.1.layer_norm1.bias",
@@ -28,8 +24,6 @@
"cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.weight": "te_text_model.encoder.layers.1.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.bias": "te_text_model.encoder.layers.1.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.weight": "te_text_model.encoder.layers.1.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias": "te_text_model.encoder.layers.10.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight": "te_text_model.encoder.layers.10.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.10.attn.out_proj.bias": "te_text_model.encoder.layers.10.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.10.attn.out_proj.weight": "te_text_model.encoder.layers.10.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.10.ln_1.bias": "te_text_model.encoder.layers.10.layer_norm1.bias",
@@ -40,8 +34,6 @@
"cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.weight": "te_text_model.encoder.layers.10.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.bias": "te_text_model.encoder.layers.10.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.weight": "te_text_model.encoder.layers.10.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias": "te_text_model.encoder.layers.11.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight": "te_text_model.encoder.layers.11.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.11.attn.out_proj.bias": "te_text_model.encoder.layers.11.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.11.attn.out_proj.weight": "te_text_model.encoder.layers.11.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.11.ln_1.bias": "te_text_model.encoder.layers.11.layer_norm1.bias",
@@ -52,8 +44,6 @@
"cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.weight": "te_text_model.encoder.layers.11.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.bias": "te_text_model.encoder.layers.11.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.weight": "te_text_model.encoder.layers.11.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias": "te_text_model.encoder.layers.12.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight": "te_text_model.encoder.layers.12.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.12.attn.out_proj.bias": "te_text_model.encoder.layers.12.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.12.attn.out_proj.weight": "te_text_model.encoder.layers.12.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.12.ln_1.bias": "te_text_model.encoder.layers.12.layer_norm1.bias",
@@ -64,8 +54,6 @@
"cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.weight": "te_text_model.encoder.layers.12.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.bias": "te_text_model.encoder.layers.12.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.weight": "te_text_model.encoder.layers.12.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias": "te_text_model.encoder.layers.13.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight": "te_text_model.encoder.layers.13.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.13.attn.out_proj.bias": "te_text_model.encoder.layers.13.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.13.attn.out_proj.weight": "te_text_model.encoder.layers.13.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.13.ln_1.bias": "te_text_model.encoder.layers.13.layer_norm1.bias",
@@ -76,8 +64,6 @@
"cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.weight": "te_text_model.encoder.layers.13.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.bias": "te_text_model.encoder.layers.13.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.weight": "te_text_model.encoder.layers.13.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias": "te_text_model.encoder.layers.14.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight": "te_text_model.encoder.layers.14.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.14.attn.out_proj.bias": "te_text_model.encoder.layers.14.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.14.attn.out_proj.weight": "te_text_model.encoder.layers.14.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.14.ln_1.bias": "te_text_model.encoder.layers.14.layer_norm1.bias",
@@ -88,8 +74,6 @@
"cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.weight": "te_text_model.encoder.layers.14.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.bias": "te_text_model.encoder.layers.14.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.weight": "te_text_model.encoder.layers.14.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias": "te_text_model.encoder.layers.15.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight": "te_text_model.encoder.layers.15.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.15.attn.out_proj.bias": "te_text_model.encoder.layers.15.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.15.attn.out_proj.weight": "te_text_model.encoder.layers.15.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.15.ln_1.bias": "te_text_model.encoder.layers.15.layer_norm1.bias",
@@ -100,8 +84,6 @@
"cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.weight": "te_text_model.encoder.layers.15.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.bias": "te_text_model.encoder.layers.15.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.weight": "te_text_model.encoder.layers.15.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias": "te_text_model.encoder.layers.16.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight": "te_text_model.encoder.layers.16.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.16.attn.out_proj.bias": "te_text_model.encoder.layers.16.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.16.attn.out_proj.weight": "te_text_model.encoder.layers.16.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.16.ln_1.bias": "te_text_model.encoder.layers.16.layer_norm1.bias",
@@ -112,8 +94,6 @@
"cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.weight": "te_text_model.encoder.layers.16.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.bias": "te_text_model.encoder.layers.16.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.weight": "te_text_model.encoder.layers.16.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias": "te_text_model.encoder.layers.17.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight": "te_text_model.encoder.layers.17.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.17.attn.out_proj.bias": "te_text_model.encoder.layers.17.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.17.attn.out_proj.weight": "te_text_model.encoder.layers.17.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.17.ln_1.bias": "te_text_model.encoder.layers.17.layer_norm1.bias",
@@ -124,8 +104,6 @@
"cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.weight": "te_text_model.encoder.layers.17.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.bias": "te_text_model.encoder.layers.17.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.weight": "te_text_model.encoder.layers.17.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias": "te_text_model.encoder.layers.18.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight": "te_text_model.encoder.layers.18.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.18.attn.out_proj.bias": "te_text_model.encoder.layers.18.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.18.attn.out_proj.weight": "te_text_model.encoder.layers.18.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.18.ln_1.bias": "te_text_model.encoder.layers.18.layer_norm1.bias",
@@ -136,8 +114,6 @@
"cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.weight": "te_text_model.encoder.layers.18.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.bias": "te_text_model.encoder.layers.18.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.weight": "te_text_model.encoder.layers.18.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias": "te_text_model.encoder.layers.19.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight": "te_text_model.encoder.layers.19.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.19.attn.out_proj.bias": "te_text_model.encoder.layers.19.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.19.attn.out_proj.weight": "te_text_model.encoder.layers.19.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.19.ln_1.bias": "te_text_model.encoder.layers.19.layer_norm1.bias",
@@ -148,8 +124,6 @@
"cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.weight": "te_text_model.encoder.layers.19.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.bias": "te_text_model.encoder.layers.19.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.weight": "te_text_model.encoder.layers.19.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias": "te_text_model.encoder.layers.2.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight": "te_text_model.encoder.layers.2.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.2.attn.out_proj.bias": "te_text_model.encoder.layers.2.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.2.attn.out_proj.weight": "te_text_model.encoder.layers.2.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.2.ln_1.bias": "te_text_model.encoder.layers.2.layer_norm1.bias",
@@ -160,8 +134,6 @@
"cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.weight": "te_text_model.encoder.layers.2.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.bias": "te_text_model.encoder.layers.2.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.weight": "te_text_model.encoder.layers.2.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias": "te_text_model.encoder.layers.20.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight": "te_text_model.encoder.layers.20.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.20.attn.out_proj.bias": "te_text_model.encoder.layers.20.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.20.attn.out_proj.weight": "te_text_model.encoder.layers.20.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.20.ln_1.bias": "te_text_model.encoder.layers.20.layer_norm1.bias",
@@ -172,8 +144,6 @@
"cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.weight": "te_text_model.encoder.layers.20.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.bias": "te_text_model.encoder.layers.20.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.weight": "te_text_model.encoder.layers.20.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias": "te_text_model.encoder.layers.21.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight": "te_text_model.encoder.layers.21.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.21.attn.out_proj.bias": "te_text_model.encoder.layers.21.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.21.attn.out_proj.weight": "te_text_model.encoder.layers.21.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.21.ln_1.bias": "te_text_model.encoder.layers.21.layer_norm1.bias",
@@ -184,8 +154,6 @@
"cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.weight": "te_text_model.encoder.layers.21.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.bias": "te_text_model.encoder.layers.21.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.weight": "te_text_model.encoder.layers.21.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias": "te_text_model.encoder.layers.22.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight": "te_text_model.encoder.layers.22.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.22.attn.out_proj.bias": "te_text_model.encoder.layers.22.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight": "te_text_model.encoder.layers.22.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.22.ln_1.bias": "te_text_model.encoder.layers.22.layer_norm1.bias",
@@ -196,8 +164,6 @@
"cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.weight": "te_text_model.encoder.layers.22.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.bias": "te_text_model.encoder.layers.22.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.weight": "te_text_model.encoder.layers.22.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias": "te_text_model.encoder.layers.3.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight": "te_text_model.encoder.layers.3.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.3.attn.out_proj.bias": "te_text_model.encoder.layers.3.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.3.attn.out_proj.weight": "te_text_model.encoder.layers.3.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.3.ln_1.bias": "te_text_model.encoder.layers.3.layer_norm1.bias",
@@ -208,8 +174,6 @@
"cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.weight": "te_text_model.encoder.layers.3.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.bias": "te_text_model.encoder.layers.3.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.weight": "te_text_model.encoder.layers.3.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias": "te_text_model.encoder.layers.4.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight": "te_text_model.encoder.layers.4.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.4.attn.out_proj.bias": "te_text_model.encoder.layers.4.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.4.attn.out_proj.weight": "te_text_model.encoder.layers.4.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.4.ln_1.bias": "te_text_model.encoder.layers.4.layer_norm1.bias",
@@ -220,8 +184,6 @@
"cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.weight": "te_text_model.encoder.layers.4.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.bias": "te_text_model.encoder.layers.4.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.weight": "te_text_model.encoder.layers.4.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias": "te_text_model.encoder.layers.5.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight": "te_text_model.encoder.layers.5.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.5.attn.out_proj.bias": "te_text_model.encoder.layers.5.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.5.attn.out_proj.weight": "te_text_model.encoder.layers.5.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.5.ln_1.bias": "te_text_model.encoder.layers.5.layer_norm1.bias",
@@ -232,8 +194,6 @@
"cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.weight": "te_text_model.encoder.layers.5.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.bias": "te_text_model.encoder.layers.5.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.weight": "te_text_model.encoder.layers.5.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias": "te_text_model.encoder.layers.6.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight": "te_text_model.encoder.layers.6.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.6.attn.out_proj.bias": "te_text_model.encoder.layers.6.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.6.attn.out_proj.weight": "te_text_model.encoder.layers.6.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.6.ln_1.bias": "te_text_model.encoder.layers.6.layer_norm1.bias",
@@ -244,8 +204,6 @@
"cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.weight": "te_text_model.encoder.layers.6.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.bias": "te_text_model.encoder.layers.6.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.weight": "te_text_model.encoder.layers.6.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias": "te_text_model.encoder.layers.7.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight": "te_text_model.encoder.layers.7.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.7.attn.out_proj.bias": "te_text_model.encoder.layers.7.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.7.attn.out_proj.weight": "te_text_model.encoder.layers.7.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.7.ln_1.bias": "te_text_model.encoder.layers.7.layer_norm1.bias",
@@ -256,8 +214,6 @@
"cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.weight": "te_text_model.encoder.layers.7.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.bias": "te_text_model.encoder.layers.7.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.weight": "te_text_model.encoder.layers.7.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias": "te_text_model.encoder.layers.8.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight": "te_text_model.encoder.layers.8.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.8.attn.out_proj.bias": "te_text_model.encoder.layers.8.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.8.attn.out_proj.weight": "te_text_model.encoder.layers.8.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.8.ln_1.bias": "te_text_model.encoder.layers.8.layer_norm1.bias",
@@ -268,8 +224,6 @@
"cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.weight": "te_text_model.encoder.layers.8.mlp.fc1.weight",
"cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.bias": "te_text_model.encoder.layers.8.mlp.fc2.bias",
"cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.weight": "te_text_model.encoder.layers.8.mlp.fc2.weight",
"cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias": "te_text_model.encoder.layers.9.self_attn.MERGED.bias",
"cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight": "te_text_model.encoder.layers.9.self_attn.MERGED.weight",
"cond_stage_model.model.transformer.resblocks.9.attn.out_proj.bias": "te_text_model.encoder.layers.9.self_attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.9.attn.out_proj.weight": "te_text_model.encoder.layers.9.self_attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.9.ln_1.bias": "te_text_model.encoder.layers.9.layer_norm1.bias",
@@ -530,11 +484,11 @@
"first_stage_model.quant_conv.weight": "vae_quant_conv.weight",
"model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias",
"model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight",
"model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
"model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
"model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight",
"model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias",
"model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight",
"model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
"model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
"model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight",
"model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias",
"model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight",
@@ -566,31 +520,31 @@
"model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight",
"model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias",
"model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight",
"model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.conv1.bias",
"model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_time_embedding.linear_2.weight",
"model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias",
"model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight",
"model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias",
"model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight",
"model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias",
"model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias",
"model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight",
"model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias",
"model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight",
"model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias",
"model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight",
"model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.conv1.bias",
"model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight",
"model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias",
"model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight",
"model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias",
"model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight",
"model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias",
"model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias",
"model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight",
"model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias",
"model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight",
"model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias",
"model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight",
"model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
"model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
"model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight",
"model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias",
"model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight",
"model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
"model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
"model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight",
"model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias",
"model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight",
@@ -624,11 +578,11 @@
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight",
"model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias",
"model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight",
"model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
"model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
"model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight",
"model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias",
"model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight",
"model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
"model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
"model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight",
"model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias",
"model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight",
@@ -662,11 +616,11 @@
"model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
"model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
"model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
"model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
"model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
"model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight",
"model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias",
"model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight",
"model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
"model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
"model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight",
"model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias",
"model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight",
@@ -700,11 +654,11 @@
"model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
"model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias",
"model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight",
"model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
"model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
"model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight",
"model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias",
"model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight",
"model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
"model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
"model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight",
"model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias",
"model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight",
@@ -738,11 +692,11 @@
"model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
"model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
"model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
"model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
"model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
"model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight",
"model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias",
"model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight",
"model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
"model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
"model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight",
"model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias",
"model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight",
@@ -776,11 +730,11 @@
"model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
"model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias",
"model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight",
"model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.conv1.bias",
"model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight",
"model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
"model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
"model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias",
"model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight",
"model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
"model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias",
"model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight",
"model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias",
"model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight",
@@ -812,11 +766,11 @@
"model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight",
"model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias",
"model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight",
"model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.conv1.bias",
"model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
"model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias",
"model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight",
"model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias",
"model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight",
"model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.time_emb_proj.bias",
"model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias",
"model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight",
"model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias",
"model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight",
@@ -826,11 +780,11 @@
"model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight",
"model.diffusion_model.out.2.bias": "unet_conv_out.bias",
"model.diffusion_model.out.2.weight": "unet_conv_out.weight",
"model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
"model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
"model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
"model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
"model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias",
"model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight",
"model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
"model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
"model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight",
"model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias",
"model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight",
@@ -838,11 +792,11 @@
"model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight",
"model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias",
"model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight",
"model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
"model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
"model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
"model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
"model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias",
"model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight",
"model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
"model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
"model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight",
"model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias",
"model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight",
@@ -850,11 +804,11 @@
"model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight",
"model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias",
"model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight",
"model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.conv1.bias",
"model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias",
"model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight",
"model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias",
"model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight",
"model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias",
"model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias",
"model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight",
"model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias",
"model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight",
@@ -888,11 +842,11 @@
"model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight",
"model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias",
"model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight",
"model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.conv1.bias",
"model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias",
"model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight",
"model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias",
"model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight",
"model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias",
"model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias",
"model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight",
"model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias",
"model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight",
@@ -926,11 +880,11 @@
"model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight",
"model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias",
"model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight",
"model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
"model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
"model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
"model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
"model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias",
"model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight",
"model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
"model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
"model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight",
"model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias",
"model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight",
@@ -940,11 +894,11 @@
"model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight",
"model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias",
"model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight",
"model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
"model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
"model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight",
"model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias",
"model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight",
"model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
"model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
"model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight",
"model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias",
"model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight",
@@ -978,11 +932,11 @@
"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
"model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
"model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
"model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight",
"model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias",
"model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight",
"model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
"model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
"model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight",
"model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias",
"model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight",
@@ -1016,11 +970,11 @@
"model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
"model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
"model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
"model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
"model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
"model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight",
"model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias",
"model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight",
"model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
"model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
"model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight",
"model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias",
"model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight",
@@ -1056,11 +1010,11 @@
"model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight",
"model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias",
"model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight",
"model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
"model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
"model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight",
"model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias",
"model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight",
"model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
"model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
"model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight",
"model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias",
"model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight",
@@ -1094,11 +1048,11 @@
"model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
"model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
"model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
"model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
"model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
"model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight",
"model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias",
"model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight",
"model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
"model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
"model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight",
"model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias",
"model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight",
@@ -1132,11 +1086,11 @@
"model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
"model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
"model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
"model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
"model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
"model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight",
"model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias",
"model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight",
"model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
"model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
"model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight",
"model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias",
"model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight",
@@ -1172,11 +1126,11 @@
"model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight",
"model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias",
"model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight",
"model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.conv1.bias",
"model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias",
"model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight",
"model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias",
"model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight",
"model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias",
"model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias",
"model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight",
"model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias",
"model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight",
@@ -1213,7 +1167,7 @@
"model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias",
"model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight",
"model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias",
"model.diffusion_model.time_embed.2.weight": "unet_mid_block.resnets.1.time_emb_proj.weight"
"model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight"
},
"ldm_diffusers_shape_map": {
"first_stage_model.decoder.mid.attn_1.k.weight": [
@@ -1264,62 +1218,6 @@
512
]
],
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": [
[
128,
256,
1,
1
],
[
128,
256,
1,
1
]
],
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": [
[
256,
512,
1,
1
],
[
256,
512,
1,
1
]
],
"first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": [
[
256,
128,
1,
1
],
[
256,
128,
1,
1
]
],
"first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": [
[
512,
256,
1,
1
],
[
512,
256,
1,
1
]
],
"first_stage_model.encoder.mid.attn_1.k.weight": [
[
512,
@@ -1367,230 +1265,6 @@
512,
512
]
],
"first_stage_model.post_quant_conv.weight": [
[
4,
4,
1,
1
],
[
4,
4,
1,
1
]
],
"first_stage_model.quant_conv.weight": [
[
8,
8,
1,
1
],
[
8,
8,
1,
1
]
],
"model.diffusion_model.input_blocks.4.0.skip_connection.weight": [
[
640,
320,
1,
1
],
[
640,
320,
1,
1
]
],
"model.diffusion_model.input_blocks.7.0.skip_connection.weight": [
[
1280,
640,
1,
1
],
[
1280,
640,
1,
1
]
],
"model.diffusion_model.output_blocks.0.0.skip_connection.weight": [
[
1280,
2560,
1,
1
],
[
1280,
2560,
1,
1
]
],
"model.diffusion_model.output_blocks.1.0.skip_connection.weight": [
[
1280,
2560,
1,
1
],
[
1280,
2560,
1,
1
]
],
"model.diffusion_model.output_blocks.10.0.skip_connection.weight": [
[
320,
640,
1,
1
],
[
320,
640,
1,
1
]
],
"model.diffusion_model.output_blocks.11.0.skip_connection.weight": [
[
320,
640,
1,
1
],
[
320,
640,
1,
1
]
],
"model.diffusion_model.output_blocks.2.0.skip_connection.weight": [
[
1280,
2560,
1,
1
],
[
1280,
2560,
1,
1
]
],
"model.diffusion_model.output_blocks.3.0.skip_connection.weight": [
[
1280,
2560,
1,
1
],
[
1280,
2560,
1,
1
]
],
"model.diffusion_model.output_blocks.4.0.skip_connection.weight": [
[
1280,
2560,
1,
1
],
[
1280,
2560,
1,
1
]
],
"model.diffusion_model.output_blocks.5.0.skip_connection.weight": [
[
1280,
1920,
1,
1
],
[
1280,
1920,
1,
1
]
],
"model.diffusion_model.output_blocks.6.0.skip_connection.weight": [
[
640,
1920,
1,
1
],
[
640,
1920,
1,
1
]
],
"model.diffusion_model.output_blocks.7.0.skip_connection.weight": [
[
640,
1280,
1,
1
],
[
640,
1280,
1,
1
]
],
"model.diffusion_model.output_blocks.8.0.skip_connection.weight": [
[
640,
960,
1,
1
],
[
640,
960,
1,
1
]
],
"model.diffusion_model.output_blocks.9.0.skip_connection.weight": [
[
320,
960,
1,
1
],
[
320,
960,
1,
1
]
]
},
"ldm_diffusers_operator_map": {
@@ -1606,8 +1280,7 @@
"te_text_model.encoder.layers.0.self_attn.q_proj.weight",
"te_text_model.encoder.layers.0.self_attn.k_proj.weight",
"te_text_model.encoder.layers.0.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.0.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias": {
"cat": [
@@ -1621,8 +1294,7 @@
"te_text_model.encoder.layers.1.self_attn.q_proj.weight",
"te_text_model.encoder.layers.1.self_attn.k_proj.weight",
"te_text_model.encoder.layers.1.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.1.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias": {
"cat": [
@@ -1636,8 +1308,7 @@
"te_text_model.encoder.layers.10.self_attn.q_proj.weight",
"te_text_model.encoder.layers.10.self_attn.k_proj.weight",
"te_text_model.encoder.layers.10.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.10.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias": {
"cat": [
@@ -1651,8 +1322,7 @@
"te_text_model.encoder.layers.11.self_attn.q_proj.weight",
"te_text_model.encoder.layers.11.self_attn.k_proj.weight",
"te_text_model.encoder.layers.11.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.11.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias": {
"cat": [
@@ -1666,8 +1336,7 @@
"te_text_model.encoder.layers.12.self_attn.q_proj.weight",
"te_text_model.encoder.layers.12.self_attn.k_proj.weight",
"te_text_model.encoder.layers.12.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.12.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias": {
"cat": [
@@ -1681,8 +1350,7 @@
"te_text_model.encoder.layers.13.self_attn.q_proj.weight",
"te_text_model.encoder.layers.13.self_attn.k_proj.weight",
"te_text_model.encoder.layers.13.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.13.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias": {
"cat": [
@@ -1696,8 +1364,7 @@
"te_text_model.encoder.layers.14.self_attn.q_proj.weight",
"te_text_model.encoder.layers.14.self_attn.k_proj.weight",
"te_text_model.encoder.layers.14.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.14.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias": {
"cat": [
@@ -1711,8 +1378,7 @@
"te_text_model.encoder.layers.15.self_attn.q_proj.weight",
"te_text_model.encoder.layers.15.self_attn.k_proj.weight",
"te_text_model.encoder.layers.15.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.15.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias": {
"cat": [
@@ -1726,8 +1392,7 @@
"te_text_model.encoder.layers.16.self_attn.q_proj.weight",
"te_text_model.encoder.layers.16.self_attn.k_proj.weight",
"te_text_model.encoder.layers.16.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.16.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias": {
"cat": [
@@ -1741,8 +1406,7 @@
"te_text_model.encoder.layers.17.self_attn.q_proj.weight",
"te_text_model.encoder.layers.17.self_attn.k_proj.weight",
"te_text_model.encoder.layers.17.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.17.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias": {
"cat": [
@@ -1756,8 +1420,7 @@
"te_text_model.encoder.layers.18.self_attn.q_proj.weight",
"te_text_model.encoder.layers.18.self_attn.k_proj.weight",
"te_text_model.encoder.layers.18.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.18.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias": {
"cat": [
@@ -1771,8 +1434,7 @@
"te_text_model.encoder.layers.19.self_attn.q_proj.weight",
"te_text_model.encoder.layers.19.self_attn.k_proj.weight",
"te_text_model.encoder.layers.19.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.19.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias": {
"cat": [
@@ -1786,8 +1448,7 @@
"te_text_model.encoder.layers.2.self_attn.q_proj.weight",
"te_text_model.encoder.layers.2.self_attn.k_proj.weight",
"te_text_model.encoder.layers.2.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.2.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias": {
"cat": [
@@ -1801,8 +1462,7 @@
"te_text_model.encoder.layers.20.self_attn.q_proj.weight",
"te_text_model.encoder.layers.20.self_attn.k_proj.weight",
"te_text_model.encoder.layers.20.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.20.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias": {
"cat": [
@@ -1816,8 +1476,7 @@
"te_text_model.encoder.layers.21.self_attn.q_proj.weight",
"te_text_model.encoder.layers.21.self_attn.k_proj.weight",
"te_text_model.encoder.layers.21.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.21.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias": {
"cat": [
@@ -1831,8 +1490,7 @@
"te_text_model.encoder.layers.22.self_attn.q_proj.weight",
"te_text_model.encoder.layers.22.self_attn.k_proj.weight",
"te_text_model.encoder.layers.22.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.22.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias": {
"cat": [
@@ -1846,8 +1504,7 @@
"te_text_model.encoder.layers.3.self_attn.q_proj.weight",
"te_text_model.encoder.layers.3.self_attn.k_proj.weight",
"te_text_model.encoder.layers.3.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.3.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias": {
"cat": [
@@ -1861,8 +1518,7 @@
"te_text_model.encoder.layers.4.self_attn.q_proj.weight",
"te_text_model.encoder.layers.4.self_attn.k_proj.weight",
"te_text_model.encoder.layers.4.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.4.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias": {
"cat": [
@@ -1876,8 +1532,7 @@
"te_text_model.encoder.layers.5.self_attn.q_proj.weight",
"te_text_model.encoder.layers.5.self_attn.k_proj.weight",
"te_text_model.encoder.layers.5.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.5.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias": {
"cat": [
@@ -1891,8 +1546,7 @@
"te_text_model.encoder.layers.6.self_attn.q_proj.weight",
"te_text_model.encoder.layers.6.self_attn.k_proj.weight",
"te_text_model.encoder.layers.6.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.6.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias": {
"cat": [
@@ -1906,8 +1560,7 @@
"te_text_model.encoder.layers.7.self_attn.q_proj.weight",
"te_text_model.encoder.layers.7.self_attn.k_proj.weight",
"te_text_model.encoder.layers.7.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.7.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias": {
"cat": [
@@ -1921,8 +1574,7 @@
"te_text_model.encoder.layers.8.self_attn.q_proj.weight",
"te_text_model.encoder.layers.8.self_attn.k_proj.weight",
"te_text_model.encoder.layers.8.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.8.self_attn.MERGED.weight"
]
},
"cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias": {
"cat": [
@@ -1936,8 +1588,7 @@
"te_text_model.encoder.layers.9.self_attn.q_proj.weight",
"te_text_model.encoder.layers.9.self_attn.k_proj.weight",
"te_text_model.encoder.layers.9.self_attn.v_proj.weight"
],
"target": "te_text_model.encoder.layers.9.self_attn.MERGED.weight"
]
}
},
"diffusers_ldm_operator_map": {

File diff suppressed because it is too large Load Diff

View File

@@ -6,16 +6,12 @@
77
],
"min": 0.0,
"max": 76.0,
"mean": 38.0,
"std": 22.375
"max": 76.0
},
"conditioner.embedders.1.model.logit_scale": {
"shape": [],
"min": 4.60546875,
"max": 4.60546875,
"mean": 4.60546875,
"std": NaN
"max": 4.60546875
},
"conditioner.embedders.1.model.text_projection": {
"shape": [
@@ -23,9 +19,7 @@
1280
],
"min": -0.15966796875,
"max": 0.230712890625,
"mean": 0.0,
"std": 0.0181732177734375
"max": 0.230712890625
}
},
"diffusers": {
@@ -35,9 +29,7 @@
1280
],
"min": -0.15966796875,
"max": 0.230712890625,
"mean": 2.128152846125886e-05,
"std": 0.018169498071074486
"max": 0.230712890625
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
{
"ldm": {
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids": {
"shape": [
1,
77
],
"min": 0.0,
"max": 76.0
},
"conditioner.embedders.1.model.text_model.embeddings.position_ids": {
"shape": [
1,
77
],
"min": 0.0,
"max": 76.0
}
},
"diffusers": {}
}

View File

@@ -5,14 +5,18 @@ import itertools
class LosslessLatentDecoder(nn.Module):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
super(LosslessLatentDecoder, self).__init__()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.latent_depth = latent_depth
self.in_channels = in_channels
self.out_channels = int(in_channels // (latent_depth * latent_depth))
numpy_kernel = self.build_kernel(in_channels, latent_depth)
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
if trainable:
self.kernel = nn.Parameter(numpy_kernel)
else:
self.kernel = numpy_kernel
def build_kernel(self, in_channels, latent_depth):
# my old code from tensorflow.
@@ -35,19 +39,27 @@ class LosslessLatentDecoder(nn.Module):
return kernel
def forward(self, x):
dtype = x.dtype
if self.kernel.dtype != dtype:
self.kernel = self.kernel.to(dtype=dtype)
# Deconvolve input tensor with the kernel
return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
class LosslessLatentEncoder(nn.Module):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
super(LosslessLatentEncoder, self).__init__()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.latent_depth = latent_depth
self.in_channels = in_channels
self.out_channels = int(in_channels * (latent_depth * latent_depth))
numpy_kernel = self.build_kernel(in_channels, latent_depth)
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
if trainable:
self.kernel = nn.Parameter(numpy_kernel)
else:
self.kernel = numpy_kernel
def build_kernel(self, in_channels, latent_depth):
@@ -70,18 +82,21 @@ class LosslessLatentEncoder(nn.Module):
return kernel
def forward(self, x):
dtype = x.dtype
if self.kernel.dtype != dtype:
self.kernel = self.kernel.to(dtype=dtype)
# Convolve input tensor with the kernel
return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
class LosslessLatentVAE(nn.Module):
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
super(LosslessLatentVAE, self).__init__()
self.latent_depth = latent_depth
self.in_channels = in_channels
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype)
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable)
encoder_out_channels = self.encoder.out_channels
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype)
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable)
def forward(self, x):
latent = self.latent_encoder(x)
@@ -101,7 +116,7 @@ if __name__ == '__main__':
from PIL import Image
import torchvision.transforms as transforms
user_path = os.path.expanduser('~')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
input_path = os.path.join(user_path, "Pictures/sample_2_512.png")

View File

@@ -1,243 +0,0 @@
# ref:
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
# - https://github.com/p1atdev/LECO/blob/main/lora.py
import os
import math
from typing import Optional, List, Type, Set, Literal
from collections import OrderedDict
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from safetensors.torch import save_file
from toolkit.metadata import add_model_hash_to_meta
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
"Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
]
UNET_TARGET_REPLACE_MODULE_CONV = [
"ResnetBlock2D",
"Downsample2D",
"Upsample2D",
] # locon, 3clier
LORA_PREFIX_UNET = "lora_unet"
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
TRAINING_METHODS = Literal[
"noxattn", # train all layers except x-attns and time_embed layers
"innoxattn", # train all layers except self attention layers
"selfattn", # ESD-u, train only self attention layers
"xattn", # ESD-x, train only x attention layers
"full", # train all layers
# "notime",
# "xlayer",
# "outxattn",
# "outsattn",
# "inxattn",
# "inmidsattn",
# "selflayer",
]
class LoRAModule(nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Linear":
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
elif org_module.__class__.__name__ == "Conv2d": # 一応
in_dim = org_module.in_channels
out_dim = org_module.out_channels
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
if self.lora_dim != lora_dim:
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = nn.Conv2d(
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
)
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.org_module = org_module # remove in applying
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
return (
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
)
class LoRANetwork(nn.Module):
def __init__(
self,
unet: UNet2DConditionModel,
rank: int = 4,
multiplier: float = 1.0,
alpha: float = 1.0,
train_method: TRAINING_METHODS = "full",
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = rank
self.alpha = alpha
# LoRAのみ
self.module = LoRAModule
# unetのloraを作る
self.unet_loras = self.create_modules(
LORA_PREFIX_UNET,
unet,
DEFAULT_TARGET_REPLACE,
self.lora_dim,
self.multiplier,
train_method=train_method,
)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
# assertion 名前の被りがないか確認しているようだ
lora_names = set()
for lora in self.unet_loras:
assert (
lora.lora_name not in lora_names
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
lora_names.add(lora.lora_name)
# 適用する
for lora in self.unet_loras:
lora.apply_to()
self.add_module(
lora.lora_name,
lora,
)
del unet
torch.cuda.empty_cache()
def create_modules(
self,
prefix: str,
root_module: nn.Module,
target_replace_modules: List[str],
rank: int,
multiplier: float,
train_method: TRAINING_METHODS,
) -> list:
loras = []
for name, module in root_module.named_modules():
if train_method == "noxattn": # Cross Attention と Time Embed 以外学習
if "attn2" in name or "time_embed" in name:
continue
elif train_method == "innoxattn": # Cross Attention 以外学習
if "attn2" in name:
continue
elif train_method == "selfattn": # Self Attention のみ学習
if "attn1" not in name:
continue
elif train_method == "xattn": # Cross Attention のみ学習
if "attn2" not in name:
continue
elif train_method == "full": # 全部学習
pass
else:
raise NotImplementedError(
f"train_method: {train_method} is not implemented."
)
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ in ["Linear", "Conv2d"]:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
print(f"{lora_name}")
lora = self.module(
lora_name, child_module, multiplier, rank, self.alpha
)
loras.append(lora)
return loras
def prepare_optimizer_params(self):
all_params = []
if self.unet_loras: # 実質これしかない
params = []
[params.extend(lora.parameters()) for lora in self.unet_loras]
param_data = {"params": params}
all_params.append(param_data)
return all_params
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
state_dict = self.state_dict()
if metadata is None:
metadata = OrderedDict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
for key in list(state_dict.keys()):
if not key.startswith("lora"):
# remove any not lora
del state_dict[key]
metadata = add_model_hash_to_meta(state_dict, metadata)
if os.path.splitext(file)[1] == ".safetensors":
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def __enter__(self):
for lora in self.unet_loras:
lora.multiplier = 1.0
def __exit__(self, exc_type, exc_value, tb):
for lora in self.unet_loras:
lora.multiplier = 0

View File

@@ -1,14 +1,16 @@
import json
import math
import os
import re
import sys
from typing import List, Optional, Dict, Type, Union
import torch
from transformers import CLIPTextModel
from .config_modules import NetworkConfig
from .lorm import count_parameters
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
from .paths import SD_SCRIPTS_ROOT
from .train_tools import get_torch_dtype
sys.path.append(SD_SCRIPTS_ROOT)
@@ -19,7 +21,18 @@ from torch.utils.checkpoint import checkpoint
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
class LoRAModule(torch.nn.Module):
# diffusers specific stuff
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
# 'GroupNorm',
]
CONV_MODULES = [
'Conv2d',
'LoRACompatibleConv'
]
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
@@ -34,12 +47,20 @@ class LoRAModule(torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
network: 'LoRASpecialNetwork' = None,
use_bias: bool = False,
**kwargs
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
ToolkitModuleMixin.__init__(self, network=network)
torch.nn.Module.__init__(self)
self.lora_name = lora_name
self.scalar = torch.tensor(1.0)
# check if parent has bias. if not force use_bias to False
if org_module.bias is None:
use_bias = False
if org_module.__class__.__name__ == "Conv2d":
if org_module.__class__.__name__ in CONV_MODULES:
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
@@ -53,15 +74,15 @@ class LoRAModule(torch.nn.Module):
# else:
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
if org_module.__class__.__name__ in CONV_MODULES:
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
@@ -74,157 +95,20 @@ class LoRAModule(torch.nn.Module):
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier: Union[float, List[float]] = multiplier
self.org_module = org_module # remove in applying
# wrap the original module so it doesn't get weights updated
self.org_module = [org_module]
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
# this allows us to set different multipliers on a per item in a batch basis
# allowing us to run positive and negative weights in the same batch
# really only useful for slider training for now
def get_multiplier(self, lora_up):
with torch.no_grad():
batch_size = lora_up.size(0)
# batch will have all negative prompts first and positive prompts second
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
# if there is more than our multiplier, it is likely a batch size increase, so we need to
# interleave the multipliers
if isinstance(self.multiplier, list):
if len(self.multiplier) == 0:
# single item, just return it
return self.multiplier[0]
elif len(self.multiplier) == batch_size:
# not doing CFG
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
else:
# we have a list of multipliers, so we need to get the multiplier for this batch
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
# should be 1 for if total batch size was 1
num_interleaves = (batch_size // 2) // len(self.multiplier)
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
# match lora_up rank
if len(lora_up.size()) == 2:
multiplier_tensor = multiplier_tensor.view(-1, 1)
elif len(lora_up.size()) == 3:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
elif len(lora_up.size()) == 4:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
return multiplier_tensor.detach()
else:
return self.multiplier
def _call_forward(self, x):
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return 0.0 # added to original forward
lx = self.lora_down(x)
# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
return lx * scale
def forward(self, x):
org_forwarded = self.org_forward(x)
lora_output = self._call_forward(x)
if self.is_normalizing:
with torch.no_grad():
# do this calculation without multiplier
# get a dim array from orig forward that had index of all dimensions except the batch and channel
# Calculate the target magnitude for the combined output
orig_max = torch.max(torch.abs(org_forwarded))
# Calculate the additional increase in magnitude that lora_output would introduce
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded))
epsilon = 1e-6 # Small constant to avoid division by zero
# Calculate the scaling factor for the lora_output
# to ensure that the potential increase in magnitude doesn't change the original max
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
normalize_scaler = normalize_scaler.detach()
# save the scaler so it can be applied later
self.normalize_scaler = normalize_scaler.clone().detach()
lora_output *= normalize_scaler
multiplier = self.get_multiplier(lora_output)
return org_forwarded + (lora_output * multiplier)
def enable_gradient_checkpointing(self):
self.is_checkpointing = True
def disable_gradient_checkpointing(self):
self.is_checkpointing = False
@torch.no_grad()
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
"""
Applied the previous normalization calculation to the module.
This must be called before saving or normalization will be lost.
It is probably best to call after each batch as well.
We just scale the up down weights to match this vector
:return:
"""
# get state dict
state_dict = self.state_dict()
dtype = state_dict['lora_up.weight'].dtype
device = state_dict['lora_up.weight'].device
# todo should we do this at fp32?
total_module_scale = torch.tensor(self.normalize_scaler / target_normalize_scaler) \
.to(device, dtype=dtype)
num_modules_layers = 2 # up and down
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
.to(device, dtype=dtype)
# apply the scaler to the up and down weights
for key in state_dict.keys():
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
# do it inplace do params are updated
state_dict[key] *= up_down_scale
# reset the normalization scaler
self.normalize_scaler = target_normalize_scaler
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
# del self.org_module
class LoRASpecialNetwork(LoRANetwork):
class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
@@ -258,7 +142,18 @@ class LoRASpecialNetwork(LoRANetwork):
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
train_text_encoder: Optional[bool] = True,
use_text_encoder_1: bool = True,
use_text_encoder_2: bool = True,
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
use_bias: bool = False,
is_lorm: bool = False,
ignore_if_contains = None,
parameter_threshold: float = 0.0,
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
**kwargs
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -269,8 +164,19 @@ class LoRASpecialNetwork(LoRANetwork):
5. modules_dimとmodules_alphaを指定 (推論用)
"""
# call the parent of the parent we are replacing (LoRANetwork) init
super(LoRANetwork, self).__init__()
torch.nn.Module.__init__(self)
ToolkitNetworkMixin.__init__(
self,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
is_sdxl=is_sdxl,
is_v2=is_v2,
is_lorm=is_lorm,
**kwargs
)
if ignore_if_contains is None:
ignore_if_contains = []
self.ignore_if_contains = ignore_if_contains
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
@@ -281,9 +187,11 @@ class LoRASpecialNetwork(LoRANetwork):
self.is_checkpointing = False
self._multiplier: float = 1.0
self.is_active: bool = False
self._is_normalizing: bool = False
self.torch_multiplier = None
# triggers the state updates
self.multiplier = multiplier
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
if modules_dim is not None:
print(f"create LoRA network from weights")
@@ -325,11 +233,19 @@ class LoRASpecialNetwork(LoRANetwork):
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_linear = child_module.__class__.__name__ in LINEAR_MODULES
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
skip = False
if any([word in child_name for word in self.ignore_if_contains]):
skip = True
# see if it is over threshold
if count_parameters(child_module) < parameter_threshold:
skip = True
if (is_linear or is_conv2d) and not skip:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
@@ -375,6 +291,9 @@ class LoRASpecialNetwork(LoRANetwork):
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
network=self,
parent=module,
use_bias=use_bias,
)
loras.append(lora)
return loras, skipped
@@ -387,6 +306,10 @@ class LoRASpecialNetwork(LoRANetwork):
skipped_te = []
if train_text_encoder:
for i, text_encoder in enumerate(text_encoders):
if not use_text_encoder_1 and i == 0:
continue
if not use_text_encoder_2 and i == 1:
continue
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
@@ -401,9 +324,9 @@ class LoRASpecialNetwork(LoRANetwork):
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
target_modules = target_lin_modules
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
target_modules += target_conv_modules
if train_unet:
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
@@ -430,106 +353,3 @@ class LoRASpecialNetwork(LoRANetwork):
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
@property
def multiplier(self) -> Union[float, List[float]]:
return self._multiplier
@multiplier.setter
def multiplier(self, value: Union[float, List[float]]):
self._multiplier = value
self._update_lora_multiplier()
def _update_lora_multiplier(self):
if self.is_active:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.multiplier = self._multiplier
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.multiplier = self._multiplier
else:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.multiplier = 0
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.multiplier = 0
# called when the context manager is entered
# ie: with network:
def __enter__(self):
self.is_active = True
self._update_lora_multiplier()
def __exit__(self, exc_type, exc_value, tb):
self.is_active = False
self._update_lora_multiplier()
def force_to(self, device, dtype):
self.to(device, dtype)
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
for lora in loras:
lora.to(device, dtype)
def get_all_modules(self):
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
return loras
def _update_checkpointing(self):
for module in self.get_all_modules():
if self.is_checkpointing:
module.enable_gradient_checkpointing()
else:
module.disable_gradient_checkpointing()
def enable_gradient_checkpointing(self):
# not supported
self.is_checkpointing = True
self._update_checkpointing()
def disable_gradient_checkpointing(self):
# not supported
self.is_checkpointing = False
self._update_checkpointing()
@property
def is_normalizing(self) -> bool:
return self._is_normalizing
@is_normalizing.setter
def is_normalizing(self, value: bool):
self._is_normalizing = value
for module in self.get_all_modules():
module.is_normalizing = self._is_normalizing
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
for module in self.get_all_modules():
module.apply_stored_normalizer(target_normalize_scaler)

460
toolkit/lorm.py Normal file
View File

@@ -0,0 +1,460 @@
from typing import Union, Tuple, Literal, Optional
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from torch import Tensor
from tqdm import tqdm
from toolkit.config_modules import LoRMConfig
conv = nn.Conv2d
lin = nn.Linear
_size_2_t = Union[int, Tuple[int, int]]
ExtractMode = Union[
'fixed',
'threshold',
'ratio',
'quantile',
'percentage'
]
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
]
CONV_MODULES = [
# 'Conv2d',
# 'LoRACompatibleConv'
]
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
# "ResnetBlock2D",
"Downsample2D",
"Upsample2D",
]
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE
UNET_TARGET_REPLACE_NAME = [
"conv_in",
"conv_out",
"time_embedding.linear_1",
"time_embedding.linear_2",
]
UNET_MODULES_TO_AVOID = [
]
# Low Rank Convolution
class LoRMCon2d(nn.Module):
def __init__(
self,
in_channels: int,
lorm_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 'same',
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super().__init__()
self.in_channels = in_channels
self.lorm_channels = lorm_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode
self.down = nn.Conv2d(
in_channels=in_channels,
out_channels=lorm_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=False,
padding_mode=padding_mode,
device=device,
dtype=dtype
)
# Kernel size on the up is always 1x1.
# I don't think you could calculate a dual 3x3, or I can't at least
self.up = nn.Conv2d(
in_channels=lorm_channels,
out_channels=out_channels,
kernel_size=(1, 1),
stride=1,
padding='same',
dilation=1,
groups=1,
bias=bias,
padding_mode='zeros',
device=device,
dtype=dtype
)
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
x = input
x = self.down(x)
x = self.up(x)
return x
class LoRMLinear(nn.Module):
def __init__(
self,
in_features: int,
lorm_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None
) -> None:
super().__init__()
self.in_features = in_features
self.lorm_features = lorm_features
self.out_features = out_features
self.down = nn.Linear(
in_features=in_features,
out_features=lorm_features,
bias=False,
device=device,
dtype=dtype
)
self.up = nn.Linear(
in_features=lorm_features,
out_features=out_features,
bias=bias,
# bias=True,
device=device,
dtype=dtype
)
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
x = input
x = self.down(x)
x = self.up(x)
return x
def extract_conv(
weight: Union[torch.Tensor, nn.Parameter],
mode='fixed',
mode_param=0,
device='cpu'
) -> Tuple[Tensor, Tensor, int, Tensor]:
weight = weight.to(device)
out_ch, in_ch, kernel_size, _ = weight.shape
U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1))
if mode == 'percentage':
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
original_params = out_ch * in_ch * kernel_size * kernel_size
desired_params = mode_param * original_params
# Solve for lora_rank from the equation
lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch))
elif mode == 'fixed':
lora_rank = mode_param
elif mode == 'threshold':
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param).item()
elif mode == 'ratio':
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s).item()
elif mode == 'quantile' or mode == 'percentile':
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum).item()
else:
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2:
lora_rank = int(out_ch / 2)
print(f"rank is higher than it should be")
# print(f"Skipping layer as determined rank is too high")
# return None, None, None, None
# return weight, 'full'
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
del U, S, Vh, weight
return extract_weight_A, extract_weight_B, lora_rank, diff
def extract_linear(
weight: Union[torch.Tensor, nn.Parameter],
mode='fixed',
mode_param=0,
device='cpu',
) -> Tuple[Tensor, Tensor, int, Tensor]:
weight = weight.to(device)
out_ch, in_ch = weight.shape
U, S, Vh = torch.linalg.svd(weight)
if mode == 'percentage':
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
desired_params = mode_param * out_ch * in_ch
# Solve for lora_rank from the equation
lora_rank = int(desired_params / (in_ch + out_ch))
elif mode == 'fixed':
lora_rank = mode_param
elif mode == 'threshold':
assert mode_param >= 0
lora_rank = torch.sum(S > mode_param).item()
elif mode == 'ratio':
assert 1 >= mode_param >= 0
min_s = torch.max(S) * mode_param
lora_rank = torch.sum(S > min_s).item()
elif mode == 'quantile':
assert 1 >= mode_param >= 0
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = mode_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum).item()
else:
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
lora_rank = max(1, lora_rank)
lora_rank = min(out_ch, in_ch, lora_rank)
if lora_rank >= out_ch / 2:
# print(f"rank is higher than it should be")
lora_rank = int(out_ch / 2)
# return weight, 'full'
# print(f"Skipping layer as determined rank is too high")
# return None, None, None, None
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
diff = (weight - U @ Vh).detach()
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
del U, S, Vh, weight
return extract_weight_A, extract_weight_B, lora_rank, diff
def replace_module_by_path(network, name, module):
"""Replace a module in a network by its name."""
name_parts = name.split('.')
current_module = network
for part in name_parts[:-1]:
current_module = getattr(current_module, part)
try:
setattr(current_module, name_parts[-1], module)
except Exception as e:
print(e)
def count_parameters(module):
return sum(p.numel() for p in module.parameters())
def compute_optimal_bias(original_module, linear_down, linear_up, X):
Y_original = original_module(X)
Y_approx = linear_up(linear_down(X))
E = Y_original - Y_approx
optimal_bias = E.mean(dim=0)
return optimal_bias
def format_with_commas(n):
return f"{n:,}"
def print_lorm_extract_details(
start_num_params: int,
end_num_params: int,
num_replaced: int,
):
start_formatted = format_with_commas(start_num_params)
end_formatted = format_with_commas(end_num_params)
num_replaced_formatted = format_with_commas(num_replaced)
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
print(f"Convert UNet result:")
print(f" - converted: {num_replaced:>{width},} modules")
print(f" - start: {start_num_params:>{width},} params")
print(f" - end: {end_num_params:>{width},} params")
lorm_ignore_if_contains = [
'proj_out', 'proj_in',
]
lorm_parameter_threshold = 1000000
@torch.no_grad()
def convert_diffusers_unet_to_lorm(
unet: UNet2DConditionModel,
config: LoRMConfig,
):
print('Converting UNet to LoRM UNet')
start_num_params = count_parameters(unet)
named_modules = list(unet.named_modules())
num_replaced = 0
pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet")
layer_names_replaced = []
converted_modules = []
ignore_if_contains = [
'proj_out', 'proj_in',
]
for name, module in named_modules:
module_name = module.__class__.__name__
if module_name in UNET_TARGET_REPLACE_MODULE:
for child_name, child_module in module.named_modules():
new_module: Union[LoRMCon2d, LoRMLinear, None] = None
# if child name includes attn, skip it
combined_name = combined_name = f"{name}.{child_name}"
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
# pass
lorm_config = config.get_config_for_module(combined_name)
extract_mode = lorm_config.extract_mode
extract_mode_param = lorm_config.extract_mode_param
parameter_threshold = lorm_config.parameter_threshold
if any([word in child_name for word in ignore_if_contains]):
pass
elif child_module.__class__.__name__ in LINEAR_MODULES:
if count_parameters(child_module) > parameter_threshold:
dtype = child_module.weight.dtype
# extract and convert
down_weight, up_weight, lora_dim, diff = extract_linear(
weight=child_module.weight.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=child_module.weight.device,
)
if down_weight is None:
continue
down_weight = down_weight.to(dtype=dtype)
up_weight = up_weight.to(dtype=dtype)
bias_weight = None
if child_module.bias is not None:
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
# linear layer weights = (out_features, in_features)
new_module = LoRMLinear(
in_features=down_weight.shape[1],
lorm_features=lora_dim,
out_features=up_weight.shape[0],
bias=bias_weight is not None,
device=down_weight.device,
dtype=down_weight.dtype
)
# replace the weights
new_module.down.weight.data = down_weight
new_module.up.weight.data = up_weight
if bias_weight is not None:
new_module.up.bias.data = bias_weight
# else:
# new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data)
# bias_correction = compute_optimal_bias(
# child_module,
# new_module.down,
# new_module.up,
# torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype)
# )
# new_module.up.bias.data += bias_correction
elif child_module.__class__.__name__ in CONV_MODULES:
if count_parameters(child_module) > parameter_threshold:
dtype = child_module.weight.dtype
down_weight, up_weight, lora_dim, diff = extract_conv(
weight=child_module.weight.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=child_module.weight.device,
)
if down_weight is None:
continue
down_weight = down_weight.to(dtype=dtype)
up_weight = up_weight.to(dtype=dtype)
bias_weight = None
if child_module.bias is not None:
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
new_module = LoRMCon2d(
in_channels=down_weight.shape[1],
lorm_channels=lora_dim,
out_channels=up_weight.shape[0],
kernel_size=child_module.kernel_size,
dilation=child_module.dilation,
padding=child_module.padding,
padding_mode=child_module.padding_mode,
stride=child_module.stride,
bias=bias_weight is not None,
device=down_weight.device,
dtype=down_weight.dtype
)
# replace the weights
new_module.down.weight.data = down_weight
new_module.up.weight.data = up_weight
if bias_weight is not None:
new_module.up.bias.data = bias_weight
if new_module:
combined_name = f"{name}.{child_name}"
replace_module_by_path(unet, combined_name, new_module)
converted_modules.append(new_module)
num_replaced += 1
layer_names_replaced.append(
f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
pbar.update(1)
pbar.close()
end_num_params = count_parameters(unet)
def sorting_key(s):
# Extract the number part, remove commas, and convert to integer
return int(s.split("-")[1].strip().replace(",", ""))
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
for layer_name in sorted_layer_names_replaced:
print(layer_name)
print_lorm_extract_details(
start_num_params=start_num_params,
end_num_params=end_num_params,
num_replaced=num_replaced,
)
return converted_modules

View File

@@ -27,11 +27,17 @@ class ComparativeTotalVariation(torch.nn.Module):
# Gradient penalty
def get_gradient_penalty(critic, real, fake, device):
with torch.autocast(device_type='cuda'):
alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
real = real.float()
fake = fake.float()
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
if torch.isnan(interpolates).any():
print('d_interpolates is nan')
d_interpolates = critic(interpolates)
fake = torch.ones(real.size(0), 1, device=device)
if torch.isnan(d_interpolates).any():
print('fake is nan')
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
@@ -41,10 +47,14 @@ def get_gradient_penalty(critic, real, fake, device):
only_inputs=True,
)[0]
# see if any are nan
if torch.isnan(gradients).any():
print('gradients is nan')
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty
return gradient_penalty.float()
class PatternLoss(torch.nn.Module):

373
toolkit/lycoris_special.py Normal file
View File

@@ -0,0 +1,373 @@
import math
import os
from typing import Optional, Union, List, Type
import torch
from lycoris.kohya import LycorisNetwork, LoConModule
from lycoris.modules.glora import GLoRAModule
from torch import nn
from transformers import CLIPTextModel
from torch.nn import functional as F
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
# diffusers specific stuff
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
]
CONV_MODULES = [
'Conv2d',
'LoRACompatibleConv'
]
class LoConSpecialModule(ToolkitModuleMixin, LoConModule, ExtractableModuleMixin):
def __init__(
self,
lora_name, org_module: nn.Module,
multiplier=1.0,
lora_dim=4, alpha=1,
dropout=0., rank_dropout=0., module_dropout=0.,
use_cp=False,
network: 'LycorisSpecialNetwork' = None,
use_bias=False,
**kwargs,
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
# call super of super
ToolkitModuleMixin.__init__(self, network=network)
torch.nn.Module.__init__(self)
self.lora_name = lora_name
self.lora_dim = lora_dim
self.cp = False
# check if parent has bias. if not force use_bias to False
if org_module.bias is None:
use_bias = False
self.scalar = nn.Parameter(torch.tensor(0.0))
orig_module_name = org_module.__class__.__name__
if orig_module_name in CONV_MODULES:
self.isconv = True
# For general LoCon
in_dim = org_module.in_channels
k_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
out_dim = org_module.out_channels
self.down_op = F.conv2d
self.up_op = F.conv2d
if use_cp and k_size != (1, 1):
self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False)
self.cp = True
else:
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=use_bias)
elif orig_module_name in LINEAR_MODULES:
self.isconv = False
self.down_op = F.linear
self.up_op = F.linear
if orig_module_name == 'GroupNorm':
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (56320x120 and 320x32)
in_dim = org_module.num_channels
out_dim = org_module.num_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=use_bias)
else:
raise NotImplementedError
self.shape = org_module.weight.shape
if dropout:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = nn.Identity()
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lora_up.weight)
if self.cp:
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
self.multiplier = multiplier
self.org_module = [org_module]
self.register_load_state_dict_post_hook(self.load_weight_hook)
def load_weight_hook(self, *args, **kwargs):
self.scalar = nn.Parameter(torch.ones_like(self.scalar))
class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
"ResnetBlock2D",
"Downsample2D",
"Upsample2D",
# 'UNet2DConditionModel',
# 'Conv2d',
# 'Timesteps',
# 'TimestepEmbedding',
# 'Linear',
# 'SiLU',
# 'ModuleList',
# 'DownBlock2D',
# 'ResnetBlock2D', # need
# 'GroupNorm',
# 'LoRACompatibleConv',
# 'LoRACompatibleLinear',
# 'Dropout',
# 'CrossAttnDownBlock2D', # needed
# 'Transformer2DModel', # maybe not, has duplicates
# 'BasicTransformerBlock', # duplicates
# 'LayerNorm',
# 'Attention',
# 'FeedForward',
# 'GEGLU',
# 'UpBlock2D',
# 'UNetMidBlock2DCrossAttn'
]
UNET_TARGET_REPLACE_NAME = [
"conv_in",
"conv_out",
"time_embedding.linear_1",
"time_embedding.linear_2",
]
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
use_cp: Optional[bool] = False,
network_module: Type[object] = LoConSpecialModule,
train_unet: bool = True,
train_text_encoder: bool = True,
use_text_encoder_1: bool = True,
use_text_encoder_2: bool = True,
use_bias: bool = False,
is_lorm: bool = False,
**kwargs,
) -> None:
# call ToolkitNetworkMixin super
ToolkitNetworkMixin.__init__(
self,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
is_lorm=is_lorm,
**kwargs
)
# call the parent of the parent LycorisNetwork
torch.nn.Module.__init__(self)
# LyCORIS unique stuff
if dropout is None:
dropout = 0
if rank_dropout is None:
rank_dropout = 0
if module_dropout is None:
module_dropout = 0
self.train_unet = train_unet
self.train_text_encoder = train_text_encoder
self.torch_multiplier = None
# triggers a tensor update
self.multiplier = multiplier
self.lora_dim = lora_dim
if not self.ENABLE_CONV or conv_lora_dim is None:
conv_lora_dim = 0
conv_alpha = 0
self.conv_lora_dim = int(conv_lora_dim)
if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim:
print('Apply different lora dim for conv layer')
print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}')
elif self.conv_lora_dim == 0:
print('Disable conv layer')
self.alpha = alpha
self.conv_alpha = float(conv_alpha)
if self.conv_lora_dim and self.alpha != self.conv_alpha:
print('Apply different alpha value for conv layer')
print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}')
if 1 >= dropout >= 0:
print(f'Use Dropout value: {dropout}')
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
# create module instances
def create_modules(
prefix,
root_module: torch.nn.Module,
target_replace_modules,
target_replace_names=[]
) -> List[network_module]:
print('Create LyCORIS Module')
loras = []
# remove this
named_modules = root_module.named_modules()
# add a few to tthe generator
for name, module in named_modules:
module_name = module.__class__.__name__
if module_name in target_replace_modules:
if module_name in self.MODULE_ALGO_MAP:
algo = self.MODULE_ALGO_MAP[module_name]
else:
algo = network_module
for child_name, child_module in module.named_modules():
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
if lora_name.startswith('lora_unet_input_blocks_1_0_emb_layers_1'):
print(f"{lora_name}")
if child_module.__class__.__name__ in LINEAR_MODULES and lora_dim > 0:
lora = algo(
lora_name, child_module, self.multiplier,
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
elif child_module.__class__.__name__ in CONV_MODULES:
k_size, *_ = child_module.kernel_size
if k_size == 1 and lora_dim > 0:
lora = algo(
lora_name, child_module, self.multiplier,
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
elif conv_lora_dim > 0:
lora = algo(
lora_name, child_module, self.multiplier,
self.conv_lora_dim, self.conv_alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
else:
continue
else:
continue
loras.append(lora)
elif name in target_replace_names:
if name in self.NAME_ALGO_MAP:
algo = self.NAME_ALGO_MAP[name]
else:
algo = network_module
lora_name = prefix + '.' + name
lora_name = lora_name.replace('.', '_')
if module.__class__.__name__ == 'Linear' and lora_dim > 0:
lora = algo(
lora_name, module, self.multiplier,
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
parent=module,
network=self,
use_bias=use_bias,
**kwargs
)
elif module.__class__.__name__ == 'Conv2d':
k_size, *_ = module.kernel_size
if k_size == 1 and lora_dim > 0:
lora = algo(
lora_name, module, self.multiplier,
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
elif conv_lora_dim > 0:
lora = algo(
lora_name, module, self.multiplier,
self.conv_lora_dim, self.conv_alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
use_bias=use_bias,
**kwargs
)
else:
continue
else:
continue
loras.append(lora)
return loras
if network_module == GLoRAModule:
print('GLoRA enabled, only train transformer')
# only train transformer (for GLoRA)
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
"Attention",
]
LycorisSpecialNetwork.UNET_TARGET_REPLACE_NAME = []
if isinstance(text_encoder, list):
text_encoders = text_encoder
use_index = True
else:
text_encoders = [text_encoder]
use_index = False
self.text_encoder_loras = []
if self.train_text_encoder:
for i, te in enumerate(text_encoders):
if not use_text_encoder_1 and i == 0:
continue
if not use_text_encoder_2 and i == 1:
continue
self.text_encoder_loras.extend(create_modules(
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
te,
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
))
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
if self.train_unet:
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
else:
self.unet_loras = []
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
self.weights_sd = None
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)

View File

@@ -77,6 +77,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
with safe_open(file_path, framework="pt") as f:
metadata = f.metadata()
return parse_metadata_from_safetensors(metadata)
try:
with safe_open(file_path, framework="pt") as f:
metadata = f.metadata()
return parse_metadata_from_safetensors(metadata)
except Exception as e:
print(f"Error loading metadata from {file_path}: {e}")
return OrderedDict()

566
toolkit/network_mixins.py Normal file
View File

@@ -0,0 +1,566 @@
import json
import os
from collections import OrderedDict
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal
import torch
from torch import nn
import weakref
from tqdm import tqdm
from toolkit.config_modules import NetworkConfig
from toolkit.lorm import extract_conv, extract_linear, count_parameters
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
if TYPE_CHECKING:
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
from toolkit.stable_diffusion_model import StableDiffusion
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
Module = Union['LoConSpecialModule', 'LoRAModule']
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
# 'GroupNorm',
]
CONV_MODULES = [
'Conv2d',
'LoRACompatibleConv'
]
ExtractMode = Union[
'existing'
'fixed',
'threshold',
'ratio',
'quantile',
'percentage'
]
def broadcast_and_multiply(tensor, multiplier):
# Determine the number of dimensions required
num_extra_dims = tensor.dim() - multiplier.dim()
# Unsqueezing the tensor to match the dimensionality
for _ in range(num_extra_dims):
multiplier = multiplier.unsqueeze(-1)
# Multiplying the broadcasted tensor with the output tensor
result = tensor * multiplier
return result
def add_bias(tensor, bias):
if bias is None:
return tensor
# add batch dim
bias = bias.unsqueeze(0)
bias = torch.cat([bias] * tensor.size(0), dim=0)
# Determine the number of dimensions required
num_extra_dims = tensor.dim() - bias.dim()
# Unsqueezing the tensor to match the dimensionality
for _ in range(num_extra_dims):
bias = bias.unsqueeze(-1)
# we may need to swap -1 for -2
if bias.size(1) != tensor.size(1):
if len(bias.size()) == 3:
bias = bias.permute(0, 2, 1)
elif len(bias.size()) == 4:
bias = bias.permute(0, 3, 1, 2)
# Multiplying the broadcasted tensor with the output tensor
try:
result = tensor + bias
except RuntimeError as e:
print(e)
print(tensor.size())
print(bias.size())
raise e
return result
class ExtractableModuleMixin:
def extract_weight(
self: Module,
extract_mode: ExtractMode = "existing",
extract_mode_param: Union[int, float] = None,
):
device = self.lora_down.weight.device
weight_to_extract = self.org_module[0].weight
if extract_mode == "existing":
extract_mode = 'fixed'
extract_mode_param = self.lora_dim
if self.org_module[0].__class__.__name__ in CONV_MODULES:
# do conv extraction
down_weight, up_weight, new_dim, diff = extract_conv(
weight=weight_to_extract.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=device
)
elif self.org_module[0].__class__.__name__ in LINEAR_MODULES:
# do linear extraction
down_weight, up_weight, new_dim, diff = extract_linear(
weight=weight_to_extract.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=device,
)
else:
raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}")
self.lora_dim = new_dim
# inject weights into the param
self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach()
self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach()
# copy bias if we have one and are using them
if self.org_module[0].bias is not None and self.lora_up.bias is not None:
self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach()
# set up alphas
self.alpha = (self.alpha * 0) + down_weight.shape[0]
self.scale = self.alpha / self.lora_dim
# assign them
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
# scaler is a parameter update the value with 1.0
self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype)
class ToolkitModuleMixin:
def __init__(
self: Module,
*args,
network: Network,
**kwargs
):
self.network_ref: weakref.ref = weakref.ref(network)
self.is_checkpointing = False
self._multiplier: Union[float, list, torch.Tensor] = None
def _call_forward(self: Module, x):
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return 0.0 # added to original forward
if hasattr(self, 'lora_mid') and self.lora_mid is not None:
lx = self.lora_mid(self.lora_down(x))
else:
try:
lx = self.lora_down(x)
except RuntimeError as e:
print(f"Error in {self.__class__.__name__} lora_down")
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
lx = self.dropout(lx)
# normal dropout
elif self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.rank_dropout > 0 and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
scale = scale * self.scalar
return lx * scale
def lorm_forward(self: Network, x, *args, **kwargs):
network: Network = self.network_ref()
if not network.is_active:
return self.org_forward(x, *args, **kwargs)
if network.lorm_train_mode == 'local':
# we are going to predict input with both and do a loss on them
inputs = x.detach()
with torch.no_grad():
# get the local prediction
target_pred = self.org_forward(inputs, *args, **kwargs).detach()
with torch.set_grad_enabled(True):
# make a prediction with the lorm
lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True)))
local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float())
# backpropr
local_loss.backward()
network.module_losses.append(local_loss.detach())
# return the original as we dont want our trainer to affect ones down the line
return target_pred
else:
return self.lora_up(self.lora_down(x))
def forward(self: Module, x, *args, **kwargs):
skip = False
network: Network = self.network_ref()
if network.is_lorm:
# we are doing lorm
return self.lorm_forward(x, *args, **kwargs)
# skip if not active
if not network.is_active:
skip = True
# skip if is merged in
if network.is_merged_in:
skip = True
# skip if multiplier is 0
if network._multiplier == 0:
skip = True
if skip:
# network is not active, avoid doing anything
return self.org_forward(x, *args, **kwargs)
org_forwarded = self.org_forward(x, *args, **kwargs)
lora_output = self._call_forward(x)
multiplier = self.network_ref().torch_multiplier
lora_output_batch_size = lora_output.size(0)
multiplier_batch_size = multiplier.size(0)
if lora_output_batch_size != multiplier_batch_size:
num_interleaves = lora_output_batch_size // multiplier_batch_size
# todo check if this is correct, do we just concat when doing cfg?
multiplier = multiplier.repeat_interleave(num_interleaves)
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
return x
def enable_gradient_checkpointing(self: Module):
self.is_checkpointing = True
def disable_gradient_checkpointing(self: Module):
self.is_checkpointing = False
@torch.no_grad()
def merge_out(self: Module, merge_out_weight=1.0):
# make sure it is positive
merge_out_weight = abs(merge_out_weight)
# merging out is just merging in the negative of the weight
self.merge_in(merge_weight=-merge_out_weight)
@torch.no_grad()
def merge_in(self: Module, merge_weight=1.0):
# get up/down weight
up_weight = self.lora_up.weight.clone().float()
down_weight = self.lora_down.weight.clone().float()
# extract weight from org_module
org_sd = self.org_module[0].state_dict()
orig_dtype = org_sd["weight"].dtype
weight = org_sd["weight"].float()
multiplier = merge_weight
scale = self.scale
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
scale = scale * self.scalar
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + multiplier * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + multiplier * conved * scale
# set weight to org_module
org_sd["weight"] = weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
# outputs the same. It is basically a LoRA but with the original module removed
# if a state dict is passed, use those weights instead of extracting
# todo load from state dict
network: Network = self.network_ref()
lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name)
extract_mode = lorm_config.extract_mode
extract_mode_param = lorm_config.extract_mode_param
parameter_threshold = lorm_config.parameter_threshold
self.extract_weight(
extract_mode=extract_mode,
extract_mode_param=extract_mode_param
)
class ToolkitNetworkMixin:
def __init__(
self: Network,
*args,
train_text_encoder: Optional[bool] = True,
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
network_config: Optional[NetworkConfig] = None,
is_lorm=False,
**kwargs
):
self.train_text_encoder = train_text_encoder
self.train_unet = train_unet
self.is_checkpointing = False
self._multiplier: float = 1.0
self.is_active: bool = False
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_merged_in = False
self.is_lorm = is_lorm
self.network_config: NetworkConfig = network_config
self.module_losses: List[torch.Tensor] = []
self.lorm_train_mode: Literal['local', None] = None
self.can_merge_in = not is_lorm
def get_keymap(self: Network):
if self.is_sdxl:
keymap_tail = 'sdxl'
elif self.is_v2:
keymap_tail = 'sd2'
else:
keymap_tail = 'sd1'
# load keymap
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
keymap = None
# check if file exists
if os.path.exists(keymap_path):
with open(keymap_path, 'r') as f:
keymap = json.load(f)['ldm_diffusers_keymap']
return keymap
def save_weights(
self: Network,
file, dtype=torch.float16,
metadata=None,
extra_state_dict: Optional[OrderedDict] = None
):
keymap = self.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():
# invert them
save_keymap[diffusers_key] = ldm_key
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
save_dict = OrderedDict()
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
save_key = save_keymap[key] if key in save_keymap else key
save_dict[save_key] = v
if extra_state_dict is not None:
# add extra items to state dict
for key in list(extra_state_dict.keys()):
v = extra_state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
save_dict[key] = v
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(state_dict, metadata)
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(save_dict, file, metadata)
else:
torch.save(save_dict, file)
def load_weights(self: Network, file):
# allows us to save and load to and from ldm weights
keymap = self.get_keymap()
keymap = {} if keymap is None else keymap
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
load_sd = OrderedDict()
for key, value in weights_sd.items():
load_key = keymap[key] if key in keymap else key
load_sd[load_key] = value
# extract extra items from state dict
current_state_dict = self.state_dict()
extra_dict = OrderedDict()
to_delete = []
for key in list(load_sd.keys()):
if key not in current_state_dict:
extra_dict[key] = load_sd[key]
to_delete.append(key)
for key in to_delete:
del load_sd[key]
info = self.load_state_dict(load_sd, False)
if len(extra_dict.keys()) == 0:
extra_dict = None
return extra_dict
@torch.no_grad()
def _update_torch_multiplier(self: Network):
# builds a tensor for fast usage in the forward pass of the network modules
# without having to set it in every single module every time it changes
multiplier = self._multiplier
# get first module
first_module = self.get_all_modules()[0]
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
elif isinstance(multiplier, list):
tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype)
elif isinstance(multiplier, torch.Tensor):
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
self.torch_multiplier = tensor_multiplier.clone().detach()
@property
def multiplier(self) -> Union[float, List[float], List[List[float]]]:
return self._multiplier
@multiplier.setter
def multiplier(self, value: Union[float, List[float], List[List[float]]]):
# it takes time to update all the multipliers, so we only do it if the value has changed
if self._multiplier == value:
return
# if we are setting a single value but have a list, keep the list if every item is the same as value
self._multiplier = value
self._update_torch_multiplier()
# called when the context manager is entered
# ie: with network:
def __enter__(self: Network):
self.is_active = True
def __exit__(self: Network, exc_type, exc_value, tb):
self.is_active = False
def force_to(self: Network, device, dtype):
self.to(device, dtype)
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
for lora in loras:
lora.to(device, dtype)
def get_all_modules(self: Network) -> List[Module]:
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
return loras
def _update_checkpointing(self: Network):
for module in self.get_all_modules():
if self.is_checkpointing:
module.enable_gradient_checkpointing()
else:
module.disable_gradient_checkpointing()
def enable_gradient_checkpointing(self: Network):
# not supported
self.is_checkpointing = True
self._update_checkpointing()
def disable_gradient_checkpointing(self: Network):
# not supported
self.is_checkpointing = False
self._update_checkpointing()
def merge_in(self, merge_weight=1.0):
self.is_merged_in = True
for module in self.get_all_modules():
module.merge_in(merge_weight)
def merge_out(self: Network, merge_weight=1.0):
if not self.is_merged_in:
return
self.is_merged_in = False
for module in self.get_all_modules():
module.merge_out(merge_weight)
def extract_weight(
self: Network,
extract_mode: ExtractMode = "existing",
extract_mode_param: Union[int, float] = None,
):
if extract_mode_param is None:
raise ValueError("extract_mode_param must be set")
for module in tqdm(self.get_all_modules(), desc="Extracting weights"):
module.extract_weight(
extract_mode=extract_mode,
extract_mode_param=extract_mode_param
)
def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None):
for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"):
module.setup_lorm(state_dict=state_dict)
def calculate_lorem_parameter_reduction(self):
params_reduced = 0
for module in self.get_all_modules():
num_orig_module_params = count_parameters(module.org_module[0])
num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up)
params_reduced += (num_orig_module_params - num_lorem_params)
return params_reduced

View File

@@ -1,4 +1,5 @@
import torch
from transformers import Adafactor
def get_optimizer(
@@ -35,6 +36,8 @@ def get_optimizer(
if use_lr < 0.1:
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
use_lr = 1.0
print(f"Using lr {use_lr}")
# let net be the neural network you want to train
# you can choose weight decay value based on your problem, 0 by default
optimizer = Prodigy(params, lr=use_lr, **optimizer_params)
@@ -43,6 +46,8 @@ def get_optimizer(
if lower_type == "adam8bit":
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
elif lower_type == "adamw8bit":
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params)
elif lower_type == "lion8bit":
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
else:
@@ -52,10 +57,15 @@ def get_optimizer(
elif lower_type == 'adamw':
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
elif lower_type == 'lion':
from lion_pytorch import Lion
return Lion(params, lr=learning_rate, **optimizer_params)
try:
from lion_pytorch import Lion
return Lion(params, lr=learning_rate, **optimizer_params)
except ImportError:
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
elif lower_type == 'adagrad':
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
elif lower_type == 'adafactor':
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
return optimizer

View File

@@ -0,0 +1,91 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2560
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 384
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 4
context_dim: [1280, 1280, 1280, 1280] # 1280
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
legacy: False
freeze: True
layer: penultimate
always_return_pooled: True
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: aesthetic_score
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by one
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@@ -5,6 +5,8 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
# check if ENV variable is set
if 'MODELS_PATH' in os.environ:

View File

@@ -1,14 +1,297 @@
import importlib
import inspect
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from diffusers.utils import is_torch_xla_available
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
def __init__(
self,
vae: 'AutoencoderKL',
text_encoder: 'CLIPTextModel',
text_encoder_2: 'CLIPTextModelWithProjection',
tokenizer: 'CLIPTokenizer',
tokenizer_2: 'CLIPTokenizer',
unet: 'UNet2DConditionModel',
scheduler: 'KarrasDiffusionSchedulers',
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
)
self.sampler = None
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
model = ModelWrapper(unet, scheduler.alphas_cumprod)
if scheduler.config.prediction_type == "v_prediction":
self.k_diffusion_model = CompVisVDenoiser(model)
else:
self.k_diffusion_model = CompVisDenoiser(model)
def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
use_karras_sigmas: bool = False,
):
# 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
# 5. Prepare sigmas
if use_karras_sigmas:
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
sigmas = sigmas.to(device)
else:
sigmas = self.scheduler.sigmas
sigmas = sigmas.to(prompt_embeds.dtype)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
latents = latents * sigmas[0]
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
# 7. Define model function
def model_fn(x, t):
latent_model_input = torch.cat([x] * 2)
t = torch.cat([t] * 2)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# noise_pred = self.unet(
# latent_model_input,
# t,
# encoder_hidden_states=prompt_embeds,
# cross_attention_kwargs=cross_attention_kwargs,
# added_cond_kwargs=added_cond_kwargs,
# return_dict=False,
# )[0]
noise_pred = self.k_diffusion_model(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred
# 8. Run k-diffusion solver
sampler_kwargs = {}
# should work without it
noise_sampler_seed = None
if "noise_sampler" in inspect.signature(self.sampler).parameters:
min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
sampler_kwargs["noise_sampler"] = noise_sampler
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
def predict_noise(
self,
@@ -532,3 +815,389 @@ class CustomStableDiffusionPipeline(StableDiffusionPipeline):
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
return noise_pred
class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
denoising_start: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a target image resolution. It should be as same
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 8.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
# 8.2 Determine denoising_start
denoising_start_index = 0
if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1:
discrete_timestep_start = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps)))
with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar:
for i, t in enumerate(timesteps, start=denoising_start_index):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
else:
image = latents
if not output_type == "latent":
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)

25
toolkit/progress_bar.py Normal file
View File

@@ -0,0 +1,25 @@
from tqdm import tqdm
import time
class ToolkitProgressBar(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.paused = False
self.last_time = self._time()
def pause(self):
if not self.paused:
self.paused = True
self.last_time = self._time()
def unpause(self):
if self.paused:
self.paused = False
cur_t = self._time()
self.start_t += cur_t - self.last_time
self.last_print_t = cur_t
def update(self, *args, **kwargs):
if not self.paused:
super().update(*args, **kwargs)

View File

@@ -1,12 +1,11 @@
import os
from typing import Optional, TYPE_CHECKING, List
from typing import Optional, TYPE_CHECKING, List, Union, Tuple
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import random
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
import itertools
@@ -19,6 +18,39 @@ class ACTION_TYPES_SLIDER:
ENHANCE_NEGATIVE = 1
class PromptEmbeds:
text_embeds: torch.Tensor
pooled_embeds: Union[torch.Tensor, None]
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
if isinstance(args, list) or isinstance(args, tuple):
# xl
self.text_embeds = args[0]
self.pooled_embeds = args[1]
else:
# sdv1.x, sdv2.x
self.text_embeds = args
self.pooled_embeds = None
def to(self, *args, **kwargs):
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
return self
def detach(self):
self.text_embeds = self.text_embeds.detach()
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.detach()
return self
def clone(self):
if self.pooled_embeds is not None:
return PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
else:
return PromptEmbeds(self.text_embeds.clone())
class EncodedPromptPair:
def __init__(
self,
@@ -73,6 +105,18 @@ class EncodedPromptPair:
self.both_targets = self.both_targets.to(*args, **kwargs)
return self
def detach(self):
self.target_class = self.target_class.detach()
self.target_class_with_neutral = self.target_class_with_neutral.detach()
self.positive_target = self.positive_target.detach()
self.positive_target_with_neutral = self.positive_target_with_neutral.detach()
self.negative_target = self.negative_target.detach()
self.negative_target_with_neutral = self.negative_target_with_neutral.detach()
self.neutral = self.neutral.detach()
self.empty_prompt = self.empty_prompt.detach()
self.both_targets = self.both_targets.detach()
return self
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
@@ -235,15 +279,17 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
return anchors
def get_permutations(s):
def get_permutations(s, max_permutations=8):
# Split the string by comma
phrases = [phrase.strip() for phrase in s.split(',')]
# remove empty strings
phrases = [phrase for phrase in phrases if len(phrase) > 0]
# shuffle the list
random.shuffle(phrases)
# Get all permutations
permutations = list(itertools.permutations(phrases))
permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)])
# Convert the tuples back to comma separated strings
return [', '.join(permutation) for permutation in permutations]
@@ -251,8 +297,8 @@ def get_permutations(s):
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
from toolkit.config_modules import SliderTargetConfig
pos_permutations = get_permutations(target.positive)
neg_permutations = get_permutations(target.negative)
pos_permutations = get_permutations(target.positive, max_permutations=max_permutations)
neg_permutations = get_permutations(target.negative, max_permutations=max_permutations)
permutations = []
for pos, neg in itertools.product(pos_permutations, neg_permutations):
@@ -465,3 +511,39 @@ def build_latent_image_batch_for_prompt_pair(
latent_list.append(neg_latent)
return torch.cat(latent_list, dim=0)
def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
if trigger is None:
# process as empty string to remove any [trigger] tokens
trigger = ''
output_prompt = prompt
default_replacements = ["[name]", "[trigger]"]
replace_with = trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
if trigger.strip() != "":
# see how many times replace_with is in the prompt
num_instances = output_prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
# if num_instances > 1:
# print(
# f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt

116
toolkit/sampler.py Normal file
View File

@@ -0,0 +1,116 @@
import copy
from diffusers import (
DDPMScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
LCMScheduler
)
from k_diffusion.external import CompVisDenoiser
from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
sdxl_sampler_config = {
"_class_name": "EulerDiscreteScheduler",
"_diffusers_version": "0.19.0.dev0",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"interpolation_type": "linear",
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"timestep_spacing": "leading",
"trained_betas": None,
"use_karras_sigmas": False
}
def get_sampler(
sampler: str,
):
sched_init_args = {}
if sampler.startswith("k_"):
sched_init_args["use_karras_sigmas"] = True
if sampler == "ddim":
scheduler_cls = DDIMScheduler
elif sampler == "ddpm": # ddpm is not supported ?
scheduler_cls = DDPMScheduler
elif sampler == "pndm":
scheduler_cls = PNDMScheduler
elif sampler == "lms" or sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif sampler == "euler" or sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif sampler == "euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif sampler == "dpmsolver" or sampler == "dpmsolver++" or sampler == "k_dpmsolver" or sampler == "k_dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = sampler.replace("k_", "")
elif sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif sampler == "dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif sampler == "dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
elif sampler == "lcm":
scheduler_cls = LCMScheduler
elif sampler == "custom_lcm":
scheduler_cls = CustomLCMScheduler
config = copy.deepcopy(sdxl_sampler_config)
config.update(sched_init_args)
scheduler = scheduler_cls.from_config(config)
return scheduler
# testing
if __name__ == "__main__":
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionKDiffusionPipeline
import torch
import os
inference_steps = 25
pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipe = pipe.to("cuda")
k_diffusion_model = CompVisDenoiser(model)
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
pipe = pipe.to("cuda")
prompt = "an astronaut riding a horse on mars"
pipe.set_scheduler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
image.save("./astronaut_heun_k_diffusion.png")

View File

@@ -0,0 +1,553 @@
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class LCMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
denoised: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class CustomLCMScheduler(SchedulerMixin, ConfigMixin):
"""
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
non-Markovian guidance.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
original_inference_steps (`int`, *optional*, defaults to 50):
The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
clip_sample (`bool`, defaults to `True`):
Clip the predicted sample for numerical stability.
clip_sample_range (`float`, defaults to 1.0):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, defaults to `True`):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
Diffusion.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
timestep_scaling (`float`, defaults to 10.0):
The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
`c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
error at the default of `10.0` is already pretty small).
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "scaled_linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
original_inference_steps: int = 50,
clip_sample: bool = False,
clip_sample_range: float = 1.0,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
timestep_scaling: float = 10.0,
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
self.original_inference_steps = 50
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self.train_timesteps = 1000
self._step_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
@property
def step_index(self):
return self._step_index
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
strength: int = 1.0,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
original_inference_steps (`int`, *optional*):
The original number of inference steps, which will be used to generate a linearly-spaced timestep
schedule (which is different from the standard `diffusers` implementation). We will then take
`num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
"""
original_inference_steps = self.original_inference_steps
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
original_steps = (
original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
)
if original_steps > self.config.num_train_timesteps:
raise ValueError(
f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
if num_inference_steps > original_steps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
f" {original_steps} because the final timestep schedule will be a subset of the"
f" `original_inference_steps`-sized initial timestep schedule."
)
# LCM Timesteps Setting
# The skipping step parameter k from the paper.
k = self.config.num_train_timesteps // original_steps
# LCM Training/Distillation Steps Schedule
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
if skipping_step < 1:
raise ValueError(
f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
)
# LCM Inference Steps Schedule
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
inference_indices = np.linspace(0, len(lcm_origin_timesteps) - 1, num=num_inference_steps)
inference_indices = np.floor(inference_indices).astype(np.int64)
timesteps = lcm_origin_timesteps[inference_indices]
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
self._step_index = None
def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
scaled_timestep = timestep * self.config.timestep_scaling
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[LCMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# 1. get previous step value
prev_step_index = self.step_index + 1
if prev_step_index < len(self.timesteps):
prev_timestep = self.timesteps[prev_step_index]
else:
prev_timestep = timestep
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 3. Get scalings for boundary conditions
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
# 4. Compute the predicted original sample x_0 based on the model parameterization
if self.config.prediction_type == "epsilon": # noise-prediction
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
elif self.config.prediction_type == "sample": # x-prediction
predicted_original_sample = model_output
elif self.config.prediction_type == "v_prediction": # v-prediction
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for `LCMScheduler`."
)
# 5. Clip or threshold "predicted x_0"
if self.config.thresholding:
predicted_original_sample = self._threshold_sample(predicted_original_sample)
elif self.config.clip_sample:
predicted_original_sample = predicted_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
# 6. Denoise model output using boundary conditions
denoised = c_out * predicted_original_sample + c_skip * sample
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
# Noise is not used on the final timestep of the timestep schedule.
# This also means that noise is not used for one-step sampling.
if self.step_index != self.num_inference_steps - 1:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
)
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample, denoised)
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -32,6 +32,10 @@ def convert_state_dict_to_ldm_with_mapping(
with open(mapping_path, 'r') as f:
mapping = json.load(f, object_pairs_hook=OrderedDict)
# keep track of keys not matched
ldm_matched_keys = []
diffusers_matched_keys = []
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map']
ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map']
@@ -52,11 +56,15 @@ def convert_state_dict_to_ldm_with_mapping(
for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
cat_list.append(diffusers_state_dict[diffusers_key].detach())
converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['cat'])
ldm_matched_keys.append(ldm_key)
if 'slice' in ldm_diffusers_operator_map[ldm_key]:
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device,
dtype=dtype)
diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['slice'])
ldm_matched_keys.append(ldm_key)
# process the rest of the keys
for ldm_key in ldm_diffusers_keymap:
@@ -67,13 +75,29 @@ def convert_state_dict_to_ldm_with_mapping(
if ldm_key in ldm_diffusers_shape_map:
tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0])
converted_state_dict[ldm_key] = tensor
diffusers_matched_keys.append(ldm_diffusers_keymap[ldm_key])
ldm_matched_keys.append(ldm_key)
# see if any are missing from know mapping
mapped_diffusers_keys = list(ldm_diffusers_keymap.values())
mapped_ldm_keys = list(ldm_diffusers_keymap.keys())
missing_diffusers_keys = [x for x in mapped_diffusers_keys if x not in diffusers_matched_keys]
missing_ldm_keys = [x for x in mapped_ldm_keys if x not in ldm_matched_keys]
if len(missing_diffusers_keys) > 0:
print(f"WARNING!!!! Missing {len(missing_diffusers_keys)} diffusers keys")
print(missing_diffusers_keys)
if len(missing_ldm_keys) > 0:
print(f"WARNING!!!! Missing {len(missing_ldm_keys)} ldm keys")
print(missing_ldm_keys)
return converted_state_dict
def get_ldm_state_dict_from_diffusers(
state_dict: 'OrderedDict',
sd_version: Literal['1', '2', 'sdxl'] = '2',
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
device='cpu',
dtype=get_torch_dtype('fp32'),
):
@@ -87,6 +111,14 @@ def get_ldm_state_dict_from_diffusers(
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
elif sd_version == 'ssd':
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
elif sd_version == 'sdxl_refiner':
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json')
else:
raise ValueError(f"Invalid sd_version {sd_version}")
@@ -105,7 +137,7 @@ def save_ldm_model_from_diffusers(
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
):
converted_state_dict = get_ldm_state_dict_from_diffusers(
sd.state_dict(),
@@ -117,3 +149,95 @@ def save_ldm_model_from_diffusers(
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
def save_lora_from_diffusers(
lora_state_dict: 'OrderedDict',
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
):
converted_state_dict = OrderedDict()
# only handle sxdxl for now
if sd_version != 'sdxl' and sd_version != 'ssd':
raise ValueError(f"Invalid sd_version {sd_version}")
for key, value in lora_state_dict.items():
# todo verify if this works with ssd
# test encoders share keys for some reason
if key.begins_with('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
else:
converted_key = key
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
def save_t2i_from_diffusers(
t2i_state_dict: 'OrderedDict',
output_file: str,
meta: 'OrderedDict',
dtype=get_torch_dtype('fp16'),
):
# todo: test compatibility with non diffusers
converted_state_dict = OrderedDict()
for key, value in t2i_state_dict.items():
converted_state_dict[key] = value.detach().to('cpu', dtype=dtype)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
def load_t2i_model(
path_to_file,
device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32
):
raw_state_dict = load_file(path_to_file, device)
converted_state_dict = OrderedDict()
for key, value in raw_state_dict.items():
# todo see if we need to convert dict
converted_state_dict[key] = value.detach().to(device, dtype=dtype)
return converted_state_dict
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
def save_ip_adapter_from_diffusers(
combined_state_dict: 'OrderedDict',
output_file: str,
meta: 'OrderedDict',
dtype=get_torch_dtype('fp16'),
):
# todo: test compatibility with non diffusers
converted_state_dict = OrderedDict()
for module_name, state_dict in combined_state_dict.items():
for key, value in state_dict.items():
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
def load_ip_adapter_model(
path_to_file,
device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32
):
# check if it is safetensors or checkpoint
if path_to_file.endswith('.safetensors'):
raw_state_dict = load_file(path_to_file, device)
combined_state_dict = OrderedDict()
for combo_key, value in raw_state_dict.items():
key_split = combo_key.split('.')
module_name = key_split.pop(0)
if module_name not in combined_state_dict:
combined_state_dict[module_name] = OrderedDict()
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
return combined_state_dict
else:
return torch.load(path_to_file, map_location=device)

View File

@@ -1,33 +1,57 @@
import torch
from typing import Optional
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup
def get_lr_scheduler(
name: Optional[str],
optimizer: torch.optim.Optimizer,
max_iterations: Optional[int],
lr_min: Optional[float],
**kwargs,
):
if name == "cosine":
if 'total_iters' in kwargs:
kwargs['T_max'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
optimizer, **kwargs
)
elif name == "cosine_with_restarts":
if 'total_iters' in kwargs:
kwargs['T_0'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=max_iterations, T_mult=2, eta_min=lr_min, **kwargs
optimizer, **kwargs
)
elif name == "step":
return torch.optim.lr_scheduler.StepLR(
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
optimizer, **kwargs
)
elif name == "constant":
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
if 'facor' not in kwargs:
kwargs['factor'] = 1.0
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
elif name == "linear":
return torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.5, end_factor=0.5, total_iters=max_iterations, **kwargs
optimizer, **kwargs
)
elif name == 'constant_with_warmup':
# see if num_warmup_steps is in kwargs
if 'num_warmup_steps' not in kwargs:
print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000")
kwargs['num_warmup_steps'] = 1000
del kwargs['total_iters']
return get_constant_schedule_with_warmup(optimizer, **kwargs)
else:
# try to use a diffusers scheduler
print(f"Trying to use diffusers scheduler {name}")
try:
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
return schedule_func(optimizer, **kwargs)
except Exception as e:
print(e)
pass
raise ValueError(
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
)

View File

@@ -0,0 +1,88 @@
from typing import Union
import torch
import copy
empty_preset = {
'vae': {
'training': False,
'device': 'cpu',
},
'unet': {
'training': False,
'requires_grad': False,
'device': 'cpu',
},
'text_encoder': {
'training': False,
'requires_grad': False,
'device': 'cpu',
},
'adapter': {
'training': False,
'requires_grad': False,
'device': 'cpu',
},
'refiner_unet': {
'training': False,
'requires_grad': False,
'device': 'cpu',
},
}
def get_train_sd_device_state_preset(
device: Union[str, torch.device],
train_unet: bool = False,
train_text_encoder: bool = False,
cached_latents: bool = False,
train_lora: bool = False,
train_adapter: bool = False,
train_embedding: bool = False,
train_refiner: bool = False,
):
preset = copy.deepcopy(empty_preset)
if not cached_latents:
preset['vae']['device'] = device
if train_unet:
preset['unet']['training'] = True
preset['unet']['requires_grad'] = True
preset['unet']['device'] = device
else:
preset['unet']['device'] = device
if train_text_encoder:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['device'] = device
else:
preset['text_encoder']['device'] = device
if train_embedding:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['training'] = True
preset['unet']['training'] = True
if train_refiner:
preset['refiner_unet']['training'] = True
preset['refiner_unet']['requires_grad'] = True
preset['refiner_unet']['device'] = device
# if not training unet, move that to cpu
if not train_unet:
preset['unet']['device'] = 'cpu'
if train_lora:
# preset['text_encoder']['requires_grad'] = False
preset['unet']['requires_grad'] = False
if train_refiner:
preset['refiner_unet']['requires_grad'] = False
if train_adapter:
preset['adapter']['requires_grad'] = True
preset['adapter']['training'] = True
preset['adapter']['device'] = device
preset['unet']['training'] = True
return preset

File diff suppressed because it is too large Load Diff

View File

@@ -33,12 +33,17 @@ class ContentLoss(nn.Module):
# Define the separate loss function
def separated_loss(y_pred, y_true):
y_pred = y_pred.float()
y_true = y_true.float()
diff = torch.abs(y_pred - y_true)
l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
return 2. * l2 / content_size
# Calculate itemized loss
pred_itemized_loss = separated_loss(pred_layer, target_layer)
# check if is nan
if torch.isnan(pred_itemized_loss).any():
print('pred_itemized_loss is nan')
# Calculate the mean of itemized loss
loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
@@ -48,6 +53,7 @@ class ContentLoss(nn.Module):
def convert_to_gram_matrix(inputs):
inputs = inputs.float()
shape = inputs.size()
batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
size = height * width * filters
@@ -93,11 +99,14 @@ class StyleLoss(nn.Module):
target_grams = convert_to_gram_matrix(style_target)
pred_grams = convert_to_gram_matrix(preds)
itemized_loss = separated_loss(pred_grams, target_grams)
# check if is nan
if torch.isnan(itemized_loss).any():
print('itemized_loss is nan')
# reshape itemized loss to be (batch, 1, 1, 1)
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
self.loss = loss.to(input_dtype)
self.loss = loss.to(input_dtype).float()
return stacked_input.to(input_dtype)
@@ -149,7 +158,7 @@ def get_style_model_and_losses(
):
# content_layers = ['conv_4']
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers = ['conv4_2']
content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
# set all weights in the model to our dtype

65
toolkit/timer.py Normal file
View File

@@ -0,0 +1,65 @@
import time
from collections import OrderedDict, deque
class Timer:
def __init__(self, name='Timer', max_buffer=10):
self.name = name
self.max_buffer = max_buffer
self.timers = OrderedDict()
self.active_timers = {}
self.current_timer = None # Used for the context manager functionality
def start(self, timer_name):
if timer_name not in self.timers:
self.timers[timer_name] = deque(maxlen=self.max_buffer)
self.active_timers[timer_name] = time.time()
def cancel(self, timer_name):
"""Cancel an active timer."""
if timer_name in self.active_timers:
del self.active_timers[timer_name]
def stop(self, timer_name):
if timer_name not in self.active_timers:
raise ValueError(f"Timer '{timer_name}' was not started!")
elapsed_time = time.time() - self.active_timers[timer_name]
self.timers[timer_name].append(elapsed_time)
# Clean up active timers
del self.active_timers[timer_name]
# Check if this timer's buffer exceeds max_buffer and remove the oldest if it does
if len(self.timers[timer_name]) > self.max_buffer:
self.timers[timer_name].popleft()
def print(self):
print(f"\nTimer '{self.name}':")
# sort by longest at top
for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
avg_time = sum(timings) / len(timings)
print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
print('')
def reset(self):
self.timers.clear()
self.active_timers.clear()
def __call__(self, timer_name):
"""Enable the use of the Timer class as a context manager."""
self.current_timer = timer_name
self.start(timer_name)
return self
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
# No exceptions, stop the timer normally
self.stop(self.current_timer)
else:
# There was an exception, cancel the timer
self.cancel(self.current_timer)

View File

@@ -5,6 +5,9 @@ import os
import time
from typing import TYPE_CHECKING, Union
import sys
from torch.cuda.amp import GradScaler
from toolkit.paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT)
@@ -444,29 +447,78 @@ if TYPE_CHECKING:
def text_tokenize(
tokenizer: 'CLIPTokenizer', # 普通ならひとつ、XLならふたつ
tokenizer: 'CLIPTokenizer',
prompts: list[str],
truncate: bool = True,
max_length: int = None,
max_length_multiplier: int = 4,
):
return tokenizer(
# allow fo up to 4x the max length for long prompts
if max_length is None:
if truncate:
max_length = tokenizer.model_max_length
else:
# allow up to 4x the max length for long prompts
max_length = tokenizer.model_max_length * max_length_multiplier
input_ids = tokenizer(
prompts,
padding="max_length",
max_length=tokenizer.model_max_length,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors="pt",
).input_ids
if truncate or max_length == tokenizer.model_max_length:
return input_ids
else:
# remove additional padding
num_chunks = input_ids.shape[1] // tokenizer.model_max_length
chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1)
# New list to store non-redundant chunks
non_redundant_chunks = []
for chunk in chunks:
if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element
non_redundant_chunks.append(chunk)
input_ids = torch.cat(non_redundant_chunks, dim=1)
return input_ids
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
def text_encode_xl(
text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'],
tokens: torch.FloatTensor,
num_images_per_prompt: int = 1,
max_length: int = 77, # not sure what default to put here, always pass one?
truncate: bool = True,
):
prompt_embeds = text_encoder(
tokens.to(text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
if truncate:
# normal short prompt 77 tokens max
prompt_embeds = text_encoder(
tokens.to(text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
else:
# handle long prompts
prompt_embeds_list = []
tokens = tokens.to(text_encoder.device)
pooled_prompt_embeds = None
for i in range(0, tokens.shape[-1], max_length):
# todo run it through the in a single batch
section_tokens = tokens[:, i: i + max_length]
embeds = text_encoder(section_tokens, output_hidden_states=True)
pooled_prompt_embed = embeds[0]
if pooled_prompt_embeds is None:
# we only want the first ( I think??)
pooled_prompt_embeds = pooled_prompt_embed
prompt_embed = embeds.hidden_states[-2] # always penultimate layer
prompt_embeds_list.append(prompt_embed)
prompt_embeds = torch.cat(prompt_embeds_list, dim=1)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -479,26 +531,43 @@ def encode_prompts_xl(
tokenizers: list['CLIPTokenizer'],
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']],
prompts: list[str],
prompts2: Union[list[str], None],
num_images_per_prompt: int = 1,
use_text_encoder_1: bool = True, # sdxl
use_text_encoder_2: bool = True # sdxl
use_text_encoder_2: bool = True, # sdxl
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# text_encoder and text_encoder_2's penuultimate layer's output
text_embeds_list = []
pooled_text_embeds = None # always text_encoder_2's pool
if prompts2 is None:
prompts2 = prompts
for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
# todo, we are using a blank string to ignore that encoder for now.
# find a better way to do this (zeroing?, removing it from the unet?)
prompt_list_to_use = prompts
prompt_list_to_use = prompts if idx == 0 else prompts2
if idx == 0 and not use_text_encoder_1:
prompt_list_to_use = ["" for _ in prompts]
if idx == 1 and not use_text_encoder_2:
prompt_list_to_use = ["" for _ in prompts]
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use)
if dropout_prob > 0.0:
# randomly drop out prompts
prompt_list_to_use = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use
]
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
# set the max length for the next one
if idx == 0:
max_length = text_tokens_input_ids.shape[-1]
text_embeds, pooled_text_embeds = text_encode_xl(
text_encoder, text_tokens_input_ids, num_images_per_prompt
text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length,
truncate=truncate
)
text_embeds_list.append(text_embeds)
@@ -511,17 +580,49 @@ def encode_prompts_xl(
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
def text_encode(text_encoder: 'CLIPTextModel', tokens):
return text_encoder(tokens.to(text_encoder.device))[0]
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
if max_length is None and not truncate:
raise ValueError("max_length must be set if truncate is True")
try:
tokens = tokens.to(text_encoder.device)
except Exception as e:
print(e)
print("tokens.device", tokens.device)
print("text_encoder.device", text_encoder.device)
raise e
if truncate:
return text_encoder(tokens)[0]
else:
# handle long prompts
prompt_embeds_list = []
for i in range(0, tokens.shape[-1], max_length):
prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0]
prompt_embeds_list.append(prompt_embeds)
return torch.cat(prompt_embeds_list, dim=1)
def encode_prompts(
tokenizer: 'CLIPTokenizer',
text_encoder: 'CLIPTokenizer',
text_encoder: 'CLIPTextModel',
prompts: list[str],
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
):
text_tokens = text_tokenize(tokenizer, prompts)
text_embeddings = text_encode(text_encoder, text_tokens)
if max_length is None:
max_length = tokenizer.model_max_length
if dropout_prob > 0.0:
# randomly drop out prompts
prompts = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
]
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
return text_embeddings
@@ -602,18 +703,86 @@ def get_all_snr(noise_scheduler, device):
all_snr.requires_grad = False
return all_snr.to(device)
class LearnableSNRGamma:
"""
This is a trainer for learnable snr gamma
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
"""
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
self.device = device
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01)
self.buffer = []
self.max_buffer_size = 20
def forward(self, loss, timesteps):
# do a our train loop for lsnr here and return our values detached
loss = loss.detach()
with torch.no_grad():
loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
for loss_chunk in loss_chunks:
self.buffer.append(loss_chunk.mean().detach())
if len(self.buffer) > self.max_buffer_size:
self.buffer.pop(0)
all_snr = get_all_snr(self.noise_scheduler, loss.device)
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
base_snrs = snr.clone().detach()
snr.requires_grad = True
snr = (snr + self.offset_1) * self.scale + self.offset_2
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
snr_adjusted_loss = loss * snr_weight
with torch.no_grad():
target = torch.mean(torch.stack(self.buffer)).detach()
# local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
squared_differences = (snr_adjusted_loss - target) ** 2
local_loss = torch.mean(squared_differences)
local_loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach()
def apply_learnable_snr_gos(
loss,
timesteps,
learnable_snr_trainer: LearnableSNRGamma
):
snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps)
snr = (snr + offset_1) * scale + offset_2
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
snr_adjusted_loss = loss * snr_weight
return snr_adjusted_loss
def apply_snr_weight(
loss,
timesteps,
noise_scheduler: Union['DDPMScheduler'],
gamma
gamma,
fixed=False,
):
# will get it form noise scheduler if exist or will calculate it if not
# will get it from noise scheduler if exist or will calculate it if not
all_snr = get_all_snr(noise_scheduler, loss.device)
snr = torch.stack([all_snr[t] for t in timesteps])
step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
snr = torch.stack([all_snr[t] for t in step_indices])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
loss = loss * snr_weight
return loss
if fixed:
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
else:
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
snr_adjusted_loss = loss * snr_weight
return snr_adjusted_loss