mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Merge remote-tracking branch 'origin/development'
# Conflicts: # toolkit/stable_diffusion_model.py
This commit is contained in:
6
.gitmodules
vendored
6
.gitmodules
vendored
@@ -4,3 +4,9 @@
|
||||
[submodule "repositories/leco"]
|
||||
path = repositories/leco
|
||||
url = https://github.com/p1atdev/LECO
|
||||
[submodule "repositories/batch_annotator"]
|
||||
path = repositories/batch_annotator
|
||||
url = https://github.com/ostris/batch-annotator
|
||||
[submodule "repositories/ipadapter"]
|
||||
path = repositories/ipadapter
|
||||
url = https://github.com/tencent-ailab/IP-Adapter.git
|
||||
|
||||
102
extensions_built_in/advanced_generator/PureLoraGenerator.py
Normal file
102
extensions_built_in/advanced_generator/PureLoraGenerator.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
|
||||
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class PureLoraGenerator(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.output_folder = self.get_conf('output_folder', required=True)
|
||||
self.device = self.get_conf('device', 'cuda')
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
||||
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
|
||||
self.dtype = self.get_conf('dtype', 'float16')
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
lorm_config = self.get_conf('lorm', None)
|
||||
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
|
||||
|
||||
self.device_state_preset = get_train_sd_device_state_preset(
|
||||
device=torch.device(self.device),
|
||||
)
|
||||
|
||||
self.progress_bar = None
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print("Loading model...")
|
||||
with torch.no_grad():
|
||||
self.sd.load_model()
|
||||
self.sd.unet.eval()
|
||||
self.sd.unet.to(self.device_torch)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
te.to(self.device_torch)
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
self.sd.to(self.device_torch)
|
||||
|
||||
print(f"Converting to LoRM UNet")
|
||||
# replace the unet with LoRMUnet
|
||||
convert_diffusers_unet_to_lorm(
|
||||
self.sd.unet,
|
||||
config=self.lorm_config,
|
||||
)
|
||||
|
||||
sample_folder = os.path.join(self.output_folder)
|
||||
gen_img_config_list = []
|
||||
|
||||
sample_config = self.generate_config
|
||||
start_seed = sample_config.seed
|
||||
current_seed = start_seed
|
||||
for i in range(len(sample_config.prompts)):
|
||||
if sample_config.walk_seed:
|
||||
current_seed = start_seed + i
|
||||
|
||||
filename = f"[time]_[count].{self.generate_config.ext}"
|
||||
output_path = os.path.join(sample_folder, filename)
|
||||
prompt = sample_config.prompts[i]
|
||||
extra_args = {}
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
width=sample_config.width,
|
||||
height=sample_config.height,
|
||||
negative_prompt=sample_config.neg,
|
||||
seed=current_seed,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
guidance_rescale=sample_config.guidance_rescale,
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
network_multiplier=sample_config.network_multiplier,
|
||||
output_path=output_path,
|
||||
output_ext=sample_config.ext,
|
||||
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
# send to be generated
|
||||
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
||||
print("Done generating images")
|
||||
# cleanup
|
||||
del self.sd
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
193
extensions_built_in/advanced_generator/ReferenceGenerator.py
Normal file
193
extensions_built_in/advanced_generator/ReferenceGenerator.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from torch.utils.data import DataLoader
|
||||
from diffusers import StableDiffusionXLAdapterPipeline
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from controlnet_aux.midas import MidasDetector
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class GenerateConfig:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.prompts: List[str]
|
||||
self.sampler = kwargs.get('sampler', 'ddpm')
|
||||
self.neg = kwargs.get('neg', '')
|
||||
self.seed = kwargs.get('seed', -1)
|
||||
self.walk_seed = kwargs.get('walk_seed', False)
|
||||
self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
|
||||
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
||||
self.sample_steps = kwargs.get('sample_steps', 20)
|
||||
self.prompt_2 = kwargs.get('prompt_2', None)
|
||||
self.neg_2 = kwargs.get('neg_2', None)
|
||||
self.prompts = kwargs.get('prompts', None)
|
||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||
self.ext = kwargs.get('ext', 'png')
|
||||
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
||||
if kwargs.get('shuffle', False):
|
||||
# shuffle the prompts
|
||||
random.shuffle(self.prompts)
|
||||
|
||||
|
||||
class ReferenceGenerator(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.output_folder = self.get_conf('output_folder', required=True)
|
||||
self.device = self.get_conf('device', 'cuda')
|
||||
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
||||
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
||||
self.is_latents_cached = True
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
||||
self.datasets = None
|
||||
self.datasets_reg = None
|
||||
self.dtype = self.get_conf('dtype', 'float16')
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.params = []
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
for raw_dataset in raw_datasets:
|
||||
dataset = DatasetConfig(**raw_dataset)
|
||||
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
||||
if not is_caching:
|
||||
self.is_latents_cached = False
|
||||
if dataset.is_reg:
|
||||
if self.datasets_reg is None:
|
||||
self.datasets_reg = []
|
||||
self.datasets_reg.append(dataset)
|
||||
else:
|
||||
if self.datasets is None:
|
||||
self.datasets = []
|
||||
self.datasets.append(dataset)
|
||||
|
||||
self.progress_bar = None
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
print(f"Using device {self.device}")
|
||||
self.data_loader: DataLoader = None
|
||||
self.adapter: T2IAdapter = None
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print("Loading model...")
|
||||
self.sd.load_model()
|
||||
device = torch.device(self.device)
|
||||
|
||||
if self.generate_config.t2i_adapter_path is not None:
|
||||
self.adapter = T2IAdapter.from_pretrained(
|
||||
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16"
|
||||
).to(device)
|
||||
|
||||
midas_depth = MidasDetector.from_pretrained(
|
||||
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
|
||||
).to(device)
|
||||
|
||||
pipe = StableDiffusionXLAdapterPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder[0],
|
||||
text_encoder_2=self.sd.text_encoder[1],
|
||||
tokenizer=self.sd.tokenizer[0],
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=get_sampler(self.generate_config.sampler),
|
||||
adapter=self.adapter,
|
||||
).to(device)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
||||
|
||||
num_batches = len(self.data_loader)
|
||||
pbar = tqdm(total=num_batches, desc="Generating images")
|
||||
seed = self.generate_config.seed
|
||||
# load images from datasets, use tqdm
|
||||
for i, batch in enumerate(self.data_loader):
|
||||
batch: DataLoaderBatchDTO = batch
|
||||
|
||||
file_item: FileItemDTO = batch.file_items[0]
|
||||
img_path = file_item.path
|
||||
img_filename = os.path.basename(img_path)
|
||||
img_filename_no_ext = os.path.splitext(img_filename)[0]
|
||||
output_path = os.path.join(self.output_folder, img_filename)
|
||||
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
|
||||
output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
|
||||
|
||||
caption = batch.get_caption_list()[0]
|
||||
|
||||
img: torch.Tensor = batch.tensor.clone()
|
||||
# image comes in -1 to 1. convert to a PIL RGB image
|
||||
img = (img + 1) / 2
|
||||
img = img.clamp(0, 1)
|
||||
img = img[0].permute(1, 2, 0).cpu().numpy()
|
||||
img = (img * 255).astype(np.uint8)
|
||||
image = Image.fromarray(img)
|
||||
|
||||
width, height = image.size
|
||||
min_res = min(width, height)
|
||||
|
||||
if self.generate_config.walk_seed:
|
||||
seed = seed + 1
|
||||
|
||||
if self.generate_config.seed == -1:
|
||||
# random
|
||||
seed = random.randint(0, 1000000)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# generate depth map
|
||||
image = midas_depth(
|
||||
image,
|
||||
detect_resolution=min_res, # do 512 ?
|
||||
image_resolution=min_res
|
||||
)
|
||||
|
||||
# image.save(output_depth_path)
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=caption,
|
||||
negative_prompt=self.generate_config.neg,
|
||||
image=image,
|
||||
num_inference_steps=self.generate_config.sample_steps,
|
||||
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
).images[0]
|
||||
gen_images.save(output_path)
|
||||
|
||||
# save caption
|
||||
with open(output_caption_path, 'w') as f:
|
||||
f.write(caption)
|
||||
|
||||
pbar.update(1)
|
||||
batch.cleanup()
|
||||
|
||||
pbar.close()
|
||||
print("Done generating images")
|
||||
# cleanup
|
||||
del self.sd
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
42
extensions_built_in/advanced_generator/__init__.py
Normal file
42
extensions_built_in/advanced_generator/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class AdvancedReferenceGeneratorExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "reference_generator"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Reference Generator"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .ReferenceGenerator import ReferenceGenerator
|
||||
return ReferenceGenerator
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class PureLoraGenerator(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "pure_lora_generator"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Pure LoRA Generator"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .PureLoraGenerator import PureLoraGenerator
|
||||
return PureLoraGenerator
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
AdvancedReferenceGeneratorExtension, PureLoraGenerator
|
||||
]
|
||||
@@ -0,0 +1,91 @@
|
||||
---
|
||||
job: extension
|
||||
config:
|
||||
name: test_v1
|
||||
process:
|
||||
- type: 'textual_inversion_trainer'
|
||||
training_folder: "out/TI"
|
||||
device: cuda:0
|
||||
# for tensorboard logging
|
||||
log_dir: "out/.tensorboard"
|
||||
embedding:
|
||||
trigger: "your_trigger_here"
|
||||
tokens: 12
|
||||
init_words: "man with short brown hair"
|
||||
save_format: "safetensors" # 'safetensors' or 'pt'
|
||||
save:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 100 # save every this many steps
|
||||
max_step_saves_to_keep: 5 # only affects step counts
|
||||
datasets:
|
||||
- folder_path: "/path/to/dataset"
|
||||
caption_ext: "txt"
|
||||
default_caption: "[trigger]"
|
||||
buckets: true
|
||||
resolution: 512
|
||||
train:
|
||||
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
||||
steps: 3000
|
||||
weight_jitter: 0.0
|
||||
lr: 5e-5
|
||||
train_unet: false
|
||||
gradient_checkpointing: true
|
||||
train_text_encoder: false
|
||||
optimizer: "adamw"
|
||||
# optimizer: "prodigy"
|
||||
optimizer_params:
|
||||
weight_decay: 1e-2
|
||||
lr_scheduler: "constant"
|
||||
max_denoising_steps: 1000
|
||||
batch_size: 4
|
||||
dtype: bf16
|
||||
xformers: true
|
||||
min_snr_gamma: 5.0
|
||||
# skip_first_sample: true
|
||||
noise_offset: 0.0 # not needed for this
|
||||
model:
|
||||
# objective reality v2
|
||||
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
||||
is_v2: false # for v2 models
|
||||
is_xl: false # for SDXL models
|
||||
is_v_pred: false # for v-prediction models (most v2 models)
|
||||
sample:
|
||||
sampler: "ddpm" # must match train.noise_scheduler
|
||||
sample_every: 100 # sample every this many steps
|
||||
width: 512
|
||||
height: 512
|
||||
prompts:
|
||||
- "photo of [trigger] laughing"
|
||||
- "photo of [trigger] smiling"
|
||||
- "[trigger] close up"
|
||||
- "dark scene [trigger] frozen"
|
||||
- "[trigger] nighttime"
|
||||
- "a painting of [trigger]"
|
||||
- "a drawing of [trigger]"
|
||||
- "a cartoon of [trigger]"
|
||||
- "[trigger] pixar style"
|
||||
- "[trigger] costume"
|
||||
neg: ""
|
||||
seed: 42
|
||||
walk_seed: false
|
||||
guidance_scale: 7
|
||||
sample_steps: 20
|
||||
network_multiplier: 1.0
|
||||
|
||||
logging:
|
||||
log_every: 10 # log every this many steps
|
||||
use_wandb: false # not supported yet
|
||||
verbose: false
|
||||
|
||||
# You can put any information you want here, and it will be saved in the model.
|
||||
# The below is an example, but you can put your grocery list in it if you want.
|
||||
# It is saved in the model so be aware of that. The software will include this
|
||||
# plus some other information for you automatically
|
||||
meta:
|
||||
# [name] gets replaced with the name above
|
||||
name: "[name]"
|
||||
# version: '1.0'
|
||||
# creator:
|
||||
# name: Your Name
|
||||
# email: your@gmail.com
|
||||
# website: https://your.website
|
||||
151
extensions_built_in/concept_replacer/ConceptReplacer.py
Normal file
151
extensions_built_in/concept_replacer/ConceptReplacer.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from torch.utils.data import DataLoader
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class ConceptReplacementConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.concept: str = kwargs.get('concept', '')
|
||||
self.replacement: str = kwargs.get('replacement', '')
|
||||
|
||||
|
||||
class ConceptReplacer(BaseSDTrainProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
replacement_list = self.config.get('replacements', [])
|
||||
self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list]
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
|
||||
# textual inversion
|
||||
if self.embedding is not None:
|
||||
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
||||
self.sd.text_encoder.train()
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
batch_replacement_list = []
|
||||
# get a random replacement for each prompt
|
||||
for prompt in conditioned_prompts:
|
||||
replacement = random.choice(self.replacement_list)
|
||||
batch_replacement_list.append(replacement)
|
||||
|
||||
# build out prompts
|
||||
concept_prompts = []
|
||||
replacement_prompts = []
|
||||
for idx, replacement in enumerate(batch_replacement_list):
|
||||
prompt = conditioned_prompts[idx]
|
||||
|
||||
# insert shuffled concept at beginning and end of prompt
|
||||
shuffled_concept = [x.strip() for x in replacement.concept.split(',')]
|
||||
random.shuffle(shuffled_concept)
|
||||
shuffled_concept = ', '.join(shuffled_concept)
|
||||
concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}")
|
||||
|
||||
# insert replacement at beginning and end of prompt
|
||||
shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')]
|
||||
random.shuffle(shuffled_replacement)
|
||||
shuffled_replacement = ', '.join(shuffled_replacement)
|
||||
replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}")
|
||||
|
||||
# predict the replacement without network
|
||||
conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype)
|
||||
|
||||
replacement_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
del conditional_embeds
|
||||
replacement_pred = replacement_pred.detach()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
flush()
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.embedding:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
# set the weights
|
||||
network.multiplier = network_weight_list
|
||||
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
# embed the prompts
|
||||
conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype)
|
||||
if not grad_on_text_encoder:
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
self.optimizer.zero_grad()
|
||||
flush()
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
self.embedding.restore_embeddings()
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
)
|
||||
# reset network multiplier
|
||||
network.multiplier = 1.0
|
||||
|
||||
return loss_dict
|
||||
26
extensions_built_in/concept_replacer/__init__.py
Normal file
26
extensions_built_in/concept_replacer/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class ConceptReplacerExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "concept_replacer"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Concept Replacer"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .ConceptReplacer import ConceptReplacer
|
||||
return ConceptReplacer
|
||||
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
ConceptReplacerExtension,
|
||||
]
|
||||
@@ -0,0 +1,91 @@
|
||||
---
|
||||
job: extension
|
||||
config:
|
||||
name: test_v1
|
||||
process:
|
||||
- type: 'textual_inversion_trainer'
|
||||
training_folder: "out/TI"
|
||||
device: cuda:0
|
||||
# for tensorboard logging
|
||||
log_dir: "out/.tensorboard"
|
||||
embedding:
|
||||
trigger: "your_trigger_here"
|
||||
tokens: 12
|
||||
init_words: "man with short brown hair"
|
||||
save_format: "safetensors" # 'safetensors' or 'pt'
|
||||
save:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 100 # save every this many steps
|
||||
max_step_saves_to_keep: 5 # only affects step counts
|
||||
datasets:
|
||||
- folder_path: "/path/to/dataset"
|
||||
caption_ext: "txt"
|
||||
default_caption: "[trigger]"
|
||||
buckets: true
|
||||
resolution: 512
|
||||
train:
|
||||
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
||||
steps: 3000
|
||||
weight_jitter: 0.0
|
||||
lr: 5e-5
|
||||
train_unet: false
|
||||
gradient_checkpointing: true
|
||||
train_text_encoder: false
|
||||
optimizer: "adamw"
|
||||
# optimizer: "prodigy"
|
||||
optimizer_params:
|
||||
weight_decay: 1e-2
|
||||
lr_scheduler: "constant"
|
||||
max_denoising_steps: 1000
|
||||
batch_size: 4
|
||||
dtype: bf16
|
||||
xformers: true
|
||||
min_snr_gamma: 5.0
|
||||
# skip_first_sample: true
|
||||
noise_offset: 0.0 # not needed for this
|
||||
model:
|
||||
# objective reality v2
|
||||
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
||||
is_v2: false # for v2 models
|
||||
is_xl: false # for SDXL models
|
||||
is_v_pred: false # for v-prediction models (most v2 models)
|
||||
sample:
|
||||
sampler: "ddpm" # must match train.noise_scheduler
|
||||
sample_every: 100 # sample every this many steps
|
||||
width: 512
|
||||
height: 512
|
||||
prompts:
|
||||
- "photo of [trigger] laughing"
|
||||
- "photo of [trigger] smiling"
|
||||
- "[trigger] close up"
|
||||
- "dark scene [trigger] frozen"
|
||||
- "[trigger] nighttime"
|
||||
- "a painting of [trigger]"
|
||||
- "a drawing of [trigger]"
|
||||
- "a cartoon of [trigger]"
|
||||
- "[trigger] pixar style"
|
||||
- "[trigger] costume"
|
||||
neg: ""
|
||||
seed: 42
|
||||
walk_seed: false
|
||||
guidance_scale: 7
|
||||
sample_steps: 20
|
||||
network_multiplier: 1.0
|
||||
|
||||
logging:
|
||||
log_every: 10 # log every this many steps
|
||||
use_wandb: false # not supported yet
|
||||
verbose: false
|
||||
|
||||
# You can put any information you want here, and it will be saved in the model.
|
||||
# The below is an example, but you can put your grocery list in it if you want.
|
||||
# It is saved in the model so be aware of that. The software will include this
|
||||
# plus some other information for you automatically
|
||||
meta:
|
||||
# [name] gets replaced with the name above
|
||||
name: "[name]"
|
||||
# version: '1.0'
|
||||
# creator:
|
||||
# name: Your Name
|
||||
# email: your@gmail.com
|
||||
# website: https://your.website
|
||||
20
extensions_built_in/dataset_tools/DatasetTools.py
Normal file
20
extensions_built_in/dataset_tools/DatasetTools.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from collections import OrderedDict
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseExtensionProcess
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class DatasetTools(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
raise NotImplementedError("This extension is not yet implemented")
|
||||
196
extensions_built_in/dataset_tools/SuperTagger.py
Normal file
196
extensions_built_in/dataset_tools/SuperTagger.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import gc
|
||||
import traceback
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from tqdm import tqdm
|
||||
|
||||
from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
|
||||
from .tools.fuyu_utils import FuyuImageProcessor
|
||||
from .tools.image_tools import load_image, ImageProcessor, resize_to_max
|
||||
from .tools.llava_utils import LLaVAImageProcessor
|
||||
from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from .tools.sync_tools import get_img_paths
|
||||
|
||||
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
VERSION = 2
|
||||
|
||||
|
||||
class SuperTagger(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
parent_dir = config.get('parent_dir', None)
|
||||
self.dataset_paths: list[str] = config.get('dataset_paths', [])
|
||||
self.device = config.get('device', 'cuda')
|
||||
self.steps: list[Step] = config.get('steps', [])
|
||||
self.caption_method = config.get('caption_method', 'llava:default')
|
||||
self.caption_prompt = config.get('caption_prompt', default_long_prompt)
|
||||
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
|
||||
self.force_reprocess_img = config.get('force_reprocess_img', False)
|
||||
self.caption_replacements = config.get('caption_replacements', default_replacements)
|
||||
self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
|
||||
self.master_dataset_dict = OrderedDict()
|
||||
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
|
||||
if parent_dir is not None and len(self.dataset_paths) == 0:
|
||||
# find all folders in the patent_dataset_path
|
||||
self.dataset_paths = [
|
||||
os.path.join(parent_dir, folder)
|
||||
for folder in os.listdir(parent_dir)
|
||||
if os.path.isdir(os.path.join(parent_dir, folder))
|
||||
]
|
||||
else:
|
||||
# make sure they exist
|
||||
for dataset_path in self.dataset_paths:
|
||||
if not os.path.exists(dataset_path):
|
||||
raise ValueError(f"Dataset path does not exist: {dataset_path}")
|
||||
|
||||
print(f"Found {len(self.dataset_paths)} dataset paths")
|
||||
|
||||
self.image_processor: ImageProcessor = self.get_image_processor()
|
||||
|
||||
def get_image_processor(self):
|
||||
if self.caption_method.startswith('llava'):
|
||||
return LLaVAImageProcessor(device=self.device)
|
||||
elif self.caption_method.startswith('fuyu'):
|
||||
return FuyuImageProcessor(device=self.device)
|
||||
else:
|
||||
raise ValueError(f"Unknown caption method: {self.caption_method}")
|
||||
|
||||
def process_image(self, img_path: str):
|
||||
root_img_dir = os.path.dirname(os.path.dirname(img_path))
|
||||
filename = os.path.basename(img_path)
|
||||
filename_no_ext = os.path.splitext(filename)[0]
|
||||
train_dir = os.path.join(root_img_dir, TRAIN_DIR)
|
||||
train_img_path = os.path.join(train_dir, filename)
|
||||
json_path = os.path.join(train_dir, f"{filename_no_ext}.json")
|
||||
|
||||
# check if json exists, if it does load it as image info
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r') as f:
|
||||
img_info = ImgInfo(**json.load(f))
|
||||
else:
|
||||
img_info = ImgInfo()
|
||||
|
||||
# always send steps first in case other processes need them
|
||||
img_info.add_steps(copy.deepcopy(self.steps))
|
||||
img_info.set_version(VERSION)
|
||||
img_info.set_caption_method(self.caption_method)
|
||||
|
||||
image: Image = None
|
||||
caption_image: Image = None
|
||||
|
||||
did_update_image = False
|
||||
|
||||
# trigger reprocess of steps
|
||||
if self.force_reprocess_img:
|
||||
img_info.trigger_image_reprocess()
|
||||
|
||||
# set the image as updated if it does not exist on disk
|
||||
if not os.path.exists(train_img_path):
|
||||
did_update_image = True
|
||||
image = load_image(img_path)
|
||||
if img_info.force_image_process:
|
||||
did_update_image = True
|
||||
image = load_image(img_path)
|
||||
|
||||
# go through the needed steps
|
||||
for step in copy.deepcopy(img_info.state.steps_to_complete):
|
||||
if step == 'caption':
|
||||
# load image
|
||||
if image is None:
|
||||
image = load_image(img_path)
|
||||
if caption_image is None:
|
||||
caption_image = resize_to_max(image, 1024, 1024)
|
||||
|
||||
if not self.image_processor.is_loaded:
|
||||
print('Loading Model. Takes a while, especially the first time')
|
||||
self.image_processor.load_model()
|
||||
|
||||
img_info.caption = self.image_processor.generate_caption(
|
||||
image=caption_image,
|
||||
prompt=self.caption_prompt,
|
||||
replacements=self.caption_replacements
|
||||
)
|
||||
img_info.mark_step_complete(step)
|
||||
elif step == 'caption_short':
|
||||
# load image
|
||||
if image is None:
|
||||
image = load_image(img_path)
|
||||
|
||||
if caption_image is None:
|
||||
caption_image = resize_to_max(image, 1024, 1024)
|
||||
|
||||
if not self.image_processor.is_loaded:
|
||||
print('Loading Model. Takes a while, especially the first time')
|
||||
self.image_processor.load_model()
|
||||
img_info.caption_short = self.image_processor.generate_caption(
|
||||
image=caption_image,
|
||||
prompt=self.caption_short_prompt,
|
||||
replacements=self.caption_short_replacements
|
||||
)
|
||||
img_info.mark_step_complete(step)
|
||||
elif step == 'contrast_stretch':
|
||||
# load image
|
||||
if image is None:
|
||||
image = load_image(img_path)
|
||||
image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True)
|
||||
did_update_image = True
|
||||
img_info.mark_step_complete(step)
|
||||
else:
|
||||
raise ValueError(f"Unknown step: {step}")
|
||||
|
||||
os.makedirs(os.path.dirname(train_img_path), exist_ok=True)
|
||||
if did_update_image:
|
||||
image.save(train_img_path)
|
||||
|
||||
if img_info.is_dirty:
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(img_info.to_dict(), f, indent=4)
|
||||
|
||||
if self.dataset_master_config_file:
|
||||
# add to master dict
|
||||
self.master_dataset_dict[train_img_path] = img_info.to_dict()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
imgs_to_process = []
|
||||
# find all images
|
||||
for dataset_path in self.dataset_paths:
|
||||
raw_dir = os.path.join(dataset_path, RAW_DIR)
|
||||
raw_image_paths = get_img_paths(raw_dir)
|
||||
for raw_image_path in raw_image_paths:
|
||||
imgs_to_process.append(raw_image_path)
|
||||
|
||||
if len(imgs_to_process) == 0:
|
||||
print(f"No images to process")
|
||||
else:
|
||||
print(f"Found {len(imgs_to_process)} to process")
|
||||
|
||||
for img_path in tqdm(imgs_to_process, desc="Processing images"):
|
||||
try:
|
||||
self.process_image(img_path)
|
||||
except Exception:
|
||||
# print full stack trace
|
||||
print(traceback.format_exc())
|
||||
continue
|
||||
# self.process_image(img_path)
|
||||
|
||||
if self.dataset_master_config_file is not None:
|
||||
# save it as json
|
||||
with open(self.dataset_master_config_file, 'w') as f:
|
||||
json.dump(self.master_dataset_dict, f, indent=4)
|
||||
|
||||
del self.image_processor
|
||||
flush()
|
||||
131
extensions_built_in/dataset_tools/SyncFromCollection.py
Normal file
131
extensions_built_in/dataset_tools/SyncFromCollection.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
import gc
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
|
||||
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
|
||||
get_img_paths
|
||||
from jobs.process import BaseExtensionProcess
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class SyncFromCollection(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
|
||||
self.min_width = config.get('min_width', 1024)
|
||||
self.min_height = config.get('min_height', 1024)
|
||||
|
||||
# add our min_width and min_height to each dataset config if they don't exist
|
||||
for dataset_config in config.get('dataset_sync', []):
|
||||
if 'min_width' not in dataset_config:
|
||||
dataset_config['min_width'] = self.min_width
|
||||
if 'min_height' not in dataset_config:
|
||||
dataset_config['min_height'] = self.min_height
|
||||
|
||||
self.dataset_configs: List[DatasetSyncCollectionConfig] = [
|
||||
DatasetSyncCollectionConfig(**dataset_config)
|
||||
for dataset_config in config.get('dataset_sync', [])
|
||||
]
|
||||
print(f"Found {len(self.dataset_configs)} dataset configs")
|
||||
|
||||
def move_new_images(self, root_dir: str):
|
||||
raw_dir = os.path.join(root_dir, RAW_DIR)
|
||||
new_dir = os.path.join(root_dir, NEW_DIR)
|
||||
new_images = get_img_paths(new_dir)
|
||||
|
||||
for img_path in new_images:
|
||||
# move to raw
|
||||
new_path = os.path.join(raw_dir, os.path.basename(img_path))
|
||||
shutil.move(img_path, new_path)
|
||||
|
||||
# remove new dir
|
||||
shutil.rmtree(new_dir)
|
||||
|
||||
def sync_dataset(self, config: DatasetSyncCollectionConfig):
|
||||
if config.host == 'unsplash':
|
||||
get_images = get_unsplash_images
|
||||
elif config.host == 'pexels':
|
||||
get_images = get_pexels_images
|
||||
else:
|
||||
raise ValueError(f"Unknown host: {config.host}")
|
||||
|
||||
results = {
|
||||
'num_downloaded': 0,
|
||||
'num_skipped': 0,
|
||||
'bad': 0,
|
||||
'total': 0,
|
||||
}
|
||||
|
||||
photos = get_images(config)
|
||||
raw_dir = os.path.join(config.directory, RAW_DIR)
|
||||
new_dir = os.path.join(config.directory, NEW_DIR)
|
||||
raw_images = get_local_image_file_names(raw_dir)
|
||||
new_images = get_local_image_file_names(new_dir)
|
||||
|
||||
for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
|
||||
try:
|
||||
if photo.filename not in raw_images and photo.filename not in new_images:
|
||||
download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
|
||||
results['num_downloaded'] += 1
|
||||
else:
|
||||
results['num_skipped'] += 1
|
||||
except Exception as e:
|
||||
print(f" - BAD({photo.id}): {e}")
|
||||
results['bad'] += 1
|
||||
continue
|
||||
results['total'] += 1
|
||||
|
||||
return results
|
||||
|
||||
def print_results(self, results):
|
||||
print(
|
||||
f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print(f"Syncing {len(self.dataset_configs)} datasets")
|
||||
all_results = None
|
||||
failed_datasets = []
|
||||
for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
|
||||
try:
|
||||
results = self.sync_dataset(dataset_config)
|
||||
if all_results is None:
|
||||
all_results = {**results}
|
||||
else:
|
||||
for key, value in results.items():
|
||||
all_results[key] += value
|
||||
|
||||
self.print_results(results)
|
||||
except Exception as e:
|
||||
print(f" - FAILED: {e}")
|
||||
if 'response' in e.__dict__:
|
||||
error = f"{e.response.status_code}: {e.response.text}"
|
||||
print(f" - {error}")
|
||||
failed_datasets.append({'dataset': dataset_config, 'error': error})
|
||||
else:
|
||||
failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
|
||||
continue
|
||||
|
||||
print("Moving new images to raw")
|
||||
for dataset_config in self.dataset_configs:
|
||||
self.move_new_images(dataset_config.directory)
|
||||
|
||||
print("Done syncing datasets")
|
||||
self.print_results(all_results)
|
||||
|
||||
if len(failed_datasets) > 0:
|
||||
print(f"Failed to sync {len(failed_datasets)} datasets")
|
||||
for failed in failed_datasets:
|
||||
print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
|
||||
print(f" - ERR: {failed['error']}")
|
||||
43
extensions_built_in/dataset_tools/__init__.py
Normal file
43
extensions_built_in/dataset_tools/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
class DatasetToolsExtension(Extension):
|
||||
uid = "dataset_tools"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Dataset Tools"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .DatasetTools import DatasetTools
|
||||
return DatasetTools
|
||||
|
||||
|
||||
class SyncFromCollectionExtension(Extension):
|
||||
uid = "sync_from_collection"
|
||||
name = "Sync from Collection"
|
||||
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .SyncFromCollection import SyncFromCollection
|
||||
return SyncFromCollection
|
||||
|
||||
|
||||
class SuperTaggerExtension(Extension):
|
||||
uid = "super_tagger"
|
||||
name = "Super Tagger"
|
||||
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .SuperTagger import SuperTagger
|
||||
return SuperTagger
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
SyncFromCollectionExtension, DatasetToolsExtension, SuperTaggerExtension
|
||||
]
|
||||
53
extensions_built_in/dataset_tools/tools/caption.py
Normal file
53
extensions_built_in/dataset_tools/tools/caption.py
Normal file
@@ -0,0 +1,53 @@
|
||||
|
||||
caption_manipulation_steps = ['caption', 'caption_short']
|
||||
|
||||
default_long_prompt = 'caption this image. describe every single thing in the image in detail. Do not include any unnecessary words in your description for the sake of good grammar. I want many short statements that serve the single purpose of giving the most thorough description if items as possible in the smallest, comma separated way possible. be sure to describe people\'s moods, clothing, the environment, lighting, colors, and everything.'
|
||||
default_short_prompt = 'caption this image in less than ten words'
|
||||
|
||||
default_replacements = [
|
||||
("the image features", ""),
|
||||
("the image shows", ""),
|
||||
("the image depicts", ""),
|
||||
("the image is", ""),
|
||||
("in this image", ""),
|
||||
("in the image", ""),
|
||||
]
|
||||
|
||||
|
||||
def clean_caption(cap, replacements=None):
|
||||
if replacements is None:
|
||||
replacements = default_replacements
|
||||
|
||||
# remove any newlines
|
||||
cap = cap.replace("\n", ", ")
|
||||
cap = cap.replace("\r", ", ")
|
||||
cap = cap.replace(".", ",")
|
||||
cap = cap.replace("\"", "")
|
||||
|
||||
# remove unicode characters
|
||||
cap = cap.encode('ascii', 'ignore').decode('ascii')
|
||||
|
||||
# make lowercase
|
||||
cap = cap.lower()
|
||||
# remove any extra spaces
|
||||
cap = " ".join(cap.split())
|
||||
|
||||
for replacement in replacements:
|
||||
if replacement[0].startswith('*'):
|
||||
# we are removing all text if it starts with this and the rest matches
|
||||
search_text = replacement[0][1:]
|
||||
if cap.startswith(search_text):
|
||||
cap = ""
|
||||
else:
|
||||
cap = cap.replace(replacement[0].lower(), replacement[1].lower())
|
||||
|
||||
cap_list = cap.split(",")
|
||||
# trim whitespace
|
||||
cap_list = [c.strip() for c in cap_list]
|
||||
# remove empty strings
|
||||
cap_list = [c for c in cap_list if c != ""]
|
||||
# remove duplicates
|
||||
cap_list = list(dict.fromkeys(cap_list))
|
||||
# join back together
|
||||
cap = ", ".join(cap_list)
|
||||
return cap
|
||||
@@ -0,0 +1,187 @@
|
||||
import json
|
||||
from typing import Literal, Type, TYPE_CHECKING
|
||||
|
||||
Host: Type = Literal['unsplash', 'pexels']
|
||||
|
||||
RAW_DIR = "raw"
|
||||
NEW_DIR = "_tmp"
|
||||
TRAIN_DIR = "train"
|
||||
DEPTH_DIR = "depth"
|
||||
|
||||
from .image_tools import Step, img_manipulation_steps
|
||||
from .caption import caption_manipulation_steps
|
||||
|
||||
|
||||
class DatasetSyncCollectionConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.host: Host = kwargs.get('host', None)
|
||||
self.collection_id: str = kwargs.get('collection_id', None)
|
||||
self.directory: str = kwargs.get('directory', None)
|
||||
self.api_key: str = kwargs.get('api_key', None)
|
||||
self.min_width: int = kwargs.get('min_width', 1024)
|
||||
self.min_height: int = kwargs.get('min_height', 1024)
|
||||
|
||||
if self.host is None:
|
||||
raise ValueError("host is required")
|
||||
if self.collection_id is None:
|
||||
raise ValueError("collection_id is required")
|
||||
if self.directory is None:
|
||||
raise ValueError("directory is required")
|
||||
if self.api_key is None:
|
||||
raise ValueError(f"api_key is required: {self.host}:{self.collection_id}")
|
||||
|
||||
|
||||
class ImageState:
|
||||
def __init__(self, **kwargs):
|
||||
self.steps_complete: list[Step] = kwargs.get('steps_complete', [])
|
||||
self.steps_to_complete: list[Step] = kwargs.get('steps_to_complete', [])
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'steps_complete': self.steps_complete
|
||||
}
|
||||
|
||||
|
||||
class Rect:
|
||||
def __init__(self, **kwargs):
|
||||
self.x = kwargs.get('x', 0)
|
||||
self.y = kwargs.get('y', 0)
|
||||
self.width = kwargs.get('width', 0)
|
||||
self.height = kwargs.get('height', 0)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'x': self.x,
|
||||
'y': self.y,
|
||||
'width': self.width,
|
||||
'height': self.height
|
||||
}
|
||||
|
||||
|
||||
class ImgInfo:
|
||||
def __init__(self, **kwargs):
|
||||
self.version: int = kwargs.get('version', None)
|
||||
self.caption: str = kwargs.get('caption', None)
|
||||
self.caption_short: str = kwargs.get('caption_short', None)
|
||||
self.poi = [Rect(**poi) for poi in kwargs.get('poi', [])]
|
||||
self.state = ImageState(**kwargs.get('state', {}))
|
||||
self.caption_method = kwargs.get('caption_method', None)
|
||||
self.other_captions = kwargs.get('other_captions', {})
|
||||
self._upgrade_state()
|
||||
self.force_image_process: bool = False
|
||||
self._requested_steps: list[Step] = []
|
||||
|
||||
self.is_dirty: bool = False
|
||||
|
||||
def _upgrade_state(self):
|
||||
# upgrades older states
|
||||
if self.caption is not None and 'caption' not in self.state.steps_complete:
|
||||
self.mark_step_complete('caption')
|
||||
self.is_dirty = True
|
||||
if self.caption_short is not None and 'caption_short' not in self.state.steps_complete:
|
||||
self.mark_step_complete('caption_short')
|
||||
self.is_dirty = True
|
||||
if self.caption_method is None and self.caption is not None:
|
||||
# added caption method in version 2. Was all llava before that
|
||||
self.caption_method = 'llava:default'
|
||||
self.is_dirty = True
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'version': self.version,
|
||||
'caption_method': self.caption_method,
|
||||
'caption': self.caption,
|
||||
'caption_short': self.caption_short,
|
||||
'poi': [poi.to_dict() for poi in self.poi],
|
||||
'state': self.state.to_dict(),
|
||||
'other_captions': self.other_captions
|
||||
}
|
||||
|
||||
def mark_step_complete(self, step: Step):
|
||||
if step not in self.state.steps_complete:
|
||||
self.state.steps_complete.append(step)
|
||||
if step in self.state.steps_to_complete:
|
||||
self.state.steps_to_complete.remove(step)
|
||||
self.is_dirty = True
|
||||
|
||||
def add_step(self, step: Step):
|
||||
if step not in self.state.steps_to_complete and step not in self.state.steps_complete:
|
||||
self.state.steps_to_complete.append(step)
|
||||
|
||||
def trigger_image_reprocess(self):
|
||||
if self._requested_steps is None:
|
||||
raise Exception("Must call add_steps before trigger_image_reprocess")
|
||||
steps = self._requested_steps
|
||||
# remove all image manipulationf from steps_to_complete
|
||||
for step in img_manipulation_steps:
|
||||
if step in self.state.steps_to_complete:
|
||||
self.state.steps_to_complete.remove(step)
|
||||
if step in self.state.steps_complete:
|
||||
self.state.steps_complete.remove(step)
|
||||
self.force_image_process = True
|
||||
self.is_dirty = True
|
||||
# we want to keep the order passed in process file
|
||||
for step in steps:
|
||||
if step in img_manipulation_steps:
|
||||
self.add_step(step)
|
||||
|
||||
def add_steps(self, steps: list[Step]):
|
||||
self._requested_steps = [step for step in steps]
|
||||
for stage in steps:
|
||||
self.add_step(stage)
|
||||
|
||||
# update steps if we have any img processes not complete, we have to reprocess them all
|
||||
# if any steps_to_complete are in img_manipulation_steps
|
||||
|
||||
is_manipulating_image = any([step in img_manipulation_steps for step in self.state.steps_to_complete])
|
||||
order_has_changed = False
|
||||
|
||||
if not is_manipulating_image:
|
||||
# check to see if order has changed. No need to if already redoing it. Will detect if ones are removed
|
||||
target_img_manipulation_order = [step for step in steps if step in img_manipulation_steps]
|
||||
current_img_manipulation_order = [step for step in self.state.steps_complete if
|
||||
step in img_manipulation_steps]
|
||||
if target_img_manipulation_order != current_img_manipulation_order:
|
||||
order_has_changed = True
|
||||
|
||||
if is_manipulating_image or order_has_changed:
|
||||
self.trigger_image_reprocess()
|
||||
|
||||
def set_caption_method(self, method: str):
|
||||
if self._requested_steps is None:
|
||||
raise Exception("Must call add_steps before set_caption_method")
|
||||
if self.caption_method != method:
|
||||
self.is_dirty = True
|
||||
# move previous caption method to other_captions
|
||||
if self.caption_method is not None and self.caption is not None or self.caption_short is not None:
|
||||
self.other_captions[self.caption_method] = {
|
||||
'caption': self.caption,
|
||||
'caption_short': self.caption_short,
|
||||
}
|
||||
self.caption_method = method
|
||||
self.caption = None
|
||||
self.caption_short = None
|
||||
# see if we have a caption from the new method
|
||||
if method in self.other_captions:
|
||||
self.caption = self.other_captions[method].get('caption', None)
|
||||
self.caption_short = self.other_captions[method].get('caption_short', None)
|
||||
else:
|
||||
self.trigger_new_caption()
|
||||
|
||||
def trigger_new_caption(self):
|
||||
self.caption = None
|
||||
self.caption_short = None
|
||||
self.is_dirty = True
|
||||
# check to see if we have any steps in the complete list and move them to the to_complete list
|
||||
for step in self.state.steps_complete:
|
||||
if step in caption_manipulation_steps:
|
||||
self.state.steps_complete.remove(step)
|
||||
self.state.steps_to_complete.append(step)
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
def set_version(self, version: int):
|
||||
if self.version != version:
|
||||
self.is_dirty = True
|
||||
self.version = version
|
||||
66
extensions_built_in/dataset_tools/tools/fuyu_utils.py
Normal file
66
extensions_built_in/dataset_tools/tools/fuyu_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from transformers import CLIPImageProcessor, BitsAndBytesConfig, AutoTokenizer
|
||||
|
||||
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class FuyuImageProcessor:
|
||||
def __init__(self, device='cuda'):
|
||||
from transformers import FuyuProcessor, FuyuForCausalLM
|
||||
self.device = device
|
||||
self.model: FuyuForCausalLM = None
|
||||
self.processor: FuyuProcessor = None
|
||||
self.dtype = torch.bfloat16
|
||||
self.tokenizer: AutoTokenizer
|
||||
self.is_loaded = False
|
||||
|
||||
def load_model(self):
|
||||
from transformers import FuyuProcessor, FuyuForCausalLM
|
||||
model_path = "adept/fuyu-8b"
|
||||
kwargs = {"device_map": self.device}
|
||||
kwargs['load_in_4bit'] = True
|
||||
kwargs['quantization_config'] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=self.dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type='nf4'
|
||||
)
|
||||
self.processor = FuyuProcessor.from_pretrained(model_path)
|
||||
self.model = FuyuForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
||||
self.is_loaded = True
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
self.model = FuyuForCausalLM.from_pretrained(model_path, torch_dtype=self.dtype, **kwargs)
|
||||
self.processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=self.tokenizer)
|
||||
|
||||
def generate_caption(
|
||||
self, image: Image,
|
||||
prompt: str = default_long_prompt,
|
||||
replacements=default_replacements,
|
||||
max_new_tokens=512
|
||||
):
|
||||
# prepare inputs for the model
|
||||
# text_prompt = f"{prompt}\n"
|
||||
|
||||
# image = image.convert('RGB')
|
||||
model_inputs = self.processor(text=prompt, images=[image])
|
||||
model_inputs = {k: v.to(dtype=self.dtype if torch.is_floating_point(v) else v.dtype, device=self.device) for k, v in
|
||||
model_inputs.items()}
|
||||
|
||||
generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens)
|
||||
prompt_len = model_inputs["input_ids"].shape[-1]
|
||||
output = self.tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
|
||||
output = clean_caption(output, replacements=replacements)
|
||||
return output
|
||||
|
||||
# inputs = self.processor(text=text_prompt, images=image, return_tensors="pt")
|
||||
# for k, v in inputs.items():
|
||||
# inputs[k] = v.to(self.device)
|
||||
|
||||
# # autoregressively generate text
|
||||
# generation_output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
# generation_text = self.processor.batch_decode(generation_output[:, -max_new_tokens:], skip_special_tokens=True)
|
||||
# output = generation_text[0]
|
||||
#
|
||||
# return clean_caption(output, replacements=replacements)
|
||||
49
extensions_built_in/dataset_tools/tools/image_tools.py
Normal file
49
extensions_built_in/dataset_tools/tools/image_tools.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Literal, Type, TYPE_CHECKING, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
Step: Type = Literal['caption', 'caption_short', 'create_mask', 'contrast_stretch']
|
||||
|
||||
img_manipulation_steps = ['contrast_stretch']
|
||||
|
||||
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llava_utils import LLaVAImageProcessor
|
||||
from .fuyu_utils import FuyuImageProcessor
|
||||
|
||||
ImageProcessor = Union['LLaVAImageProcessor', 'FuyuImageProcessor']
|
||||
|
||||
|
||||
def pil_to_cv2(image):
|
||||
"""Convert a PIL image to a cv2 image."""
|
||||
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
|
||||
def cv2_to_pil(image):
|
||||
"""Convert a cv2 image to a PIL image."""
|
||||
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
|
||||
def load_image(img_path: str):
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
try:
|
||||
# transpose with exif data
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except Exception as e:
|
||||
pass
|
||||
return image
|
||||
|
||||
|
||||
def resize_to_max(image, max_width=1024, max_height=1024):
|
||||
width, height = image.size
|
||||
if width <= max_width and height <= max_height:
|
||||
return image
|
||||
|
||||
scale = min(max_width / width, max_height / height)
|
||||
width = int(width * scale)
|
||||
height = int(height * scale)
|
||||
|
||||
return image.resize((width, height), Image.LANCZOS)
|
||||
85
extensions_built_in/dataset_tools/tools/llava_utils.py
Normal file
85
extensions_built_in/dataset_tools/tools/llava_utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
from .caption import default_long_prompt, default_short_prompt, default_replacements, clean_caption
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
||||
|
||||
img_ext = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
|
||||
|
||||
class LLaVAImageProcessor:
|
||||
def __init__(self, device='cuda'):
|
||||
try:
|
||||
from llava.model import LlavaLlamaForCausalLM
|
||||
except ImportError:
|
||||
# print("You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
|
||||
print(
|
||||
"You need to manually install llava -> pip install --no-deps git+https://github.com/haotian-liu/LLaVA.git")
|
||||
raise
|
||||
self.device = device
|
||||
self.model: LlavaLlamaForCausalLM = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.image_processor: CLIPImageProcessor = None
|
||||
self.is_loaded = False
|
||||
|
||||
def load_model(self):
|
||||
from llava.model import LlavaLlamaForCausalLM
|
||||
|
||||
model_path = "4bit/llava-v1.5-13b-3GB"
|
||||
# kwargs = {"device_map": "auto"}
|
||||
kwargs = {"device_map": self.device}
|
||||
kwargs['load_in_4bit'] = True
|
||||
kwargs['quantization_config'] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type='nf4'
|
||||
)
|
||||
self.model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
vision_tower = self.model.get_vision_tower()
|
||||
if not vision_tower.is_loaded:
|
||||
vision_tower.load_model()
|
||||
vision_tower.to(device=self.device)
|
||||
self.image_processor = vision_tower.image_processor
|
||||
self.is_loaded = True
|
||||
|
||||
def generate_caption(
|
||||
self, image:
|
||||
Image, prompt: str = default_long_prompt,
|
||||
replacements=default_replacements,
|
||||
max_new_tokens=512
|
||||
):
|
||||
from llava.conversation import conv_templates, SeparatorStyle
|
||||
from llava.utils import disable_torch_init
|
||||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
||||
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
|
||||
# question = "how many dogs are in the picture?"
|
||||
disable_torch_init()
|
||||
conv_mode = "llava_v0"
|
||||
conv = conv_templates[conv_mode].copy()
|
||||
roles = conv.roles
|
||||
image_tensor = self.image_processor.preprocess([image], return_tensors='pt')['pixel_values'].half().cuda()
|
||||
|
||||
inp = f"{roles[0]}: {prompt}"
|
||||
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
|
||||
conv.append_message(conv.roles[0], inp)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
raw_prompt = conv.get_prompt()
|
||||
input_ids = tokenizer_image_token(raw_prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
|
||||
return_tensors='pt').unsqueeze(0).cuda()
|
||||
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
||||
keywords = [stop_str]
|
||||
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
|
||||
with torch.inference_mode():
|
||||
output_ids = self.model.generate(
|
||||
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
|
||||
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
|
||||
top_p=0.8
|
||||
)
|
||||
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
||||
conv.messages[-1][-1] = outputs
|
||||
output = outputs.rsplit('</s>', 1)[0]
|
||||
return clean_caption(output, replacements=replacements)
|
||||
279
extensions_built_in/dataset_tools/tools/sync_tools.py
Normal file
279
extensions_built_in/dataset_tools/tools/sync_tools.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import os
|
||||
import requests
|
||||
import tqdm
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
|
||||
def img_root_path(img_id: str):
|
||||
return os.path.dirname(os.path.dirname(img_id))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .dataset_tools_config_modules import DatasetSyncCollectionConfig
|
||||
|
||||
img_exts = ['.jpg', '.jpeg', '.webp', '.png']
|
||||
|
||||
class Photo:
|
||||
def __init__(
|
||||
self,
|
||||
id,
|
||||
host,
|
||||
width,
|
||||
height,
|
||||
url,
|
||||
filename
|
||||
):
|
||||
self.id = str(id)
|
||||
self.host = host
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.url = url
|
||||
self.filename = filename
|
||||
|
||||
|
||||
def get_desired_size(img_width: int, img_height: int, min_width: int, min_height: int):
|
||||
if img_width > img_height:
|
||||
scale = min_height / img_height
|
||||
else:
|
||||
scale = min_width / img_width
|
||||
|
||||
new_width = int(img_width * scale)
|
||||
new_height = int(img_height * scale)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
|
||||
def get_pexels_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
|
||||
all_images = []
|
||||
next_page = f"https://api.pexels.com/v1/collections/{config.collection_id}?page=1&per_page=80&type=photos"
|
||||
|
||||
while True:
|
||||
response = requests.get(next_page, headers={
|
||||
"Authorization": f"{config.api_key}"
|
||||
})
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
all_images.extend(data['media'])
|
||||
if 'next_page' in data and data['next_page']:
|
||||
next_page = data['next_page']
|
||||
else:
|
||||
break
|
||||
|
||||
photos = []
|
||||
for image in all_images:
|
||||
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
|
||||
url = f"{image['src']['original']}?auto=compress&cs=tinysrgb&h={new_height}&w={new_width}"
|
||||
filename = os.path.basename(image['src']['original'])
|
||||
|
||||
photos.append(Photo(
|
||||
id=image['id'],
|
||||
host="pexels",
|
||||
width=image['width'],
|
||||
height=image['height'],
|
||||
url=url,
|
||||
filename=filename
|
||||
))
|
||||
|
||||
return photos
|
||||
|
||||
|
||||
def get_unsplash_images(config: 'DatasetSyncCollectionConfig') -> List[Photo]:
|
||||
headers = {
|
||||
# "Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"
|
||||
"Authorization": f"Client-ID {config.api_key}"
|
||||
}
|
||||
# headers['Authorization'] = f"Bearer {token}"
|
||||
|
||||
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page=1&per_page=30"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
res_headers = response.headers
|
||||
# parse the link header to get the next page
|
||||
# 'Link': '<https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=82>; rel="last", <https://api.unsplash.com/collections/mIPWwLdfct8/photos?page=2>; rel="next"'
|
||||
has_next_page = False
|
||||
if 'Link' in res_headers:
|
||||
has_next_page = True
|
||||
link_header = res_headers['Link']
|
||||
link_header = link_header.split(',')
|
||||
link_header = [link.strip() for link in link_header]
|
||||
link_header = [link.split(';') for link in link_header]
|
||||
link_header = [[link[0].strip('<>'), link[1].strip().strip('"')] for link in link_header]
|
||||
link_header = {link[1]: link[0] for link in link_header}
|
||||
|
||||
# get page number from last url
|
||||
last_page = link_header['rel="last']
|
||||
last_page = last_page.split('?')[1]
|
||||
last_page = last_page.split('&')
|
||||
last_page = [param.split('=') for param in last_page]
|
||||
last_page = {param[0]: param[1] for param in last_page}
|
||||
last_page = int(last_page['page'])
|
||||
|
||||
all_images = response.json()
|
||||
|
||||
if has_next_page:
|
||||
# assume we start on page 1, so we don't need to get it again
|
||||
for page in tqdm.tqdm(range(2, last_page + 1)):
|
||||
url = f"https://api.unsplash.com/collections/{config.collection_id}/photos?page={page}&per_page=30"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
all_images.extend(response.json())
|
||||
|
||||
photos = []
|
||||
for image in all_images:
|
||||
new_width, new_height = get_desired_size(image['width'], image['height'], config.min_width, config.min_height)
|
||||
url = f"{image['urls']['raw']}&w={new_width}"
|
||||
filename = f"{image['id']}.jpg"
|
||||
|
||||
photos.append(Photo(
|
||||
id=image['id'],
|
||||
host="unsplash",
|
||||
width=image['width'],
|
||||
height=image['height'],
|
||||
url=url,
|
||||
filename=filename
|
||||
))
|
||||
|
||||
return photos
|
||||
|
||||
|
||||
def get_img_paths(dir_path: str):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
local_files = os.listdir(dir_path)
|
||||
# remove non image files
|
||||
local_files = [file for file in local_files if os.path.splitext(file)[1].lower() in img_exts]
|
||||
# make full path
|
||||
local_files = [os.path.join(dir_path, file) for file in local_files]
|
||||
return local_files
|
||||
|
||||
|
||||
def get_local_image_ids(dir_path: str):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
local_files = get_img_paths(dir_path)
|
||||
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
|
||||
return set([os.path.basename(file).split('.')[0] for file in local_files])
|
||||
|
||||
|
||||
def get_local_image_file_names(dir_path: str):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
local_files = get_img_paths(dir_path)
|
||||
# assuming local files are named after Unsplash IDs, e.g., 'abc123.jpg'
|
||||
return set([os.path.basename(file) for file in local_files])
|
||||
|
||||
|
||||
def download_image(photo: Photo, dir_path: str, min_width: int = 1024, min_height: int = 1024):
|
||||
img_width = photo.width
|
||||
img_height = photo.height
|
||||
|
||||
if img_width < min_width or img_height < min_height:
|
||||
raise ValueError(f"Skipping {photo.id} because it is too small: {img_width}x{img_height}")
|
||||
|
||||
img_response = requests.get(photo.url)
|
||||
img_response.raise_for_status()
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
filename = os.path.join(dir_path, photo.filename)
|
||||
with open(filename, 'wb') as file:
|
||||
file.write(img_response.content)
|
||||
|
||||
|
||||
def update_caption(img_path: str):
|
||||
# if the caption is a txt file, convert it to a json file
|
||||
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
# see if it exists
|
||||
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json")):
|
||||
# todo add poi and what not
|
||||
return # we have a json file
|
||||
caption = ""
|
||||
# see if txt file exists
|
||||
if os.path.exists(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt")):
|
||||
# read it
|
||||
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"), 'r') as file:
|
||||
caption = file.read()
|
||||
# write json file
|
||||
with open(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.json"), 'w') as file:
|
||||
file.write(f'{{"caption": "{caption}"}}')
|
||||
|
||||
# delete txt file
|
||||
os.remove(os.path.join(os.path.dirname(img_path), f"{filename_no_ext}.txt"))
|
||||
|
||||
|
||||
# def equalize_img(img_path: str):
|
||||
# input_path = img_path
|
||||
# output_path = os.path.join(img_root_path(img_path), COLOR_CORRECTED_DIR, os.path.basename(img_path))
|
||||
# os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
# process_img(
|
||||
# img_path=input_path,
|
||||
# output_path=output_path,
|
||||
# equalize=True,
|
||||
# max_size=2056,
|
||||
# white_balance=False,
|
||||
# gamma_correction=False,
|
||||
# strength=0.6,
|
||||
# )
|
||||
|
||||
|
||||
# def annotate_depth(img_path: str):
|
||||
# # make fake args
|
||||
# args = argparse.Namespace()
|
||||
# args.annotator = "midas"
|
||||
# args.res = 1024
|
||||
#
|
||||
# img = cv2.imread(img_path)
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
#
|
||||
# output = annotate(img, args)
|
||||
#
|
||||
# output = output.astype('uint8')
|
||||
# output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
#
|
||||
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
||||
# output_path = os.path.join(img_root_path(img_path), DEPTH_DIR, os.path.basename(img_path))
|
||||
#
|
||||
# cv2.imwrite(output_path, output)
|
||||
|
||||
|
||||
# def invert_depth(img_path: str):
|
||||
# img = cv2.imread(img_path)
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
# # invert the colors
|
||||
# img = cv2.bitwise_not(img)
|
||||
#
|
||||
# os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
||||
# output_path = os.path.join(img_root_path(img_path), INVERTED_DEPTH_DIR, os.path.basename(img_path))
|
||||
# cv2.imwrite(output_path, img)
|
||||
|
||||
|
||||
#
|
||||
# # update our list of raw images
|
||||
# raw_images = get_img_paths(raw_dir)
|
||||
#
|
||||
# # update raw captions
|
||||
# for image_id in tqdm.tqdm(raw_images, desc="Updating raw captions"):
|
||||
# update_caption(image_id)
|
||||
#
|
||||
# # equalize images
|
||||
# for img_path in tqdm.tqdm(raw_images, desc="Equalizing images"):
|
||||
# if img_path not in eq_images:
|
||||
# equalize_img(img_path)
|
||||
#
|
||||
# # update our list of eq images
|
||||
# eq_images = get_img_paths(eq_dir)
|
||||
# # update eq captions
|
||||
# for image_id in tqdm.tqdm(eq_images, desc="Updating eq captions"):
|
||||
# update_caption(image_id)
|
||||
#
|
||||
# # annotate depth
|
||||
# depth_dir = os.path.join(root_dir, DEPTH_DIR)
|
||||
# depth_images = get_img_paths(depth_dir)
|
||||
# for img_path in tqdm.tqdm(eq_images, desc="Annotating depth"):
|
||||
# if img_path not in depth_images:
|
||||
# annotate_depth(img_path)
|
||||
#
|
||||
# depth_images = get_img_paths(depth_dir)
|
||||
#
|
||||
# # invert depth
|
||||
# inv_depth_dir = os.path.join(root_dir, INVERTED_DEPTH_DIR)
|
||||
# inv_depth_images = get_img_paths(inv_depth_dir)
|
||||
# for img_path in tqdm.tqdm(depth_images, desc="Inverting depth"):
|
||||
# if img_path not in inv_depth_images:
|
||||
# invert_depth(img_path)
|
||||
@@ -16,6 +16,7 @@ from toolkit import train_tools
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
import random
|
||||
from toolkit.basic import value_map
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -91,11 +92,18 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
network_neg_weight = network_neg_weight.item()
|
||||
|
||||
# get an array of random floats between -weight_jitter and weight_jitter
|
||||
loss_jitter_multiplier = 1.0
|
||||
weight_jitter = self.slider_config.weight_jitter
|
||||
if weight_jitter > 0.0:
|
||||
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
||||
orig_network_pos_weight = network_pos_weight
|
||||
network_pos_weight += jitter_list
|
||||
network_neg_weight += (jitter_list * -1.0)
|
||||
# penalize the loss for its distance from network_pos_weight
|
||||
# a jitter_list of abs(3.0) on a weight of 5.0 is a 60% jitter
|
||||
# so the loss_jitter_multiplier needs to be 0.4
|
||||
loss_jitter_multiplier = value_map(abs(jitter_list), 0.0, weight_jitter, 1.0, 0.0)
|
||||
|
||||
|
||||
# if items in network_weight list are tensors, convert them to floats
|
||||
|
||||
@@ -146,38 +154,33 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0]
|
||||
|
||||
flush()
|
||||
|
||||
loss_float = None
|
||||
loss_mirror_float = None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
noisy_latents.requires_grad = False
|
||||
|
||||
# if training text encoder enable grads, else do context of no grad
|
||||
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
||||
# text encoding
|
||||
embedding_list = []
|
||||
# embed the prompts
|
||||
# fix issue with them being tuples sometimes
|
||||
prompt_list = []
|
||||
for prompt in prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
if isinstance(prompt, tuple):
|
||||
prompt = prompt[0]
|
||||
prompt_list.append(prompt)
|
||||
conditional_embeds = self.sd.encode_prompt(prompt_list).to(self.device_torch, dtype=dtype)
|
||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
|
||||
if self.model_config.is_xl:
|
||||
# todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
|
||||
network_multiplier_list = network_multiplier
|
||||
noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
|
||||
noise_list = torch.chunk(noise, 2, dim=0)
|
||||
timesteps_list = torch.chunk(timesteps, 2, dim=0)
|
||||
conditional_embeds_list = split_prompt_embeds(conditional_embeds)
|
||||
else:
|
||||
network_multiplier_list = [network_multiplier]
|
||||
noisy_latent_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
conditional_embeds_list = [conditional_embeds]
|
||||
# if self.model_config.is_xl:
|
||||
# # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
|
||||
# network_multiplier_list = network_multiplier
|
||||
# noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
|
||||
# noise_list = torch.chunk(noise, 2, dim=0)
|
||||
# timesteps_list = torch.chunk(timesteps, 2, dim=0)
|
||||
# conditional_embeds_list = split_prompt_embeds(conditional_embeds)
|
||||
# else:
|
||||
network_multiplier_list = [network_multiplier]
|
||||
noisy_latent_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
conditional_embeds_list = [conditional_embeds]
|
||||
|
||||
losses = []
|
||||
# allow to chunk it out to save vram
|
||||
@@ -205,20 +208,17 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
# todo add snr gamma here
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
loss_slide_float = loss.item()
|
||||
loss = loss.mean() * loss_jitter_multiplier
|
||||
|
||||
loss_float = loss.item()
|
||||
losses.append(loss_float)
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
from collections import OrderedDict
|
||||
from torch.utils.data import DataLoader
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from typing import Union, Literal, List
|
||||
from diffusers import T2IAdapter
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map, adain, get_mean_std
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
||||
apply_learnable_snr_gos, LearnableSNRGamma
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -13,101 +23,618 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
adapter_transforms = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
|
||||
class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
self.assistant_adapter: Union['T2IAdapter', None]
|
||||
self.do_prior_prediction = False
|
||||
self.do_long_prompts = False
|
||||
if self.train_config.inverted_mask_prior:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def before_dataset_load(self):
|
||||
self.assistant_adapter = None
|
||||
# get adapter assistant if one is set
|
||||
if self.train_config.adapter_assist_name_or_path is not None:
|
||||
adapter_path = self.train_config.adapter_assist_name_or_path
|
||||
|
||||
# dont name this adapter since we are not training it
|
||||
self.assistant_adapter = T2IAdapter.from_pretrained(
|
||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
||||
).to(self.device_torch)
|
||||
self.assistant_adapter.eval()
|
||||
self.assistant_adapter.requires_grad_(False)
|
||||
flush()
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
|
||||
# textual inversion
|
||||
if self.embedding is not None:
|
||||
# keep original embeddings as reference
|
||||
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
|
||||
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
||||
self.sd.text_encoder.train()
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
flush()
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.embedding:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
# move vae to device if we did not cache latents
|
||||
if not self.is_latents_cached:
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
# offload it. Already cached
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
||||
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
embedding_list = []
|
||||
# embed the prompts
|
||||
for prompt in conditioned_prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
# you can expand these in a child class to make customization easier
|
||||
def calculate_loss(
|
||||
self,
|
||||
noise_pred: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
noisy_latents: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
mask_multiplier: Union[torch.Tensor, float] = 1.0,
|
||||
prior_pred: Union[torch.Tensor, None] = None,
|
||||
**kwargs
|
||||
):
|
||||
loss_target = self.train_config.loss_target
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype)
|
||||
prior_mask_multiplier = None
|
||||
target_mask_multiplier = None
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
if self.train_config.match_noise_norm:
|
||||
# match the norm of the noise
|
||||
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||
|
||||
if self.train_config.inverted_mask_prior:
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
prior_mask_multiplier = 1.0 - mask_multiplier
|
||||
# target_mask_multiplier = mask_multiplier
|
||||
# mask_multiplier = 1.0
|
||||
target = noise
|
||||
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
||||
# set masked multiplier to 1.0 so we dont double apply it
|
||||
# mask_multiplier = 1.0
|
||||
elif prior_pred is not None:
|
||||
# matching adapter prediction
|
||||
target = prior_pred
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
pred = noise_pred
|
||||
|
||||
ignore_snr = False
|
||||
|
||||
if loss_target == 'source' or loss_target == 'unaugmented':
|
||||
# ignore_snr = True
|
||||
if batch.sigmas is None:
|
||||
raise ValueError("Batch sigmas is None. This should not happen")
|
||||
|
||||
# src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190
|
||||
denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents
|
||||
weighing = batch.sigmas ** -2.0
|
||||
if loss_target == 'source':
|
||||
# denoise the latent and compare to the latent in the batch
|
||||
target = batch.latents
|
||||
elif loss_target == 'unaugmented':
|
||||
# we have to encode images into latents for now
|
||||
# we also denoise as the unaugmented tensor is not a noisy diffirental
|
||||
with torch.no_grad():
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
|
||||
target = unaugmented_latents.detach()
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = target # we are computing loss against denoise latents
|
||||
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
||||
|
||||
# mse loss without reduction
|
||||
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
|
||||
loss = loss_per_element
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
if self.train_config.inverted_mask_prior:
|
||||
# to a loss to unmasked areas of the prior for unmasked regularization
|
||||
prior_loss = torch.nn.functional.mse_loss(
|
||||
prior_pred.float(),
|
||||
pred.float(),
|
||||
reduction="none"
|
||||
)
|
||||
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
|
||||
loss = loss + prior_loss
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
if self.train_config.learnable_snr_gos:
|
||||
# add snr_gamma
|
||||
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
# back propagate loss to free ram
|
||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
return batch
|
||||
|
||||
def get_guided_loss(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Perform targeted guidance (working title)
|
||||
conditional_noisy_latents = noisy_latents # target images
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
if batch.unconditional_latents is not None:
|
||||
# unconditional latents are the "neutral" images. Add noise here identical to
|
||||
# the noise added to the conditional latents, at the same timesteps
|
||||
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
|
||||
batch.unconditional_latents, noise, timesteps
|
||||
)
|
||||
|
||||
# calculate the differential between our conditional (target image) and out unconditional (neutral image)
|
||||
target_differential_noise = unconditional_noisy_latents - conditional_noisy_latents
|
||||
target_differential_noise = target_differential_noise.detach()
|
||||
|
||||
# add the target differential to the target latents as if it were noise with the scheduler, scaled to
|
||||
# the current timestep. Scaling the noise here is important as it scales our guidance to the current
|
||||
# timestep. This is the key to making the guidance work.
|
||||
guidance_latents = self.sd.noise_scheduler.add_noise(
|
||||
conditional_noisy_latents,
|
||||
target_differential_noise,
|
||||
timesteps
|
||||
)
|
||||
|
||||
# Disable the LoRA network so we can predict parent network knowledge without it
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
|
||||
# Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
|
||||
# This acts as our control to preserve the unaltered parts of the image.
|
||||
baseline_prediction = self.sd.predict_noise(
|
||||
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
).detach()
|
||||
|
||||
# turn the LoRA network back on.
|
||||
self.sd.unet.train()
|
||||
self.network.is_active = True
|
||||
self.network.multiplier = network_weight_list
|
||||
|
||||
# do our prediction with LoRA active on the scaled guidance latents
|
||||
prediction = self.sd.predict_noise(
|
||||
latents=guidance_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
|
||||
# remove the baseline prediction from our prediction to get the differential between the two
|
||||
# all that should be left is the differential between the conditional and unconditional images
|
||||
pred_differential_noise = prediction - baseline_prediction
|
||||
|
||||
# for loss, we target ONLY the unscaled differential between our conditional and unconditional latents
|
||||
# not the timestep scaled noise that was added. This is the diffusion training process.
|
||||
# This will guide the network to make identical predictions it previously did for everything EXCEPT our
|
||||
# differential between the conditional and unconditional images (target)
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
pred_differential_noise.float(),
|
||||
target_differential_noise.float(),
|
||||
reduction="none"
|
||||
)
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
loss.requires_grad_(True)
|
||||
|
||||
return loss
|
||||
|
||||
def get_prior_prediction(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# dont use network on this
|
||||
# self.network.multiplier = 0.0
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
self.sd.unet.train()
|
||||
prior_pred = prior_pred.detach()
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
# restore network
|
||||
# self.network.multiplier = network_weight_list
|
||||
self.network.is_active = was_network_active
|
||||
return prior_pred
|
||||
|
||||
def before_unet_predict(self):
|
||||
pass
|
||||
|
||||
def after_unet_predict(self):
|
||||
pass
|
||||
|
||||
def end_of_training_loop(self):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
match_adapter_assist = True
|
||||
elif self.train_config.match_adapter_chance > 0.0:
|
||||
match_adapter_assist = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
) < self.train_config.match_adapter_chance
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
with torch.no_grad():
|
||||
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
for idx, file_item in enumerate(batch.file_items):
|
||||
if file_item.is_reg:
|
||||
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
|
||||
|
||||
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if has_adapter_img and (self.adapter or self.assistant_adapter):
|
||||
with self.timer('get_adapter_images'):
|
||||
# todo move this to data loader
|
||||
if batch.control_tensor is not None:
|
||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
# match in channels
|
||||
if self.assistant_adapter is not None:
|
||||
in_channels = self.assistant_adapter.config.in_channels
|
||||
if adapter_images.shape[1] != in_channels:
|
||||
# we need to match the channels
|
||||
adapter_images = adapter_images[:, :in_channels, :, :]
|
||||
else:
|
||||
raise NotImplementedError("Adapter images now must be loaded with dataloader")
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
if batch.mask_tensor is not None:
|
||||
with self.timer('get_mask_multiplier'):
|
||||
# upsampling no supported for bfloat16
|
||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||
mask_multiplier = torch.nn.functional.interpolate(
|
||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||
)
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
def get_adapter_multiplier():
|
||||
if self.adapter and isinstance(self.adapter, T2IAdapter):
|
||||
# training a t2i adapter, not using as assistant.
|
||||
return 1.0
|
||||
elif match_adapter_assist:
|
||||
# training a texture. We want it high
|
||||
adapter_strength_min = 0.9
|
||||
adapter_strength_max = 1.0
|
||||
else:
|
||||
# training with assistance, we want it low
|
||||
adapter_strength_min = 0.4
|
||||
adapter_strength_max = 0.7
|
||||
# adapter_strength_min = 0.9
|
||||
# adapter_strength_max = 1.1
|
||||
|
||||
adapter_conditioning_scale = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
)
|
||||
|
||||
adapter_conditioning_scale = value_map(
|
||||
adapter_conditioning_scale,
|
||||
0.0,
|
||||
1.0,
|
||||
adapter_strength_min,
|
||||
adapter_strength_max
|
||||
)
|
||||
return adapter_conditioning_scale
|
||||
|
||||
# flush()
|
||||
with self.timer('grad_setup'):
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.embedding:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
# set the weights
|
||||
network.multiplier = network_weight_list
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# activate network if it exits
|
||||
|
||||
prompts_1 = conditioned_prompts
|
||||
prompts_2 = None
|
||||
if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
|
||||
prompts_1 = batch.get_caption_short_list()
|
||||
prompts_2 = conditioned_prompts
|
||||
|
||||
# make the batch splits
|
||||
if self.train_config.single_item_batching:
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
raise ValueError("Single item batching is not supported when training the refiner")
|
||||
batch_size = noisy_latents.shape[0]
|
||||
# chunk/split everything
|
||||
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
||||
noise_list = torch.chunk(noise, batch_size, dim=0)
|
||||
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
|
||||
conditioned_prompts_list = [[prompt] for prompt in prompts_1]
|
||||
if imgs is not None:
|
||||
imgs_list = torch.chunk(imgs, batch_size, dim=0)
|
||||
else:
|
||||
imgs_list = [None for _ in range(batch_size)]
|
||||
if adapter_images is not None:
|
||||
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
|
||||
else:
|
||||
adapter_images_list = [None for _ in range(batch_size)]
|
||||
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
|
||||
if prompts_2 is None:
|
||||
prompt_2_list = [None for _ in range(batch_size)]
|
||||
else:
|
||||
prompt_2_list = [[prompt] for prompt in prompts_2]
|
||||
|
||||
else:
|
||||
noisy_latents_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
conditioned_prompts_list = [prompts_1]
|
||||
imgs_list = [imgs]
|
||||
adapter_images_list = [adapter_images]
|
||||
mask_multiplier_list = [mask_multiplier]
|
||||
if prompts_2 is None:
|
||||
prompt_2_list = [None]
|
||||
else:
|
||||
prompt_2_list = [prompts_2]
|
||||
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
|
||||
noisy_latents_list,
|
||||
noise_list,
|
||||
timesteps_list,
|
||||
conditioned_prompts_list,
|
||||
imgs_list,
|
||||
adapter_images_list,
|
||||
mask_multiplier_list,
|
||||
prompt_2_list
|
||||
):
|
||||
if self.train_config.negative_prompt is not None:
|
||||
# add negative prompt
|
||||
conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
|
||||
range(len(conditioned_prompts))]
|
||||
if prompt_2 is not None:
|
||||
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||
|
||||
with network:
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
with torch.set_grad_enabled(False):
|
||||
# make sure it is in eval mode
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and (
|
||||
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.adapter if self.adapter else self.assistant_adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None:
|
||||
# do guided loss
|
||||
loss = self.get_guided_loss(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
batch=batch,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
else:
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
with self.timer('backward'):
|
||||
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
||||
loss = loss * loss_multiplier.mean()
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
# with fsdp_overlap_step_with_backward():
|
||||
loss.backward()
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# gradient accumulation. Just a place for breakpoint
|
||||
pass
|
||||
|
||||
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[
|
||||
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
|
||||
with torch.no_grad():
|
||||
self.sd.text_encoder.get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = self.orig_embeds_params[index_no_updates]
|
||||
with self.timer('restore_embeddings'):
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
self.embedding.restore_embeddings()
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
)
|
||||
|
||||
self.end_of_training_loop()
|
||||
|
||||
return loss_dict
|
||||
|
||||
@@ -19,7 +19,7 @@ config:
|
||||
max_step_saves_to_keep: 5 # only affects step counts
|
||||
datasets:
|
||||
- folder_path: "/path/to/dataset"
|
||||
caption_type: "txt"
|
||||
caption_ext: "txt"
|
||||
default_caption: "[trigger]"
|
||||
buckets: true
|
||||
resolution: 512
|
||||
|
||||
2
info.py
2
info.py
@@ -3,6 +3,6 @@ from collections import OrderedDict
|
||||
v = OrderedDict()
|
||||
v["name"] = "ai-toolkit"
|
||||
v["repo"] = "https://github.com/ostris/ai-toolkit"
|
||||
v["version"] = "0.0.4"
|
||||
v["version"] = "0.1.0"
|
||||
|
||||
software_meta = v
|
||||
|
||||
@@ -2,6 +2,8 @@ import copy
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.timer import Timer
|
||||
|
||||
|
||||
class BaseProcess(object):
|
||||
|
||||
@@ -18,6 +20,9 @@ class BaseProcess(object):
|
||||
self.raw_process_config = config
|
||||
self.name = self.get_conf('name', self.job.name)
|
||||
self.meta = copy.deepcopy(self.job.meta)
|
||||
self.timer: Timer = Timer(f'{self.name} Timer')
|
||||
self.performance_log_every = self.get_conf('performance_log_every', 0)
|
||||
|
||||
print(json.dumps(self.config, indent=4))
|
||||
|
||||
def get_conf(self, key, default=None, required=False, as_type=None):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,10 @@
|
||||
import random
|
||||
from datetime import datetime
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
@@ -28,9 +30,18 @@ class BaseTrainProcess(BaseProcess):
|
||||
self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
|
||||
self.progress_bar: 'tqdm' = None
|
||||
|
||||
self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None)
|
||||
# if training seed is set, use it
|
||||
if self.training_seed is not None:
|
||||
torch.manual_seed(self.training_seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(self.training_seed)
|
||||
random.seed(self.training_seed)
|
||||
|
||||
self.progress_bar = None
|
||||
self.writer = None
|
||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder if hasattr(self.job, 'training_folder') else None)
|
||||
self.training_folder = self.get_conf('training_folder',
|
||||
self.job.training_folder if hasattr(self.job, 'training_folder') else None)
|
||||
self.save_root = os.path.join(self.training_folder, self.name)
|
||||
self.step = 0
|
||||
self.first_step = 0
|
||||
@@ -62,8 +73,7 @@ class BaseTrainProcess(BaseProcess):
|
||||
self.writer = SummaryWriter(summary_dir)
|
||||
|
||||
def save_training_config(self):
|
||||
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
save_dif = os.path.join(self.save_root, f'process_config_{timestamp}.yaml')
|
||||
save_dif = os.path.join(self.save_root, f'config.yaml')
|
||||
with open(save_dif, 'w') as f:
|
||||
yaml.dump(self.raw_process_config, f)
|
||||
yaml.dump(self.job.raw_config, f)
|
||||
|
||||
@@ -98,7 +98,7 @@ class GenerateProcess(BaseProcess):
|
||||
add_prompt_file=self.generate_config.prompt_file
|
||||
))
|
||||
# generate images
|
||||
self.sd.generate_images(prompt_image_configs)
|
||||
self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
|
||||
|
||||
print("Done generating images")
|
||||
# cleanup
|
||||
|
||||
@@ -3,10 +3,12 @@ import glob
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
# from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
|
||||
from toolkit.basic import flush
|
||||
from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
@@ -67,9 +69,10 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
self.augmentations = self.get_conf('augmentations', {})
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
if self.torch_dtype == torch.bfloat16:
|
||||
self.esrgan_dtype = torch.float16
|
||||
self.esrgan_dtype = torch.float32
|
||||
else:
|
||||
self.esrgan_dtype = torch.float32
|
||||
|
||||
self.vgg_19 = None
|
||||
self.style_weight_scalers = []
|
||||
self.content_weight_scalers = []
|
||||
@@ -232,6 +235,7 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
pattern_size=self.zoom,
|
||||
dtype=self.torch_dtype
|
||||
).to(self.device, dtype=self.torch_dtype)
|
||||
self._pattern_loss = self._pattern_loss.to(self.device, dtype=self.torch_dtype)
|
||||
loss = torch.mean(self._pattern_loss(pred, target))
|
||||
return loss
|
||||
|
||||
@@ -269,13 +273,63 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
if self.use_critic:
|
||||
self.critic.save(step)
|
||||
|
||||
def sample(self, step=None):
|
||||
def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if not os.path.exists(sample_folder):
|
||||
os.makedirs(sample_folder, exist_ok=True)
|
||||
batch_sample_folder = os.path.join(self.save_root, 'samples_batch')
|
||||
|
||||
batch_targets = None
|
||||
batch_inputs = None
|
||||
if batch is not None and not os.path.exists(batch_sample_folder):
|
||||
os.makedirs(batch_sample_folder, exist_ok=True)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def process_and_save(img, target_img, save_path):
|
||||
img = img.to(self.device, dtype=self.esrgan_dtype)
|
||||
output = self.model(img)
|
||||
# output = (output / 2 + 0.5).clamp(0, 1)
|
||||
output = output.clamp(0, 1)
|
||||
img = img.clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
img = img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
|
||||
# convert to pillow image
|
||||
output = Image.fromarray((output * 255).astype(np.uint8))
|
||||
img = Image.fromarray((img * 255).astype(np.uint8))
|
||||
|
||||
if isinstance(target_img, torch.Tensor):
|
||||
# convert to pil
|
||||
target_img = target_img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
target_img = Image.fromarray((target_img * 255).astype(np.uint8))
|
||||
|
||||
# upscale to size * self.upscale_sample while maintaining pixels
|
||||
output = output.resize(
|
||||
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
|
||||
resample=Image.NEAREST
|
||||
)
|
||||
img = img.resize(
|
||||
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
|
||||
resample=Image.NEAREST
|
||||
)
|
||||
|
||||
width, height = output.size
|
||||
|
||||
# stack input image and decoded image
|
||||
target_image = target_img.resize((width, height))
|
||||
output = output.resize((width, height))
|
||||
img = img.resize((width, height))
|
||||
|
||||
output_img = Image.new('RGB', (width * 3, height))
|
||||
|
||||
output_img.paste(img, (0, 0))
|
||||
output_img.paste(output, (width, 0))
|
||||
output_img.paste(target_image, (width * 2, 0))
|
||||
|
||||
output_img.save(save_path)
|
||||
|
||||
with torch.no_grad():
|
||||
for i, img_url in enumerate(self.sample_sources):
|
||||
img = exif_transpose(Image.open(img_url))
|
||||
@@ -295,30 +349,6 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
|
||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype)
|
||||
img = img
|
||||
output = self.model(img)
|
||||
# output = (output / 2 + 0.5).clamp(0, 1)
|
||||
output = output.clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
|
||||
# convert to pillow image
|
||||
output = Image.fromarray((output * 255).astype(np.uint8))
|
||||
|
||||
# upscale to size * self.upscale_sample while maintaining pixels
|
||||
output = output.resize(
|
||||
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
|
||||
resample=Image.NEAREST
|
||||
)
|
||||
|
||||
width, height = output.size
|
||||
|
||||
# stack input image and decoded image
|
||||
target_image = target_image.resize((width, height))
|
||||
output = output.resize((width, height))
|
||||
|
||||
output_img = Image.new('RGB', (width * 2, height))
|
||||
output_img.paste(target_image, (0, 0))
|
||||
output_img.paste(output, (width, 0))
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
@@ -327,8 +357,24 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
seconds_since_epoch = int(time.time())
|
||||
# zero-pad 2 digits
|
||||
i_str = str(i).zfill(2)
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||
output_img.save(os.path.join(sample_folder, filename))
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
|
||||
process_and_save(img, target_image, os.path.join(sample_folder, filename))
|
||||
|
||||
if batch is not None:
|
||||
batch_targets = batch[0].detach()
|
||||
batch_inputs = batch[1].detach()
|
||||
batch_targets = torch.chunk(batch_targets, batch_targets.shape[0], dim=0)
|
||||
batch_inputs = torch.chunk(batch_inputs, batch_inputs.shape[0], dim=0)
|
||||
|
||||
for i in range(len(batch_inputs)):
|
||||
if step is not None:
|
||||
# zero-pad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
seconds_since_epoch = int(time.time())
|
||||
# zero-pad 2 digits
|
||||
i_str = str(i).zfill(2)
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
|
||||
process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))
|
||||
|
||||
self.model.train()
|
||||
|
||||
@@ -376,13 +422,14 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
def run(self):
|
||||
super().run()
|
||||
self.load_datasets()
|
||||
steps_per_step = (self.critic.num_critic_per_gen + 1)
|
||||
|
||||
max_step_epochs = self.max_steps // len(self.data_loader)
|
||||
max_step_epochs = self.max_steps // (len(self.data_loader) // steps_per_step)
|
||||
num_epochs = self.epochs
|
||||
if num_epochs is None or num_epochs > max_step_epochs:
|
||||
num_epochs = max_step_epochs
|
||||
|
||||
max_epoch_steps = len(self.data_loader) * num_epochs
|
||||
max_epoch_steps = len(self.data_loader) * num_epochs * steps_per_step
|
||||
num_steps = self.max_steps
|
||||
if num_steps is None or num_steps > max_epoch_steps:
|
||||
num_steps = max_epoch_steps
|
||||
@@ -445,35 +492,60 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
print("Generating baseline samples")
|
||||
self.sample(step=0)
|
||||
# range start at self.epoch_num go to self.epochs
|
||||
critic_losses = []
|
||||
for epoch in range(self.epoch_num, self.epochs, 1):
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
flush()
|
||||
for targets, inputs in self.data_loader:
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
with torch.no_grad():
|
||||
targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
|
||||
inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
|
||||
is_critic_only_step = False
|
||||
if self.use_critic and 1 / (self.critic.num_critic_per_gen + 1) < np.random.uniform():
|
||||
is_critic_only_step = True
|
||||
|
||||
pred = self.model(inputs)
|
||||
targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
|
||||
inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
|
||||
|
||||
pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
optimizer.zero_grad()
|
||||
# dont do grads here for critic step
|
||||
do_grad = not is_critic_only_step
|
||||
with torch.set_grad_enabled(do_grad):
|
||||
pred = self.model(inputs)
|
||||
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
stacked = torch.cat([pred, targets], dim=0)
|
||||
# stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
stacked = stacked.clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
if torch.isnan(pred).any():
|
||||
raise ValueError('pred has nan values')
|
||||
if torch.isnan(targets).any():
|
||||
raise ValueError('targets has nan values')
|
||||
|
||||
if self.use_critic:
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
stacked = torch.cat([pred, targets], dim=0)
|
||||
# stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
stacked = stacked.clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
# make sure we dont have nans
|
||||
if torch.isnan(self.vgg19_pool_4.tensor).any():
|
||||
raise ValueError('vgg19_pool_4 has nan values')
|
||||
|
||||
if is_critic_only_step:
|
||||
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
|
||||
critic_losses.append(critic_d_loss)
|
||||
# don't do generator step
|
||||
continue
|
||||
else:
|
||||
critic_d_loss = 0.0
|
||||
# doing a regular step
|
||||
if len(critic_losses) == 0:
|
||||
critic_d_loss = 0
|
||||
else:
|
||||
critic_d_loss = sum(critic_losses) / len(critic_losses)
|
||||
|
||||
style_loss = self.get_style_loss() * self.style_weight
|
||||
content_loss = self.get_content_loss() * self.content_weight
|
||||
|
||||
mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight
|
||||
tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight
|
||||
pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight
|
||||
@@ -483,10 +555,13 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
|
||||
# make sure non nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError('loss is nan')
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
@@ -549,7 +624,7 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
if self.sample_every and self.step_num % self.sample_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Sampling at step {self.step_num}")
|
||||
self.sample(self.step_num)
|
||||
self.sample(self.step_num, batch=[targets, inputs])
|
||||
|
||||
if self.save_every and self.step_num % self.save_every == 0:
|
||||
# print above the progress bar
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
# ref:
|
||||
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from leco import train_util, model_util
|
||||
from leco.prompt_util import PromptEmbedsCache
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class LoRAHack:
|
||||
def __init__(self, **kwargs):
|
||||
self.type = kwargs.get('type', 'suppression')
|
||||
|
||||
|
||||
class TrainLoRAHack(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
# we don't need text encoder so move it to cpu
|
||||
self.sd.text_encoder.to("cpu")
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
if self.hack_config.type == 'suppression':
|
||||
# set all params to self.current_suppression
|
||||
params = self.network.parameters()
|
||||
for param in params:
|
||||
# get random noise for each param
|
||||
noise = torch.randn_like(param) - 0.5
|
||||
# apply noise to param
|
||||
param.data = noise * 0.001
|
||||
|
||||
|
||||
def supress_loop(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'sup': 0.0}
|
||||
)
|
||||
# increase noise
|
||||
for param in self.network.parameters():
|
||||
# get random noise for each param
|
||||
noise = torch.randn_like(param) - 0.5
|
||||
# apply noise to param
|
||||
param.data = param.data + noise * 0.001
|
||||
|
||||
|
||||
|
||||
return loss_dict
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
if self.hack_config.type == 'suppression':
|
||||
return self.supress_loop()
|
||||
else:
|
||||
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
|
||||
# end hook_train_loop
|
||||
@@ -1,9 +1,19 @@
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from torchvision.transforms import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_learnable_snr_gos
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
from toolkit.prompt_utils import \
|
||||
@@ -21,6 +31,11 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
adapter_transforms = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
|
||||
class TrainSliderProcess(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -42,6 +57,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# trim targets
|
||||
self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
|
||||
|
||||
# get presets
|
||||
self.eval_slider_device_state = get_train_sd_device_state_preset(
|
||||
self.device_torch,
|
||||
train_unet=False,
|
||||
train_text_encoder=False,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=False,
|
||||
train_adapter=False,
|
||||
train_embedding=False,
|
||||
)
|
||||
|
||||
self.train_slider_device_state = get_train_sd_device_state_preset(
|
||||
self.device_torch,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=False,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=True,
|
||||
train_adapter=False,
|
||||
train_embedding=False,
|
||||
)
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
@@ -66,6 +102,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# trim list to our max steps
|
||||
|
||||
cache = PromptEmbedsCache()
|
||||
print(f"Building prompt cache")
|
||||
|
||||
# get encoded latents for our prompts
|
||||
with torch.no_grad():
|
||||
@@ -175,31 +212,107 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.sd.vae.to(self.device_torch)
|
||||
# end hook_before_train_loop
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
def before_dataset_load(self):
|
||||
if self.slider_config.use_adapter == 'depth':
|
||||
print(f"Loading T2I Adapter for depth")
|
||||
# called before LoRA network is loaded but after model is loaded
|
||||
# attach the adapter here so it is there before we load the network
|
||||
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
|
||||
if self.model_config.is_xl:
|
||||
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
|
||||
|
||||
print(f"Loading T2I Adapter from {adapter_path}")
|
||||
|
||||
# get a random pair
|
||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
||||
]
|
||||
# move to device and dtype
|
||||
prompt_pair.to(self.device_torch, dtype=dtype)
|
||||
# dont name this adapter since we are not training it
|
||||
self.t2i_adapter = T2IAdapter.from_pretrained(
|
||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
||||
).to(self.device_torch)
|
||||
self.t2i_adapter.eval()
|
||||
self.t2i_adapter.requires_grad_(False)
|
||||
flush()
|
||||
|
||||
# get a random resolution
|
||||
height, width = self.slider_config.resolutions[
|
||||
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
|
||||
]
|
||||
if self.train_config.gradient_checkpointing:
|
||||
# may get disabled elsewhere
|
||||
self.sd.unet.enable_gradient_checkpointing()
|
||||
@torch.no_grad()
|
||||
def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']):
|
||||
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
adapter_folder_path = self.slider_config.adapter_img_dir
|
||||
adapter_images = []
|
||||
# loop through images
|
||||
for file_item in batch.file_items:
|
||||
img_path = file_item.path
|
||||
file_name_no_ext = os.path.basename(img_path).split('.')[0]
|
||||
# find the image
|
||||
for ext in img_ext_list:
|
||||
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
|
||||
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
|
||||
break
|
||||
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
|
||||
adapter_tensors = []
|
||||
# load images with torch transforms
|
||||
for idx, adapter_image in enumerate(adapter_images):
|
||||
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
|
||||
# to the smallest dimension
|
||||
img: Image.Image = Image.open(adapter_image)
|
||||
if img.width > img.height:
|
||||
# scale down so height is the same as batch
|
||||
new_height = height
|
||||
new_width = int(img.width * (height / img.height))
|
||||
else:
|
||||
new_width = width
|
||||
new_height = int(img.height * (width / img.width))
|
||||
|
||||
img = img.resize((new_width, new_height))
|
||||
crop_fn = transforms.CenterCrop((height, width))
|
||||
# crop the center to match batch
|
||||
img = crop_fn(img)
|
||||
img = adapter_transforms(img)
|
||||
adapter_tensors.append(img)
|
||||
|
||||
# stack them
|
||||
adapter_tensors = torch.stack(adapter_tensors).to(
|
||||
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
|
||||
)
|
||||
return adapter_tensors
|
||||
|
||||
def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]):
|
||||
# set to eval mode
|
||||
self.sd.set_device_state(self.eval_slider_device_state)
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# get a random pair
|
||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
||||
]
|
||||
# move to device and dtype
|
||||
prompt_pair.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get a random resolution
|
||||
height, width = self.slider_config.resolutions[
|
||||
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
|
||||
]
|
||||
if self.train_config.gradient_checkpointing:
|
||||
# may get disabled elsewhere
|
||||
self.sd.unet.enable_gradient_checkpointing()
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
pred_kwargs = {}
|
||||
|
||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||
down_kwargs = copy.deepcopy(pred_kwargs)
|
||||
if 'down_block_additional_residuals' in down_kwargs:
|
||||
dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0]
|
||||
if dbr_batch_size != dn.shape[0]:
|
||||
amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
|
||||
down_kwargs['down_block_additional_residuals'] = [
|
||||
torch.cat([sample.clone()] * amount_to_add) for sample in
|
||||
down_kwargs['down_block_additional_residuals']
|
||||
]
|
||||
return self.sd.predict_noise(
|
||||
latents=dn,
|
||||
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||
@@ -209,9 +322,13 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
),
|
||||
timestep=cts,
|
||||
guidance_scale=gs,
|
||||
**down_kwargs
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
adapter_images = None
|
||||
self.sd.unet.eval()
|
||||
|
||||
# for a complete slider, the batch size is 4 to begin with now
|
||||
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
||||
from_batch = False
|
||||
@@ -219,9 +336,32 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# traing from a batch of images, not generating ourselves
|
||||
from_batch = True
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.slider_config.adapter_img_dir is not None:
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
adapter_strength_min = 0.9
|
||||
adapter_strength_max = 1.0
|
||||
|
||||
denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size
|
||||
denoised_latents = torch.cat(denoised_latent_chunks, dim=0)
|
||||
def rand_strength(sample):
|
||||
adapter_conditioning_scale = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
)
|
||||
|
||||
adapter_conditioning_scale = value_map(
|
||||
adapter_conditioning_scale,
|
||||
0.0,
|
||||
1.0,
|
||||
adapter_strength_min,
|
||||
adapter_strength_max
|
||||
)
|
||||
return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale
|
||||
|
||||
down_block_additional_residuals = self.t2i_adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
rand_strength(sample) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0)
|
||||
current_timestep = timesteps
|
||||
else:
|
||||
|
||||
@@ -229,14 +369,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# ger a random number of steps
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
).item()
|
||||
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=height,
|
||||
@@ -249,32 +386,60 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
prompt_pair.positive_target, # unconditional
|
||||
prompt_pair.target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=3,
|
||||
)
|
||||
assert not self.network.is_active
|
||||
self.sd.unet.eval()
|
||||
# pass the multiplier list to the network
|
||||
self.network.multiplier = prompt_pair.multiplier_list
|
||||
denoised_latents = self.sd.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
prompt_pair.positive_target, # unconditional
|
||||
prompt_pair.target_class, # target
|
||||
self.train_config.batch_size,
|
||||
),
|
||||
start_timesteps=0,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=3,
|
||||
)
|
||||
|
||||
|
||||
noise_scheduler.set_timesteps(1000)
|
||||
|
||||
# split the latents into out prompt pair chunks
|
||||
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
|
||||
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
|
||||
|
||||
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
# split the latents into out prompt pair chunks
|
||||
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
|
||||
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
|
||||
|
||||
# flush() # 4.2GB to 3GB on 512x512
|
||||
mask_multiplier = torch.ones((denoised_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
has_mask = False
|
||||
if batch and batch.mask_tensor is not None:
|
||||
with self.timer('get_mask_multiplier'):
|
||||
# upsampling no supported for bfloat16
|
||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||
mask_multiplier = torch.nn.functional.interpolate(
|
||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||
)
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
has_mask = True
|
||||
|
||||
if has_mask:
|
||||
unmasked_target = get_noise_pred(
|
||||
prompt_pair.positive_target, # negative prompt
|
||||
prompt_pair.target_class, # positive prompt
|
||||
1,
|
||||
current_timestep,
|
||||
denoised_latents
|
||||
)
|
||||
unmasked_target = unmasked_target.detach()
|
||||
unmasked_target.requires_grad = False
|
||||
else:
|
||||
unmasked_target = None
|
||||
|
||||
# 4.20 GB RAM for 512x512
|
||||
positive_latents = get_noise_pred(
|
||||
@@ -286,7 +451,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
positive_latents = positive_latents.detach()
|
||||
positive_latents.requires_grad = False
|
||||
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
|
||||
|
||||
neutral_latents = get_noise_pred(
|
||||
prompt_pair.positive_target, # negative prompt
|
||||
@@ -297,7 +461,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
neutral_latents = neutral_latents.detach()
|
||||
neutral_latents.requires_grad = False
|
||||
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
|
||||
|
||||
unconditional_latents = get_noise_pred(
|
||||
prompt_pair.positive_target, # negative prompt
|
||||
@@ -308,13 +471,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
unconditional_latents = unconditional_latents.detach()
|
||||
unconditional_latents.requires_grad = False
|
||||
unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
|
||||
|
||||
denoised_latents = denoised_latents.detach()
|
||||
|
||||
flush() # 4.2GB to 3GB on 512x512
|
||||
self.sd.set_device_state(self.train_slider_device_state)
|
||||
self.sd.unet.train()
|
||||
# start accumulating gradients
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# 4.20 GB RAM for 512x512
|
||||
anchor_loss_float = None
|
||||
if len(self.anchor_pairs) > 0:
|
||||
with torch.no_grad():
|
||||
@@ -369,9 +533,34 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
del anchor_target_noise
|
||||
# move anchor back to cpu
|
||||
anchor.to("cpu")
|
||||
flush()
|
||||
|
||||
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
|
||||
with torch.no_grad():
|
||||
if self.slider_config.low_ram:
|
||||
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
|
||||
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
|
||||
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
unconditional_latents_chunks = torch.chunk(
|
||||
unconditional_latents.detach(),
|
||||
self.prompt_chunk_size,
|
||||
dim=0
|
||||
)
|
||||
mask_multiplier_chunks = torch.chunk(mask_multiplier, self.prompt_chunk_size, dim=0)
|
||||
if unmasked_target is not None:
|
||||
unmasked_target_chunks = torch.chunk(unmasked_target, self.prompt_chunk_size, dim=0)
|
||||
else:
|
||||
unmasked_target_chunks = [None for _ in range(self.prompt_chunk_size)]
|
||||
else:
|
||||
# run through in one instance
|
||||
prompt_pair_chunks = [prompt_pair.detach()]
|
||||
denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()]
|
||||
positive_latents_chunks = [positive_latents.detach()]
|
||||
neutral_latents_chunks = [neutral_latents.detach()]
|
||||
unconditional_latents_chunks = [unconditional_latents.detach()]
|
||||
mask_multiplier_chunks = [mask_multiplier]
|
||||
unmasked_target_chunks = [unmasked_target]
|
||||
|
||||
# flush()
|
||||
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
|
||||
# 3.28 GB RAM for 512x512
|
||||
with self.network:
|
||||
@@ -381,13 +570,17 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
denoised_latent_chunk, \
|
||||
positive_latents_chunk, \
|
||||
neutral_latents_chunk, \
|
||||
unconditional_latents_chunk \
|
||||
unconditional_latents_chunk, \
|
||||
mask_multiplier_chunk, \
|
||||
unmasked_target_chunk \
|
||||
in zip(
|
||||
prompt_pair_chunks,
|
||||
denoised_latent_chunks,
|
||||
positive_latents_chunks,
|
||||
neutral_latents_chunks,
|
||||
unconditional_latents_chunks,
|
||||
mask_multiplier_chunks,
|
||||
unmasked_target_chunks
|
||||
):
|
||||
self.network.multiplier = prompt_pair_chunk.multiplier_list
|
||||
target_latents = get_noise_pred(
|
||||
@@ -421,17 +614,43 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
|
||||
|
||||
# do inverted mask to preserve non masked
|
||||
if has_mask and unmasked_target_chunk is not None:
|
||||
loss = loss * mask_multiplier_chunk
|
||||
# match the mask unmasked_target_chunk
|
||||
mask_target_loss = torch.nn.functional.mse_loss(
|
||||
target_latents.float(),
|
||||
unmasked_target_chunk.float(),
|
||||
reduction="none"
|
||||
)
|
||||
mask_target_loss = mask_target_loss * (1.0 - mask_multiplier_chunk)
|
||||
loss += mask_target_loss
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.learnable_snr_gos:
|
||||
if from_batch:
|
||||
# match batch size
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
|
||||
self.train_config.min_snr_gamma)
|
||||
else:
|
||||
# match batch size
|
||||
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
|
||||
# add snr_gamma
|
||||
loss = apply_learnable_snr_gos(loss, timesteps_index_list, self.snr_gos)
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
if from_batch:
|
||||
# match batch size
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
|
||||
self.train_config.min_snr_gamma)
|
||||
else:
|
||||
# match batch size
|
||||
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler,
|
||||
self.train_config.min_snr_gamma)
|
||||
|
||||
|
||||
loss = loss.mean() * prompt_pair_chunk.weight
|
||||
|
||||
@@ -440,7 +659,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
del target_latents
|
||||
del offset_neutral
|
||||
del loss
|
||||
flush()
|
||||
# flush()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -457,7 +676,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
# move back to cpu
|
||||
prompt_pair.to("cpu")
|
||||
flush()
|
||||
# flush()
|
||||
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
@@ -7,7 +7,6 @@ from .TrainVAEProcess import TrainVAEProcess
|
||||
from .BaseMergeProcess import BaseMergeProcess
|
||||
from .TrainSliderProcess import TrainSliderProcess
|
||||
from .TrainSliderProcessOld import TrainSliderProcessOld
|
||||
from .TrainLoRAHack import TrainLoRAHack
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .ModRescaleLoraProcess import ModRescaleLoraProcess
|
||||
from .GenerateProcess import GenerateProcess
|
||||
|
||||
@@ -154,28 +154,28 @@ class Critic:
|
||||
# train critic here
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
critic_losses = []
|
||||
for i in range(self.num_critic_per_gen):
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
|
||||
stacked_output = self.model(inputs)
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
stacked_output = self.model(inputs).float()
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
self.optimizer.zero_grad()
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# avg loss
|
||||
loss = np.mean(critic_losses)
|
||||
|
||||
1
repositories/batch_annotator
Submodule
1
repositories/batch_annotator
Submodule
Submodule repositories/batch_annotator added at 420e142f6a
1
repositories/ipadapter
Submodule
1
repositories/ipadapter
Submodule
Submodule repositories/ipadapter added at d8ab37c421
@@ -1,9 +1,9 @@
|
||||
torch
|
||||
torchvision
|
||||
safetensors
|
||||
diffusers
|
||||
transformers
|
||||
lycoris_lora
|
||||
diffusers==0.21.3
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
pyyaml
|
||||
oyaml
|
||||
@@ -15,4 +15,10 @@ accelerate
|
||||
toml
|
||||
albumentations
|
||||
pydantic
|
||||
omegaconf
|
||||
omegaconf
|
||||
k-diffusion
|
||||
open_clip_torch
|
||||
timm
|
||||
prodigyopt
|
||||
controlnet_aux==0.0.7
|
||||
python-dotenv
|
||||
14
run.py
14
run.py
@@ -1,8 +1,22 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Union, OrderedDict
|
||||
from dotenv import load_dotenv
|
||||
# Load the .env file if it exists
|
||||
load_dotenv()
|
||||
|
||||
sys.path.insert(0, os.getcwd())
|
||||
# must come before ANY torch or fastai imports
|
||||
# import toolkit.cuda_malloc
|
||||
|
||||
# turn off diffusers telemetry until I can figure out how to make it opt-in
|
||||
os.environ['DISABLE_TELEMETRY'] = 'YES'
|
||||
|
||||
# check if we have DEBUG_TOOLKIT in env
|
||||
if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
|
||||
# set torch to trace mode
|
||||
import torch
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
import argparse
|
||||
from toolkit.job import get_job
|
||||
|
||||
|
||||
128
scripts/convert_cog.py
Normal file
128
scripts/convert_cog.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
||||
# [diffusers] -> kohya
|
||||
embedding_mapping = {
|
||||
'text_encoders_0': 'clip_l',
|
||||
'text_encoders_1': 'clip_g'
|
||||
}
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
KEYMAP_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps')
|
||||
sdxl_keymap_path = os.path.join(KEYMAP_ROOT, 'stable_diffusion_locon_sdxl.json')
|
||||
|
||||
# load keymap
|
||||
with open(sdxl_keymap_path, 'r') as f:
|
||||
ldm_diffusers_keymap = json.load(f)['ldm_diffusers_keymap']
|
||||
|
||||
# invert the item / key pairs
|
||||
diffusers_ldm_keymap = {v: k for k, v in ldm_diffusers_keymap.items()}
|
||||
|
||||
|
||||
def get_ldm_key(diffuser_key):
|
||||
diffuser_key = f"lora_unet_{diffuser_key.replace('.', '_')}"
|
||||
diffuser_key = diffuser_key.replace('_lora_down_weight', '.lora_down.weight')
|
||||
diffuser_key = diffuser_key.replace('_lora_up_weight', '.lora_up.weight')
|
||||
diffuser_key = diffuser_key.replace('_alpha', '.alpha')
|
||||
diffuser_key = diffuser_key.replace('_processor_to_', '_to_')
|
||||
diffuser_key = diffuser_key.replace('_to_out.', '_to_out_0.')
|
||||
if diffuser_key in diffusers_ldm_keymap:
|
||||
return diffusers_ldm_keymap[diffuser_key]
|
||||
else:
|
||||
raise KeyError(f"Key {diffuser_key} not found in keymap")
|
||||
|
||||
|
||||
def convert_cog(lora_path, embedding_path):
|
||||
embedding_state_dict = OrderedDict()
|
||||
lora_state_dict = OrderedDict()
|
||||
|
||||
# # normal dict
|
||||
# normal_dict = OrderedDict()
|
||||
# example_path = "/mnt/Models/stable-diffusion/models/LoRA/sdxl/LogoRedmond_LogoRedAF.safetensors"
|
||||
# with safe_open(example_path, framework="pt", device='cpu') as f:
|
||||
# keys = list(f.keys())
|
||||
# for key in keys:
|
||||
# normal_dict[key] = f.get_tensor(key)
|
||||
|
||||
with safe_open(embedding_path, framework="pt", device='cpu') as f:
|
||||
keys = list(f.keys())
|
||||
for key in keys:
|
||||
new_key = embedding_mapping[key]
|
||||
embedding_state_dict[new_key] = f.get_tensor(key)
|
||||
|
||||
with safe_open(lora_path, framework="pt", device='cpu') as f:
|
||||
keys = list(f.keys())
|
||||
lora_rank = None
|
||||
|
||||
# get the lora dim first. Check first 3 linear layers just to be safe
|
||||
for key in keys:
|
||||
new_key = get_ldm_key(key)
|
||||
tensor = f.get_tensor(key)
|
||||
num_checked = 0
|
||||
if len(tensor.shape) == 2:
|
||||
this_dim = min(tensor.shape)
|
||||
if lora_rank is None:
|
||||
lora_rank = this_dim
|
||||
elif lora_rank != this_dim:
|
||||
raise ValueError(f"lora rank is not consistent, got {tensor.shape}")
|
||||
else:
|
||||
num_checked += 1
|
||||
if num_checked >= 3:
|
||||
break
|
||||
|
||||
for key in keys:
|
||||
new_key = get_ldm_key(key)
|
||||
tensor = f.get_tensor(key)
|
||||
if new_key.endswith('.lora_down.weight'):
|
||||
alpha_key = new_key.replace('.lora_down.weight', '.alpha')
|
||||
# diffusers does not have alpha, they usa an alpha multiplier of 1 which is a tensor weight of the dims
|
||||
# assume first smallest dim is the lora rank if shape is 2
|
||||
lora_state_dict[alpha_key] = torch.ones(1).to(tensor.device, tensor.dtype) * lora_rank
|
||||
|
||||
lora_state_dict[new_key] = tensor
|
||||
|
||||
return lora_state_dict, embedding_state_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'lora_path',
|
||||
type=str,
|
||||
help='Path to lora file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'embedding_path',
|
||||
type=str,
|
||||
help='Path to embedding file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--lora_output',
|
||||
type=str,
|
||||
default="lora_output",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embedding_output',
|
||||
type=str,
|
||||
default="embedding_output",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
lora_state_dict, embedding_state_dict = convert_cog(args.lora_path, args.embedding_path)
|
||||
|
||||
# save them
|
||||
save_file(lora_state_dict, args.lora_output)
|
||||
save_file(embedding_state_dict, args.embedding_output)
|
||||
print(f"Saved lora to {args.lora_output}")
|
||||
print(f"Saved embedding to {args.embedding_output}")
|
||||
57
scripts/make_diffusers_model.py
Normal file
57
scripts/make_diffusers_model.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'input_path',
|
||||
type=str,
|
||||
help='Path to original sdxl model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'output_path',
|
||||
type=str,
|
||||
help='output path'
|
||||
)
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
print(f"Loading model from {args.input_path}")
|
||||
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=args.input_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
model_config=diffusers_model_config,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd.load_model()
|
||||
|
||||
|
||||
print(f"Loaded model from {args.input_path}")
|
||||
|
||||
diffusers_sd.pipeline.fuse_lora()
|
||||
|
||||
meta = OrderedDict()
|
||||
|
||||
diffusers_sd.save(args.output_path, meta=meta)
|
||||
|
||||
|
||||
print(f"Saved to {args.output_path}")
|
||||
67
scripts/make_lcm_sdxl_model.py
Normal file
67
scripts/make_lcm_sdxl_model.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'input_path',
|
||||
type=str,
|
||||
help='Path to original sdxl model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'output_path',
|
||||
type=str,
|
||||
help='output path'
|
||||
)
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
print(f"Loading model from {args.input_path}")
|
||||
|
||||
if args.sdxl:
|
||||
adapter_id = "latent-consistency/lcm-lora-sdxl"
|
||||
if args.refiner:
|
||||
adapter_id = "latent-consistency/lcm-lora-sdxl"
|
||||
elif args.ssd:
|
||||
adapter_id = "latent-consistency/lcm-lora-ssd-1b"
|
||||
else:
|
||||
adapter_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=args.input_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
model_config=diffusers_model_config,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd.load_model()
|
||||
|
||||
|
||||
print(f"Loaded model from {args.input_path}")
|
||||
|
||||
diffusers_sd.pipeline.load_lora_weights(adapter_id)
|
||||
diffusers_sd.pipeline.fuse_lora()
|
||||
|
||||
meta = OrderedDict()
|
||||
|
||||
diffusers_sd.save(args.output_path, meta=meta)
|
||||
|
||||
|
||||
print(f"Saved to {args.output_path}")
|
||||
@@ -1,547 +0,0 @@
|
||||
import gc
|
||||
import time
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import custom_tools.train_tools as train_tools
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
# perlin_noise,
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
SD_SCRIPTS_ROOT = os.path.join(PROJECT_ROOT, "repositories", "sd-scripts")
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, False)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
if args.no_token_padding:
|
||||
train_dataset_group.disable_token_padding()
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# replace captions with names
|
||||
if args.name_replace is not None:
|
||||
print(f"Replacing captions [name] with '{args.name_replace}'")
|
||||
|
||||
train_dataset_group = train_tools.replace_filewords_in_dataset_group(
|
||||
train_dataset_group, args
|
||||
)
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
print(
|
||||
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
|
||||
)
|
||||
print(
|
||||
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
||||
)
|
||||
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||
src_diffusers_model_path = None
|
||||
else:
|
||||
src_stable_diffusion_ckpt = None
|
||||
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||
|
||||
if args.save_model_as is None:
|
||||
save_stable_diffusion_format = load_stable_diffusion_format
|
||||
use_safetensors = args.use_safetensors
|
||||
else:
|
||||
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
||||
unet.requires_grad_(True) # 念のため追加
|
||||
text_encoder.requires_grad_(train_text_encoder)
|
||||
if not train_text_encoder:
|
||||
print("Text Encoder is not trained.")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
if train_text_encoder:
|
||||
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
else:
|
||||
trainable_params = unet.parameters()
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
print("enable full fp16 training.")
|
||||
unet.to(weight_dtype)
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
if not train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
print("running training / 学習開始")
|
||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
if args.sample_first or args.sample_only:
|
||||
# Do initial sample before starting training
|
||||
train_tools.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer,
|
||||
text_encoder, unet, force_sample=True)
|
||||
|
||||
if args.sample_only:
|
||||
return
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||
unet.train()
|
||||
# train==True is required to enable gradient_checkpointing
|
||||
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
|
||||
text_encoder.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
# 指定したステップ数でText Encoderの学習を止める
|
||||
if global_step == args.stop_text_encoder_training:
|
||||
print(f"stop text encoder training at step {global_step}")
|
||||
if not args.gradient_checkpointing:
|
||||
text_encoder.train(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
if cache_latents:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
if args.train_noise_seed is not None:
|
||||
torch.manual_seed(args.train_noise_seed)
|
||||
torch.cuda.manual_seed(args.train_noise_seed)
|
||||
# make same seed for each item in the batch by stacking them
|
||||
single_noise = torch.randn_like(latents[0])
|
||||
noise = torch.stack([single_noise for _ in range(b_size)])
|
||||
noise = noise.to(latents.device)
|
||||
elif args.seed_lock:
|
||||
noise = train_tools.get_noise_from_latents(latents)
|
||||
else:
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
|
||||
if args.noise_offset:
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||
# elif args.perlin_noise:
|
||||
# noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
if train_text_encoder:
|
||||
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
else:
|
||||
params_to_clip = unet.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
# checking for saving is in util
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_train_end(
|
||||
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
||||
)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, False, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_token_padding",
|
||||
action="store_true",
|
||||
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop_text_encoder_training",
|
||||
type=int,
|
||||
default=None,
|
||||
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample_first",
|
||||
action="store_true",
|
||||
help="Sample first interval before training",
|
||||
default=False
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--name_replace",
|
||||
type=str,
|
||||
help="Replaces [name] in prompts. Used is sampling, training, and regs",
|
||||
default=None
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--train_noise_seed",
|
||||
type=int,
|
||||
help="Use custom seed for training noise",
|
||||
default=None
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample_only",
|
||||
action="store_true",
|
||||
help="Only generate samples. Used for generating training data with specific seeds to alter during training",
|
||||
default=False
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed_lock",
|
||||
action="store_true",
|
||||
help="Locks the seed to the latent images so the same latent will always have the same noise",
|
||||
default=False
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
130
testing/generate_lora_mapping.py
Normal file
130
testing/generate_lora_mapping.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
keymap_path = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', 'stable_diffusion_sdxl.json')
|
||||
|
||||
# load keymap
|
||||
with open(keymap_path, 'r') as f:
|
||||
keymap = json.load(f)
|
||||
|
||||
lora_keymap = OrderedDict()
|
||||
|
||||
# convert keymap to lora key naming
|
||||
for ldm_key, diffusers_key in keymap['ldm_diffusers_keymap'].items():
|
||||
if ldm_key.endswith('.bias') or diffusers_key.endswith('.bias'):
|
||||
# skip it
|
||||
continue
|
||||
# sdxl has same te for locon with kohya and ours
|
||||
if ldm_key.startswith('conditioner'):
|
||||
#skip it
|
||||
continue
|
||||
# ignore vae
|
||||
if ldm_key.startswith('first_stage_model'):
|
||||
continue
|
||||
ldm_key = ldm_key.replace('model.diffusion_model.', 'lora_unet_')
|
||||
ldm_key = ldm_key.replace('.weight', '')
|
||||
ldm_key = ldm_key.replace('.', '_')
|
||||
|
||||
diffusers_key = diffusers_key.replace('unet_', 'lora_unet_')
|
||||
diffusers_key = diffusers_key.replace('.weight', '')
|
||||
diffusers_key = diffusers_key.replace('.', '_')
|
||||
|
||||
lora_keymap[f"{ldm_key}.alpha"] = f"{diffusers_key}.alpha"
|
||||
lora_keymap[f"{ldm_key}.lora_down.weight"] = f"{diffusers_key}.lora_down.weight"
|
||||
lora_keymap[f"{ldm_key}.lora_up.weight"] = f"{diffusers_key}.lora_up.weight"
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input", help="input file")
|
||||
parser.add_argument("input2", help="input2 file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# name = args.name
|
||||
# if args.sdxl:
|
||||
# name += '_sdxl'
|
||||
# elif args.sd2:
|
||||
# name += '_sd2'
|
||||
# else:
|
||||
# name += '_sd1'
|
||||
name = 'stable_diffusion_locon_sdxl'
|
||||
|
||||
locon_save = load_file(args.input)
|
||||
our_save = load_file(args.input2)
|
||||
|
||||
our_extra_keys = list(set(our_save.keys()) - set(locon_save.keys()))
|
||||
locon_extra_keys = list(set(locon_save.keys()) - set(our_save.keys()))
|
||||
|
||||
print(f"we have {len(our_extra_keys)} extra keys")
|
||||
print(f"locon has {len(locon_extra_keys)} extra keys")
|
||||
|
||||
save_dtype = torch.float16
|
||||
print(f"our extra keys: {our_extra_keys}")
|
||||
print(f"locon extra keys: {locon_extra_keys}")
|
||||
|
||||
|
||||
def export_state_dict(our_save):
|
||||
converted_state_dict = OrderedDict()
|
||||
for key, value in our_save.items():
|
||||
# test encoders share keys for some reason
|
||||
if key.startswith('lora_te'):
|
||||
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
|
||||
else:
|
||||
converted_key = key
|
||||
for ldm_key, diffusers_key in lora_keymap.items():
|
||||
if converted_key == diffusers_key:
|
||||
converted_key = ldm_key
|
||||
|
||||
converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
|
||||
return converted_state_dict
|
||||
|
||||
def import_state_dict(loaded_state_dict):
|
||||
converted_state_dict = OrderedDict()
|
||||
for key, value in loaded_state_dict.items():
|
||||
if key.startswith('lora_te'):
|
||||
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
|
||||
else:
|
||||
converted_key = key
|
||||
for ldm_key, diffusers_key in lora_keymap.items():
|
||||
if converted_key == ldm_key:
|
||||
converted_key = diffusers_key
|
||||
|
||||
converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
# check it again
|
||||
converted_state_dict = export_state_dict(our_save)
|
||||
converted_extra_keys = list(set(converted_state_dict.keys()) - set(locon_save.keys()))
|
||||
locon_extra_keys = list(set(locon_save.keys()) - set(converted_state_dict.keys()))
|
||||
|
||||
|
||||
print(f"we have {len(converted_extra_keys)} extra keys")
|
||||
print(f"locon has {len(locon_extra_keys)} extra keys")
|
||||
|
||||
print(f"our extra keys: {converted_extra_keys}")
|
||||
|
||||
# convert back
|
||||
cycle_state_dict = import_state_dict(converted_state_dict)
|
||||
cycle_extra_keys = list(set(cycle_state_dict.keys()) - set(our_save.keys()))
|
||||
our_extra_keys = list(set(our_save.keys()) - set(cycle_state_dict.keys()))
|
||||
|
||||
print(f"we have {len(our_extra_keys)} extra keys")
|
||||
print(f"cycle has {len(cycle_extra_keys)} extra keys")
|
||||
|
||||
# save keymap
|
||||
to_save = OrderedDict()
|
||||
to_save['ldm_diffusers_keymap'] = lora_keymap
|
||||
|
||||
with open(os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', f'{name}.json'), 'w') as f:
|
||||
json.dump(to_save, f, indent=4)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import os
|
||||
# add project root to sys path
|
||||
import sys
|
||||
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import torch
|
||||
@@ -50,6 +52,8 @@ parser.add_argument(
|
||||
|
||||
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -60,24 +64,76 @@ find_matches = False
|
||||
|
||||
print(f'Loading diffusers model')
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=file_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
model_config=diffusers_model_config,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd.load_model()
|
||||
# delete things we dont need
|
||||
del diffusers_sd.tokenizer
|
||||
flush()
|
||||
ignore_ldm_begins_with = []
|
||||
|
||||
diffusers_file_path = file_path
|
||||
if args.ssd:
|
||||
diffusers_file_path = "segmind/SSD-1B"
|
||||
|
||||
# if args.refiner:
|
||||
# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||
|
||||
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
|
||||
|
||||
if not args.refiner:
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=diffusers_file_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
model_config=diffusers_model_config,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd.load_model()
|
||||
# delete things we dont need
|
||||
del diffusers_sd.tokenizer
|
||||
flush()
|
||||
|
||||
print(f'Loading ldm model')
|
||||
diffusers_state_dict = diffusers_sd.state_dict()
|
||||
else:
|
||||
# refiner wont work directly with stable diffusion
|
||||
# so we need to load the model and then load the state dict
|
||||
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
diffusers_file_path,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
# diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
# file_path,
|
||||
# torch_dtype=torch.float16,
|
||||
# use_safetensors=True,
|
||||
# variant="fp16",
|
||||
# ).to(device)
|
||||
|
||||
SD_PREFIX_VAE = "vae"
|
||||
SD_PREFIX_UNET = "unet"
|
||||
SD_PREFIX_REFINER_UNET = "refiner_unet"
|
||||
SD_PREFIX_TEXT_ENCODER = "te"
|
||||
|
||||
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
||||
SD_PREFIX_TEXT_ENCODER2 = "te1"
|
||||
|
||||
diffusers_state_dict = OrderedDict()
|
||||
for k, v in diffusers_pipeline.vae.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in diffusers_pipeline.text_encoder_2.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in diffusers_pipeline.unet.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
|
||||
# add ignore ones as we are only going to focus on unet and copy the rest
|
||||
# ignore_ldm_begins_with = ["conditioner.", "first_stage_model."]
|
||||
|
||||
print(f'Loading ldm model')
|
||||
diffusers_state_dict = diffusers_sd.state_dict()
|
||||
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
||||
|
||||
ldm_state_dict = load_file(file_path)
|
||||
@@ -93,18 +149,26 @@ total_keys = len(ldm_dict_keys)
|
||||
matched_ldm_keys = []
|
||||
matched_diffusers_keys = []
|
||||
|
||||
error_margin = 1e-4
|
||||
error_margin = 1e-8
|
||||
|
||||
tmp_merge_key = "TMP___MERGE"
|
||||
|
||||
te_suffix = ''
|
||||
proj_pattern_weight = None
|
||||
proj_pattern_bias = None
|
||||
text_proj_layer = None
|
||||
if args.sdxl:
|
||||
if args.sdxl or args.ssd:
|
||||
te_suffix = '1'
|
||||
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
|
||||
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "conditioner.embedders.1.model.text_projection"
|
||||
if args.refiner:
|
||||
te_suffix = '1'
|
||||
ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks"
|
||||
proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "conditioner.embedders.0.model.text_projection"
|
||||
if args.sd2:
|
||||
te_suffix = ''
|
||||
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
|
||||
@@ -112,10 +176,13 @@ if args.sd2:
|
||||
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||
|
||||
if args.sdxl or args.sd2:
|
||||
if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
|
||||
else:
|
||||
d_model = 1024
|
||||
|
||||
@@ -139,7 +206,7 @@ if args.sdxl or args.sd2:
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
|
||||
# make diffusers convertable_dict
|
||||
diffusers_state_dict[
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"] = new_val
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.weight"] = new_val
|
||||
|
||||
# add operator
|
||||
ldm_operator_map[ldm_key] = {
|
||||
@@ -148,7 +215,6 @@ if args.sdxl or args.sd2:
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
|
||||
],
|
||||
"target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"
|
||||
}
|
||||
|
||||
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||
@@ -189,7 +255,7 @@ if args.sdxl or args.sd2:
|
||||
matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
|
||||
# make diffusers convertable_dict
|
||||
diffusers_state_dict[
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"] = new_val
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.bias"] = new_val
|
||||
|
||||
# add operator
|
||||
ldm_operator_map[ldm_key] = {
|
||||
@@ -198,7 +264,6 @@ if args.sdxl or args.sd2:
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
|
||||
],
|
||||
# "target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"
|
||||
}
|
||||
|
||||
# add diffusers operators
|
||||
@@ -237,11 +302,11 @@ for ldm_key in ldm_dict_keys:
|
||||
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
|
||||
|
||||
# That was easy. Same key
|
||||
if ldm_key == diffusers_key:
|
||||
ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||
matched_ldm_keys.append(ldm_key)
|
||||
matched_diffusers_keys.append(diffusers_key)
|
||||
break
|
||||
# if ldm_key == diffusers_key:
|
||||
# ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||
# matched_ldm_keys.append(ldm_key)
|
||||
# matched_diffusers_keys.append(diffusers_key)
|
||||
# break
|
||||
|
||||
# if we already have this key mapped, skip it
|
||||
if diffusers_key in matched_diffusers_keys:
|
||||
@@ -266,7 +331,7 @@ for ldm_key in ldm_dict_keys:
|
||||
did_reduce_diffusers = True
|
||||
|
||||
# check to see if they match within a margin of error
|
||||
mse = torch.nn.functional.mse_loss(ldm_weight, diffusers_weight)
|
||||
mse = torch.nn.functional.mse_loss(ldm_weight.float(), diffusers_weight.float())
|
||||
if mse < error_margin:
|
||||
ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||
matched_ldm_keys.append(ldm_key)
|
||||
@@ -289,6 +354,10 @@ pbar.close()
|
||||
name = args.name
|
||||
if args.sdxl:
|
||||
name += '_sdxl'
|
||||
elif args.ssd:
|
||||
name += '_ssd'
|
||||
elif args.refiner:
|
||||
name += '_refiner'
|
||||
elif args.sd2:
|
||||
name += '_sd2'
|
||||
else:
|
||||
@@ -359,13 +428,35 @@ for key in unmatched_ldm_keys:
|
||||
save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors'))
|
||||
print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}')
|
||||
|
||||
# do cleanup of some left overs and bugs
|
||||
to_remove = []
|
||||
for ldm_key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
# get rid of tmp merge keys used to slicing
|
||||
if tmp_merge_key in diffusers_key or tmp_merge_key in ldm_key:
|
||||
to_remove.append(ldm_key)
|
||||
|
||||
for key in to_remove:
|
||||
del ldm_diffusers_keymap[key]
|
||||
|
||||
to_remove = []
|
||||
# remove identical shape mappings. Not sure why they exist but they do
|
||||
for ldm_key, shape_list in ldm_diffusers_shape_map.items():
|
||||
# remove identical shape mappings. Not sure why they exist but they do
|
||||
# convert to json string to make it easier to compare
|
||||
ldm_shape = json.dumps(shape_list[0])
|
||||
diffusers_shape = json.dumps(shape_list[1])
|
||||
if ldm_shape == diffusers_shape:
|
||||
to_remove.append(ldm_key)
|
||||
|
||||
for key in to_remove:
|
||||
del ldm_diffusers_shape_map[key]
|
||||
|
||||
dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json')
|
||||
save_obj = OrderedDict()
|
||||
save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap
|
||||
save_obj["ldm_diffusers_shape_map"] = ldm_diffusers_shape_map
|
||||
save_obj["ldm_diffusers_operator_map"] = ldm_operator_map
|
||||
save_obj["diffusers_ldm_operator_map"] = diffusers_operator_map
|
||||
|
||||
with open(dest_path, 'w') as f:
|
||||
f.write(json.dumps(save_obj, indent=4))
|
||||
|
||||
|
||||
@@ -1,37 +1,107 @@
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
# make sure we can import from the toolkit
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
import os
|
||||
import cv2
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
|
||||
from toolkit.image_utils import show_img
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from library.model_util import load_vae
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
|
||||
trigger_dataloader_setup_epoch
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dataset_folder', type=str, default='input')
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
resolution = 512
|
||||
resolution = 1024
|
||||
bucket_tolerance = 64
|
||||
batch_size = 4
|
||||
batch_size = 1
|
||||
|
||||
|
||||
##
|
||||
|
||||
dataset_config = DatasetConfig(
|
||||
folder_path=dataset_folder,
|
||||
dataset_path=dataset_folder,
|
||||
resolution=resolution,
|
||||
caption_type='txt',
|
||||
caption_ext='json',
|
||||
default_caption='default',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
poi='person',
|
||||
augmentations=[
|
||||
{
|
||||
'method': 'RandomBrightnessContrast',
|
||||
'brightness_limit': (-0.3, 0.3),
|
||||
'contrast_limit': (-0.3, 0.3),
|
||||
'brightness_by_max': False,
|
||||
'p': 1.0
|
||||
},
|
||||
{
|
||||
'method': 'HueSaturationValue',
|
||||
'hue_shift_limit': (-0, 0),
|
||||
'sat_shift_limit': (-40, 40),
|
||||
'val_shift_limit': (-40, 40),
|
||||
'p': 1.0
|
||||
},
|
||||
# {
|
||||
# 'method': 'RGBShift',
|
||||
# 'r_shift_limit': (-20, 20),
|
||||
# 'g_shift_limit': (-20, 20),
|
||||
# 'b_shift_limit': (-20, 20),
|
||||
# 'p': 1.0
|
||||
# },
|
||||
]
|
||||
|
||||
|
||||
)
|
||||
|
||||
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
||||
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
||||
|
||||
|
||||
# run through an epoch ang check sizes
|
||||
for batch in dataloader:
|
||||
print(list(batch[0].shape))
|
||||
dataloader_iterator = iter(dataloader)
|
||||
for epoch in range(args.epochs):
|
||||
for batch in dataloader:
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
|
||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
|
||||
show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
print('done')
|
||||
|
||||
@@ -3,6 +3,8 @@ import os
|
||||
# add project root to sys path
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import torch
|
||||
@@ -20,6 +22,8 @@ from toolkit.stable_diffusion_model import StableDiffusion
|
||||
# you probably wont need this. Unless they change them.... again... again
|
||||
# on second thought, you probably will
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
@@ -109,7 +113,26 @@ keys_in_both.sort()
|
||||
|
||||
if len(keys_not_in_state_dict_2) == 0 and len(keys_not_in_state_dict_1) == 0:
|
||||
print("All keys match!")
|
||||
exit(0)
|
||||
print("Checking values...")
|
||||
mismatch_keys = []
|
||||
loss = torch.nn.MSELoss()
|
||||
tolerance = 1e-6
|
||||
for key in tqdm(keys_in_both):
|
||||
if loss(state_dict_file_1[key], state_dict_file_2[key]) > tolerance:
|
||||
print(f"Values for key {key} don't match!")
|
||||
print(f"Loss: {loss(state_dict_file_1[key], state_dict_file_2[key])}")
|
||||
mismatch_keys.append(key)
|
||||
|
||||
if len(mismatch_keys) == 0:
|
||||
print("All values match!")
|
||||
else:
|
||||
print("Some valued font match!")
|
||||
print(mismatch_keys)
|
||||
mismatched_path = os.path.join(project_root, 'config', 'mismatch.json')
|
||||
with open(mismatched_path, 'w') as f:
|
||||
f.write(json.dumps(mismatch_keys, indent=4))
|
||||
exit(0)
|
||||
|
||||
else:
|
||||
print("Keys don't match!, generating info...")
|
||||
|
||||
@@ -132,17 +155,17 @@ for key in keys_not_in_state_dict_2:
|
||||
|
||||
# print(json_data)
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
json_save_path = os.path.join(project_root, 'config', 'keys.json')
|
||||
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
|
||||
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
|
||||
state_dict_1_filename = os.path.basename(args.file_1[0])
|
||||
state_dict_2_filename = os.path.basename(args.file_2[0])
|
||||
# state_dict_2_filename = os.path.basename(args.file_2[0])
|
||||
# save key names for each in own file
|
||||
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f:
|
||||
f.write(json.dumps(state_dict_1_keys, indent=4))
|
||||
|
||||
with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f:
|
||||
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}_loop.json'), 'w') as f:
|
||||
f.write(json.dumps(state_dict_2_keys, indent=4))
|
||||
|
||||
with open(json_save_path, 'w') as f:
|
||||
|
||||
@@ -1,4 +1,50 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def value_map(inputs, min_in, max_in, min_out, max_out):
|
||||
return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out
|
||||
|
||||
|
||||
def flush(garbage_collect=True):
|
||||
torch.cuda.empty_cache()
|
||||
if garbage_collect:
|
||||
gc.collect()
|
||||
|
||||
|
||||
def get_mean_std(tensor):
|
||||
if len(tensor.shape) == 3:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
elif len(tensor.shape) != 4:
|
||||
raise Exception("Expected tensor of shape (batch_size, channels, width, height)")
|
||||
mean, variance = torch.mean(
|
||||
tensor, dim=[2, 3], keepdim=True
|
||||
), torch.var(
|
||||
tensor, dim=[2, 3],
|
||||
keepdim=True
|
||||
)
|
||||
std = torch.sqrt(variance + 1e-5)
|
||||
return mean, std
|
||||
|
||||
|
||||
def adain(content_features, style_features):
|
||||
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
|
||||
|
||||
# Step 1: Calculate mean and variance of content features
|
||||
content_mean, content_var = torch.mean(content_features, dim=[2, 3], keepdim=True), torch.var(content_features,
|
||||
dim=[2, 3],
|
||||
keepdim=True)
|
||||
# Step 2: Calculate mean and variance of style features
|
||||
style_mean, style_var = torch.mean(style_features, dim=[2, 3], keepdim=True), torch.var(style_features, dim=[2, 3],
|
||||
keepdim=True)
|
||||
|
||||
# Step 3: Normalize content features
|
||||
content_std = torch.sqrt(content_var + 1e-5)
|
||||
normalized_content = (content_features - content_mean) / content_std
|
||||
|
||||
# Step 4: Scale and shift normalized content with style's statistics
|
||||
style_std = torch.sqrt(style_var + 1e-5)
|
||||
stylized_content = normalized_content * style_std + style_mean
|
||||
|
||||
return stylized_content
|
||||
|
||||
127
toolkit/buckets.py
Normal file
127
toolkit/buckets.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from typing import Type, List, Union, TypedDict
|
||||
|
||||
|
||||
class BucketResolution(TypedDict):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
# resolutions SDXL was trained on with a 1024x1024 base resolution
|
||||
resolutions_1024: List[BucketResolution] = [
|
||||
# SDXL Base resolution
|
||||
{"width": 1024, "height": 1024},
|
||||
# SDXL Resolutions, widescreen
|
||||
{"width": 2048, "height": 512},
|
||||
{"width": 1984, "height": 512},
|
||||
{"width": 1920, "height": 512},
|
||||
{"width": 1856, "height": 512},
|
||||
{"width": 1792, "height": 576},
|
||||
{"width": 1728, "height": 576},
|
||||
{"width": 1664, "height": 576},
|
||||
{"width": 1600, "height": 640},
|
||||
{"width": 1536, "height": 640},
|
||||
{"width": 1472, "height": 704},
|
||||
{"width": 1408, "height": 704},
|
||||
{"width": 1344, "height": 704},
|
||||
{"width": 1344, "height": 768},
|
||||
{"width": 1280, "height": 768},
|
||||
{"width": 1216, "height": 832},
|
||||
{"width": 1152, "height": 832},
|
||||
{"width": 1152, "height": 896},
|
||||
{"width": 1088, "height": 896},
|
||||
{"width": 1088, "height": 960},
|
||||
{"width": 1024, "height": 960},
|
||||
# SDXL Resolutions, portrait
|
||||
{"width": 960, "height": 1024},
|
||||
{"width": 960, "height": 1088},
|
||||
{"width": 896, "height": 1088},
|
||||
{"width": 896, "height": 1152}, # 2:3
|
||||
{"width": 832, "height": 1152},
|
||||
{"width": 832, "height": 1216},
|
||||
{"width": 768, "height": 1280},
|
||||
{"width": 768, "height": 1344},
|
||||
{"width": 704, "height": 1408},
|
||||
{"width": 704, "height": 1472},
|
||||
{"width": 640, "height": 1536},
|
||||
{"width": 640, "height": 1600},
|
||||
{"width": 576, "height": 1664},
|
||||
{"width": 576, "height": 1728},
|
||||
{"width": 576, "height": 1792},
|
||||
{"width": 512, "height": 1856},
|
||||
{"width": 512, "height": 1920},
|
||||
{"width": 512, "height": 1984},
|
||||
{"width": 512, "height": 2048},
|
||||
]
|
||||
|
||||
|
||||
def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]:
|
||||
# determine scaler form 1024 to resolution
|
||||
scaler = resolution / 1024
|
||||
|
||||
bucket_size_list = []
|
||||
for bucket in resolutions_1024:
|
||||
# must be divisible by 8
|
||||
width = int(bucket["width"] * scaler)
|
||||
height = int(bucket["height"] * scaler)
|
||||
if width % divisibility != 0:
|
||||
width = width - (width % divisibility)
|
||||
if height % divisibility != 0:
|
||||
height = height - (height % divisibility)
|
||||
bucket_size_list.append({"width": width, "height": height})
|
||||
|
||||
return bucket_size_list
|
||||
|
||||
|
||||
def get_resolution(width, height):
|
||||
num_pixels = width * height
|
||||
# determine same number of pixels for square image
|
||||
square_resolution = int(num_pixels ** 0.5)
|
||||
return square_resolution
|
||||
|
||||
|
||||
def get_bucket_for_image_size(
|
||||
width: int,
|
||||
height: int,
|
||||
bucket_size_list: List[BucketResolution] = None,
|
||||
resolution: Union[int, None] = None,
|
||||
divisibility: int = 8
|
||||
) -> BucketResolution:
|
||||
|
||||
if bucket_size_list is None and resolution is None:
|
||||
# get resolution from width and height
|
||||
resolution = get_resolution(width, height)
|
||||
if bucket_size_list is None:
|
||||
# if real resolution is smaller, use that instead
|
||||
real_resolution = get_resolution(width, height)
|
||||
resolution = min(resolution, real_resolution)
|
||||
bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility)
|
||||
|
||||
# Check for exact match first
|
||||
for bucket in bucket_size_list:
|
||||
if bucket["width"] == width and bucket["height"] == height:
|
||||
return bucket
|
||||
|
||||
# If exact match not found, find the closest bucket
|
||||
closest_bucket = None
|
||||
min_removed_pixels = float("inf")
|
||||
|
||||
for bucket in bucket_size_list:
|
||||
scale_w = bucket["width"] / width
|
||||
scale_h = bucket["height"] / height
|
||||
|
||||
# To minimize pixels, we use the larger scale factor to minimize the amount that has to be cropped.
|
||||
scale = max(scale_w, scale_h)
|
||||
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
|
||||
removed_pixels = (new_width - bucket["width"]) * new_height + (new_height - bucket["height"]) * new_width
|
||||
|
||||
if removed_pixels < min_removed_pixels:
|
||||
min_removed_pixels = removed_pixels
|
||||
closest_bucket = bucket
|
||||
|
||||
if closest_bucket is None:
|
||||
raise ValueError("No suitable bucket found")
|
||||
|
||||
return closest_bucket
|
||||
@@ -17,6 +17,24 @@ def get_cwd_abs_path(path):
|
||||
return path
|
||||
|
||||
|
||||
def replace_env_vars_in_string(s: str) -> str:
|
||||
"""
|
||||
Replace placeholders like ${VAR_NAME} with the value of the corresponding environment variable.
|
||||
If the environment variable is not set, raise an error.
|
||||
"""
|
||||
|
||||
def replacer(match):
|
||||
var_name = match.group(1)
|
||||
value = os.environ.get(var_name)
|
||||
|
||||
if value is None:
|
||||
raise ValueError(f"Environment variable {var_name} not set. Please ensure it's defined before proceeding.")
|
||||
|
||||
return value
|
||||
|
||||
return re.sub(r'\$\{([^}]+)\}', replacer, s)
|
||||
|
||||
|
||||
def preprocess_config(config: OrderedDict, name: str = None):
|
||||
if "job" not in config:
|
||||
raise ValueError("config file must have a job key")
|
||||
@@ -81,13 +99,14 @@ def get_config(
|
||||
raise ValueError(f"Could not find config file {config_file_path}")
|
||||
|
||||
# if we found it, check if it is a json or yaml file
|
||||
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
|
||||
with open(real_config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f, object_pairs_hook=OrderedDict)
|
||||
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
|
||||
with open(real_config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.load(f, Loader=fixed_loader)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
|
||||
with open(real_config_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
content_with_env_replaced = replace_env_vars_in_string(content)
|
||||
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
|
||||
config = json.loads(content_with_env_replaced, object_pairs_hook=OrderedDict)
|
||||
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
|
||||
config = yaml.load(content_with_env_replaced, Loader=fixed_loader)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
|
||||
|
||||
return preprocess_config(config, name)
|
||||
|
||||
@@ -1,14 +1,25 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Literal
|
||||
from typing import List, Optional, Literal, Union
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
ImgExt = Literal['jpg', 'png', 'webp']
|
||||
|
||||
SaveFormat = Literal['safetensors', 'diffusers']
|
||||
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.save_every: int = kwargs.get('save_every', 1000)
|
||||
self.dtype: str = kwargs.get('save_dtype', 'float16')
|
||||
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
|
||||
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
|
||||
if self.save_format not in ['safetensors', 'diffusers']:
|
||||
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
|
||||
|
||||
|
||||
class LogingConfig:
|
||||
@@ -20,6 +31,7 @@ class LogingConfig:
|
||||
|
||||
class SampleConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.sampler: str = kwargs.get('sampler', 'ddpm')
|
||||
self.sample_every: int = kwargs.get('sample_every', 100)
|
||||
self.width: int = kwargs.get('width', 512)
|
||||
self.height: int = kwargs.get('height', 512)
|
||||
@@ -31,11 +43,59 @@ class SampleConfig:
|
||||
self.sample_steps = kwargs.get('sample_steps', 20)
|
||||
self.network_multiplier = kwargs.get('network_multiplier', 1)
|
||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||
self.ext: ImgExt = kwargs.get('format', 'jpg')
|
||||
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists
|
||||
|
||||
|
||||
class LormModuleSettingsConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.contains: str = kwargs.get('contains', '4nt$3')
|
||||
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
|
||||
# min num parameters to attach to
|
||||
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
|
||||
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
|
||||
|
||||
|
||||
class LoRMConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
|
||||
self.do_conv: bool = kwargs.get('do_conv', False)
|
||||
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
|
||||
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
|
||||
module_settings = kwargs.get('module_settings', [])
|
||||
default_module_settings = {
|
||||
'extract_mode': self.extract_mode,
|
||||
'extract_mode_param': self.extract_mode_param,
|
||||
'parameter_threshold': self.parameter_threshold,
|
||||
}
|
||||
module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings]
|
||||
self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for
|
||||
module_setting in module_settings]
|
||||
|
||||
def get_config_for_module(self, block_name):
|
||||
for setting in self.module_settings:
|
||||
contain_pieces = setting.contains.split('|')
|
||||
if all(contain_piece in block_name for contain_piece in contain_pieces):
|
||||
return setting
|
||||
# try replacing the . with _
|
||||
contain_pieces = setting.contains.replace('.', '_').split('|')
|
||||
if all(contain_piece in block_name for contain_piece in contain_pieces):
|
||||
return setting
|
||||
# do default
|
||||
return LormModuleSettingsConfig(**{
|
||||
'extract_mode': self.extract_mode,
|
||||
'extract_mode_param': self.extract_mode_param,
|
||||
'parameter_threshold': self.parameter_threshold,
|
||||
})
|
||||
|
||||
|
||||
NetworkType = Literal['lora', 'locon', 'lorm']
|
||||
|
||||
|
||||
class NetworkConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: str = kwargs.get('type', 'lora')
|
||||
self.type: NetworkType = kwargs.get('type', 'lora')
|
||||
rank = kwargs.get('rank', None)
|
||||
linear = kwargs.get('linear', None)
|
||||
if rank is not None:
|
||||
@@ -48,7 +108,37 @@ class NetworkConfig:
|
||||
self.alpha: float = kwargs.get('alpha', 1.0)
|
||||
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
||||
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
|
||||
self.normalize = kwargs.get('normalize', False)
|
||||
self.dropout: Union[float, None] = kwargs.get('dropout', None)
|
||||
|
||||
self.lorm_config: Union[LoRMConfig, None] = None
|
||||
lorm = kwargs.get('lorm', None)
|
||||
if lorm is not None:
|
||||
self.lorm_config: LoRMConfig = LoRMConfig(**lorm)
|
||||
|
||||
if self.type == 'lorm':
|
||||
# set linear to arbitrary values so it makes them
|
||||
self.linear = 4
|
||||
self.rank = 4
|
||||
if self.lorm_config.do_conv:
|
||||
self.conv = 4
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+']
|
||||
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
|
||||
self.in_channels: int = kwargs.get('in_channels', 3)
|
||||
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
|
||||
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
|
||||
self.downscale_factor: int = kwargs.get('downscale_factor', 8)
|
||||
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
|
||||
self.image_dir: str = kwargs.get('image_dir', None)
|
||||
self.test_img_path: str = kwargs.get('test_img_path', None)
|
||||
self.train: str = kwargs.get('train', False)
|
||||
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
|
||||
self.name_or_path = kwargs.get('name_or_path', None)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
@@ -59,26 +149,94 @@ class EmbeddingConfig:
|
||||
self.save_format = kwargs.get('save_format', 'safetensors')
|
||||
|
||||
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||
self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced')
|
||||
self.steps: int = kwargs.get('steps', 1000)
|
||||
self.lr = kwargs.get('lr', 1e-6)
|
||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
|
||||
self.refiner_lr = kwargs.get('refiner_lr', self.lr)
|
||||
self.embedding_lr = kwargs.get('embedding_lr', self.lr)
|
||||
self.adapter_lr = kwargs.get('adapter_lr', self.lr)
|
||||
self.optimizer = kwargs.get('optimizer', 'adamw')
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
self.sdp = kwargs.get('sdp', False)
|
||||
self.train_unet = kwargs.get('train_unet', True)
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.train_refiner = kwargs.get('train_refiner', True)
|
||||
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||
self.snr_gamma = kwargs.get('snr_gamma', None)
|
||||
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
|
||||
# this should balance the learning rate across all timesteps over time
|
||||
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
|
||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
|
||||
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
|
||||
self.start_step = kwargs.get('start_step', None)
|
||||
self.free_u = kwargs.get('free_u', False)
|
||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
||||
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||
# multiplier applied to loos on regularization images
|
||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||
|
||||
# dropout that happens before encoding. It functions independently per text encoder
|
||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||
|
||||
# match the norm of the noise before computing loss. This will help the model maintain its
|
||||
# current understandin of the brightness of images.
|
||||
|
||||
self.match_noise_norm = kwargs.get('match_noise_norm', False)
|
||||
|
||||
# set to -1 to accumulate gradients for entire epoch
|
||||
# warning, only do this with a small dataset or you will run out of memory
|
||||
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
|
||||
|
||||
# short long captions will double your batch size. This only works when a dataset is
|
||||
# prepared with a json caption file that has both short and long captions in it. It will
|
||||
# Double up every image and run it through with both short and long captions. The idea
|
||||
# is that the network will learn how to generate good images with both short and long captions
|
||||
self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
|
||||
# if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only
|
||||
self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False)
|
||||
|
||||
# basically gradient accumulation but we run just 1 item through the network
|
||||
# and accumulate gradients. This can be used as basic gradient accumulation but is very helpful
|
||||
# for training tricks that increase batch size but need a single gradient step
|
||||
self.single_item_batching = kwargs.get('single_item_batching', False)
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target',
|
||||
'noise') # noise, source, unaugmented, differential_noise
|
||||
|
||||
# When a mask is passed in a dataset, and this is true,
|
||||
# we will predict noise without a the LoRa network and use the prediction as a target for
|
||||
# unmasked reign. It is unmasked regularization basically
|
||||
self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False)
|
||||
self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5)
|
||||
|
||||
# legacy
|
||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||
self.match_adapter_chance = 1.0
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -86,17 +244,27 @@ class ModelConfig:
|
||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||
self.is_v2: bool = kwargs.get('is_v2', False)
|
||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||
self.vae_path = kwargs.get('vae_path', None)
|
||||
self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None)
|
||||
self._original_refiner_name_or_path = self.refiner_name_or_path
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
|
||||
|
||||
# only for SDXL models for now
|
||||
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
|
||||
self.use_text_encoder_2: bool = kwargs.get('use_text_encoder_2', True)
|
||||
|
||||
self.experimental_xl: bool = kwargs.get('experimental_xl', False)
|
||||
|
||||
if self.name_or_path is None:
|
||||
raise ValueError('name_or_path must be specified')
|
||||
|
||||
if self.is_ssd:
|
||||
# sed sdxl as true since it is mostly the same architecture
|
||||
self.is_xl = True
|
||||
|
||||
|
||||
class ReferenceDatasetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -126,6 +294,14 @@ class SliderTargetConfig:
|
||||
self.shuffle: bool = kwargs.get('shuffle', False)
|
||||
|
||||
|
||||
class GuidanceConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.target_class: str = kwargs.get('target_class', '')
|
||||
self.guidance_scale: float = kwargs.get('guidance_scale', 1.0)
|
||||
self.positive_prompt: str = kwargs.get('positive_prompt', '')
|
||||
self.negative_prompt: str = kwargs.get('negative_prompt', '')
|
||||
|
||||
|
||||
class SliderConfigAnchors:
|
||||
def __init__(self, **kwargs):
|
||||
self.prompt = kwargs.get('prompt', '')
|
||||
@@ -143,28 +319,41 @@ class SliderConfig:
|
||||
self.prompt_file: str = kwargs.get('prompt_file', None)
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
|
||||
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
|
||||
self.low_ram = kwargs.get('low_ram', False)
|
||||
|
||||
# expand targets if shuffling
|
||||
from toolkit.prompt_utils import get_slider_target_permutations
|
||||
self.targets: List[SliderTargetConfig] = []
|
||||
targets = [SliderTargetConfig(**target) for target in targets]
|
||||
# do permutations if shuffle is true
|
||||
print(f"Building slider targets")
|
||||
for target in targets:
|
||||
if target.shuffle:
|
||||
target_permutations = get_slider_target_permutations(target)
|
||||
target_permutations = get_slider_target_permutations(target, max_permutations=8)
|
||||
self.targets = self.targets + target_permutations
|
||||
else:
|
||||
self.targets.append(target)
|
||||
print(f"Built {len(self.targets)} slider targets (with permutations)")
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
caption_type: Literal["txt", "caption"] = 'txt'
|
||||
"""
|
||||
Dataset config for sd-datasets
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.type = kwargs.get('type', 'image') # sd, slider, reference
|
||||
# will be legacy
|
||||
self.folder_path: str = kwargs.get('folder_path', None)
|
||||
# can be json or folder path
|
||||
self.dataset_path: str = kwargs.get('dataset_path', None)
|
||||
|
||||
self.default_caption: str = kwargs.get('default_caption', None)
|
||||
self.caption_type: str = kwargs.get('caption_type', None)
|
||||
self.random_triggers: List[str] = kwargs.get('random_triggers', [])
|
||||
self.caption_ext: str = kwargs.get('caption_ext', None)
|
||||
self.random_scale: bool = kwargs.get('random_scale', False)
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
self.resolution: int = kwargs.get('resolution', 512)
|
||||
@@ -172,6 +361,66 @@ class DatasetConfig:
|
||||
self.buckets: bool = kwargs.get('buckets', False)
|
||||
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
|
||||
self.is_reg: bool = kwargs.get('is_reg', False)
|
||||
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
|
||||
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
self.mask_path: str = kwargs.get('mask_path',
|
||||
None) # focus mask (black and white. White has higher loss than black)
|
||||
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located
|
||||
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||
self.poi: Union[str, None] = kwargs.get('poi',
|
||||
None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
|
||||
# cache latents will store them in memory
|
||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||
|
||||
# https://albumentations.ai/docs/api_reference/augmentations/transforms
|
||||
# augmentations are returned as a separate image and cannot currently be cached
|
||||
self.augmentations: List[dict] = kwargs.get('augmentations', None)
|
||||
self.shuffle_augmentations: bool = kwargs.get('shuffle_augmentations', False)
|
||||
|
||||
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
|
||||
|
||||
if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk):
|
||||
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
|
||||
self.cache_latents = False
|
||||
self.cache_latents_to_disk = False
|
||||
|
||||
# legacy compatability
|
||||
legacy_caption_type = kwargs.get('caption_type', None)
|
||||
if legacy_caption_type:
|
||||
self.caption_ext = legacy_caption_type
|
||||
self.caption_type = self.caption_ext
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
"""
|
||||
This just splits up the datasets by resolutions so you dont have to do it manually
|
||||
:param raw_config:
|
||||
:return:
|
||||
"""
|
||||
# split up datasets by resolutions
|
||||
new_config = []
|
||||
for dataset in raw_config:
|
||||
resolution = dataset.get('resolution', 512)
|
||||
if isinstance(resolution, list):
|
||||
resolution_list = resolution
|
||||
else:
|
||||
resolution_list = [resolution]
|
||||
for res in resolution_list:
|
||||
dataset_copy = dataset.copy()
|
||||
dataset_copy['resolution'] = res
|
||||
new_config.append(dataset_copy)
|
||||
return new_config
|
||||
|
||||
|
||||
class GenerateImageConfig:
|
||||
@@ -191,9 +440,14 @@ class GenerateImageConfig:
|
||||
# the tag [time] will be replaced with milliseconds since epoch
|
||||
output_path: str = None, # full image path
|
||||
output_folder: str = None, # folder to save image in if output_path is not specified
|
||||
output_ext: str = 'png', # extension to save image as if output_path is not specified
|
||||
output_ext: str = ImgExt, # extension to save image as if output_path is not specified
|
||||
output_tail: str = '', # tail to add to output filename
|
||||
add_prompt_file: bool = False, # add a prompt file with generated image
|
||||
adapter_image_path: str = None, # path to adapter image
|
||||
adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
|
||||
latents: Union[torch.Tensor | None] = None, # input latent to start with,
|
||||
extra_kwargs: dict = None, # extra data to save with prompt file
|
||||
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -204,6 +458,7 @@ class GenerateImageConfig:
|
||||
self.prompt_2: str = prompt_2
|
||||
self.negative_prompt: str = negative_prompt
|
||||
self.negative_prompt_2: str = negative_prompt_2
|
||||
self.latents: Union[torch.Tensor | None] = latents
|
||||
|
||||
self.output_path: str = output_path
|
||||
self.seed: int = seed
|
||||
@@ -216,6 +471,10 @@ class GenerateImageConfig:
|
||||
self.add_prompt_file: bool = add_prompt_file
|
||||
self.output_tail: str = output_tail
|
||||
self.gen_time: int = int(time.time() * 1000)
|
||||
self.adapter_image_path: str = adapter_image_path
|
||||
self.adapter_conditioning_scale: float = adapter_conditioning_scale
|
||||
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
|
||||
self.refiner_start_at = refiner_start_at
|
||||
|
||||
# prompt string will override any settings above
|
||||
self._process_prompt_string()
|
||||
@@ -370,3 +629,15 @@ class GenerateImageConfig:
|
||||
self.network_multiplier = float(content)
|
||||
elif flag == 'gr':
|
||||
self.guidance_rescale = float(content)
|
||||
elif flag == 'a':
|
||||
self.adapter_conditioning_scale = float(content)
|
||||
elif flag == 'ref':
|
||||
self.refiner_start_at = float(content)
|
||||
|
||||
def post_process_embeddings(
|
||||
self,
|
||||
conditional_prompt_embeds: PromptEmbeds,
|
||||
unconditional_prompt_embeds: Optional[PromptEmbeds] = None,
|
||||
):
|
||||
# this is called after prompt embeds are encoded. We can override them in the future here
|
||||
pass
|
||||
|
||||
93
toolkit/cuda_malloc.py
Normal file
93
toolkit/cuda_malloc.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# ref comfy ui
|
||||
import os
|
||||
import importlib.util
|
||||
|
||||
|
||||
# Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||
def get_gpu_names():
|
||||
if os.name == 'nt':
|
||||
import ctypes
|
||||
|
||||
# Define necessary C structures and types
|
||||
class DISPLAY_DEVICEA(ctypes.Structure):
|
||||
_fields_ = [
|
||||
('cb', ctypes.c_ulong),
|
||||
('DeviceName', ctypes.c_char * 32),
|
||||
('DeviceString', ctypes.c_char * 128),
|
||||
('StateFlags', ctypes.c_ulong),
|
||||
('DeviceID', ctypes.c_char * 128),
|
||||
('DeviceKey', ctypes.c_char * 128)
|
||||
]
|
||||
|
||||
# Load user32.dll
|
||||
user32 = ctypes.windll.user32
|
||||
|
||||
# Call EnumDisplayDevicesA
|
||||
def enum_display_devices():
|
||||
device_info = DISPLAY_DEVICEA()
|
||||
device_info.cb = ctypes.sizeof(device_info)
|
||||
device_index = 0
|
||||
gpu_names = set()
|
||||
|
||||
while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
|
||||
device_index += 1
|
||||
gpu_names.add(device_info.DeviceString.decode('utf-8'))
|
||||
return gpu_names
|
||||
|
||||
return enum_display_devices()
|
||||
else:
|
||||
return set()
|
||||
|
||||
|
||||
blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950",
|
||||
"GeForce 945M",
|
||||
"GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745",
|
||||
"Quadro K620",
|
||||
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
|
||||
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000",
|
||||
"Quadro M5500", "Quadro M6000",
|
||||
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
|
||||
"GeForce GTX 1650", "GeForce GTX 1630"
|
||||
}
|
||||
|
||||
|
||||
def cuda_malloc_supported():
|
||||
try:
|
||||
names = get_gpu_names()
|
||||
except:
|
||||
names = set()
|
||||
for x in names:
|
||||
if "NVIDIA" in x:
|
||||
for b in blacklist:
|
||||
if b in x:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
cuda_malloc = False
|
||||
|
||||
if not cuda_malloc:
|
||||
try:
|
||||
version = ""
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
ver_file = os.path.join(folder, "version.py")
|
||||
if os.path.isfile(ver_file):
|
||||
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
version = module.__version__
|
||||
if int(version[0]) >= 2: # enable by default for torch version 2.0 and up
|
||||
cuda_malloc = cuda_malloc_supported()
|
||||
except:
|
||||
pass
|
||||
|
||||
if cuda_malloc:
|
||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||
if env_var is None:
|
||||
env_var = "backend:cudaMallocAsync"
|
||||
else:
|
||||
env_var += ",backend:cudaMallocAsync"
|
||||
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||
print("CUDA Malloc Async Enabled")
|
||||
@@ -1,10 +1,13 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
import traceback
|
||||
from functools import lru_cache
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torchvision import transforms
|
||||
@@ -12,9 +15,13 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
@@ -27,7 +34,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
self.include_prompt = self.get_config('include_prompt', False)
|
||||
self.default_prompt = self.get_config('default_prompt', '')
|
||||
if self.include_prompt:
|
||||
self.caption_type = self.get_config('caption_type', 'txt')
|
||||
self.caption_type = self.get_config('caption_ext', 'txt')
|
||||
else:
|
||||
self.caption_type = None
|
||||
# we always random crop if random scale is enabled
|
||||
@@ -48,6 +55,8 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
else:
|
||||
bad_count += 1
|
||||
|
||||
self.file_list = new_file_list
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.path}"
|
||||
@@ -85,7 +94,10 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
scaler = scale_size / min_img_size
|
||||
scale_width = int((img.width + 5) * scaler)
|
||||
scale_height = int((img.height + 5) * scaler)
|
||||
img = img.resize((scale_width, scale_height), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.resolution)(img)
|
||||
else:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
@@ -100,21 +112,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
return img
|
||||
|
||||
|
||||
class Augments:
|
||||
def __init__(self, **kwargs):
|
||||
self.method_name = kwargs.get('method', None)
|
||||
self.params = kwargs.get('params', {})
|
||||
|
||||
# convert kwargs enums for cv2
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, str):
|
||||
# split the string
|
||||
split_string = value.split('.')
|
||||
if len(split_string) == 2 and split_string[0] == 'cv2':
|
||||
if hasattr(cv2, split_string[1]):
|
||||
self.params[key] = getattr(cv2, split_string[1].upper())
|
||||
else:
|
||||
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
|
||||
|
||||
|
||||
class AugmentedImageDataset(ImageDataset):
|
||||
@@ -265,6 +263,38 @@ class PairedImageDataset(Dataset):
|
||||
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
img_path = img_path_or_tuple[1]
|
||||
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
|
||||
# always use # 2 (pos)
|
||||
bucket_resolution = get_bucket_for_image_size(
|
||||
width=img2.width,
|
||||
height=img2.height,
|
||||
resolution=self.size,
|
||||
# divisibility=self.
|
||||
)
|
||||
|
||||
# images will be same base dimension, but may be trimmed. We need to shrink and then central crop
|
||||
if bucket_resolution['width'] > bucket_resolution['height']:
|
||||
img1_scale_to_height = bucket_resolution["height"]
|
||||
img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height))
|
||||
img2_scale_to_height = bucket_resolution["height"]
|
||||
img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height))
|
||||
else:
|
||||
img1_scale_to_width = bucket_resolution["width"]
|
||||
img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width))
|
||||
img2_scale_to_width = bucket_resolution["width"]
|
||||
img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width))
|
||||
|
||||
img1_crop_height = bucket_resolution["height"]
|
||||
img1_crop_width = bucket_resolution["width"]
|
||||
img2_crop_height = bucket_resolution["height"]
|
||||
img2_crop_width = bucket_resolution["width"]
|
||||
|
||||
# scale then center crop images
|
||||
img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC)
|
||||
img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1)
|
||||
img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC)
|
||||
img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2)
|
||||
|
||||
# combine them side by side
|
||||
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
|
||||
img.paste(img1, (0, 0))
|
||||
@@ -272,52 +302,45 @@ class PairedImageDataset(Dataset):
|
||||
else:
|
||||
img_path = img_path_or_tuple
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
height = self.size
|
||||
# determine width to keep aspect ratio
|
||||
width = int(img.size[0] * height / img.size[1])
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((width, height), Image.BICUBIC)
|
||||
|
||||
prompt = self.get_prompt_item(index)
|
||||
|
||||
height = self.size
|
||||
# determine width to keep aspect ratio
|
||||
width = int(img.size[0] * height / img.size[1])
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((width, height), Image.BICUBIC)
|
||||
img = self.transform(img)
|
||||
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
printed_messages = []
|
||||
class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
|
||||
|
||||
def print_once(msg):
|
||||
global printed_messages
|
||||
if msg not in printed_messages:
|
||||
print(msg)
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
class FileItem:
|
||||
def __init__(self, **kwargs):
|
||||
self.path = kwargs.get('path', None)
|
||||
self.width = kwargs.get('width', None)
|
||||
self.height = kwargs.get('height', None)
|
||||
# we scale first, then crop
|
||||
self.scale_to_width = kwargs.get('scale_to_width', self.width)
|
||||
self.scale_to_height = kwargs.get('scale_to_height', self.height)
|
||||
# crop values are from scaled size
|
||||
self.crop_x = kwargs.get('crop_x', 0)
|
||||
self.crop_y = kwargs.get('crop_y', 0)
|
||||
self.crop_width = kwargs.get('crop_width', self.scale_to_width)
|
||||
self.crop_height = kwargs.get('crop_height', self.scale_to_height)
|
||||
|
||||
|
||||
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
|
||||
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_config: 'DatasetConfig',
|
||||
batch_size=1,
|
||||
sd: 'StableDiffusion' = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_config = dataset_config
|
||||
self.folder_path = dataset_config.folder_path
|
||||
self.caption_type = dataset_config.caption_type
|
||||
folder_path = dataset_config.folder_path
|
||||
self.dataset_path = dataset_config.dataset_path
|
||||
if self.dataset_path is None:
|
||||
self.dataset_path = folder_path
|
||||
|
||||
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
|
||||
self.is_caching_latents_to_memory = dataset_config.cache_latents
|
||||
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
|
||||
self.epoch_num = 0
|
||||
|
||||
self.sd = sd
|
||||
|
||||
if self.sd is None and self.is_caching_latents:
|
||||
raise ValueError(f"sd is required for caching latents")
|
||||
|
||||
self.caption_type = dataset_config.caption_ext
|
||||
self.default_caption = dataset_config.default_caption
|
||||
self.random_scale = dataset_config.random_scale
|
||||
self.scale = dataset_config.scale
|
||||
@@ -325,176 +348,211 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
# we always random crop if random scale is enabled
|
||||
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
|
||||
self.resolution = dataset_config.resolution
|
||||
self.file_list: List['FileItem'] = []
|
||||
self.caption_dict = None
|
||||
self.file_list: List['FileItemDTO'] = []
|
||||
|
||||
# get the file list
|
||||
file_list = [
|
||||
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||
]
|
||||
# check if dataset_path is a folder or json
|
||||
if os.path.isdir(self.dataset_path):
|
||||
file_list = [
|
||||
os.path.join(self.dataset_path, file) for file in os.listdir(self.dataset_path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||
]
|
||||
else:
|
||||
# assume json
|
||||
with open(self.dataset_path, 'r') as f:
|
||||
self.caption_dict = json.load(f)
|
||||
# keys are file paths
|
||||
file_list = list(self.caption_dict.keys())
|
||||
|
||||
if self.dataset_config.num_repeats > 1:
|
||||
# repeat the list
|
||||
file_list = file_list * self.dataset_config.num_repeats
|
||||
|
||||
# this might take a while
|
||||
print(f" - Preprocessing image dimensions")
|
||||
bad_count = 0
|
||||
for file in tqdm(file_list):
|
||||
try:
|
||||
w, h = image_utils.get_image_size(file)
|
||||
except image_utils.UnknownImageFormat:
|
||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
f'This process is faster for png, jpeg')
|
||||
img = Image.open(file)
|
||||
h, w = img.size
|
||||
if int(min(h, w) * self.scale) >= self.resolution:
|
||||
self.file_list.append(
|
||||
FileItem(
|
||||
path=file,
|
||||
width=w,
|
||||
height=h,
|
||||
scale_to_width=int(w * self.scale),
|
||||
scale_to_height=int(h * self.scale),
|
||||
)
|
||||
file_item = FileItemDTO(
|
||||
path=file,
|
||||
dataset_config=dataset_config
|
||||
)
|
||||
else:
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Error processing image: {file}")
|
||||
print(e)
|
||||
bad_count += 1
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
|
||||
# print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
# handle x axis flips
|
||||
if self.dataset_config.flip_x:
|
||||
print(" - adding x axis flips")
|
||||
current_file_list = [x for x in self.file_list]
|
||||
for file_item in current_file_list:
|
||||
# create a copy that is flipped on the x axis
|
||||
new_file_item = copy.deepcopy(file_item)
|
||||
new_file_item.flip_x = True
|
||||
self.file_list.append(new_file_item)
|
||||
|
||||
# handle y axis flips
|
||||
if self.dataset_config.flip_y:
|
||||
print(" - adding y axis flips")
|
||||
current_file_list = [x for x in self.file_list]
|
||||
for file_item in current_file_list:
|
||||
# create a copy that is flipped on the y axis
|
||||
new_file_item = copy.deepcopy(file_item)
|
||||
new_file_item.flip_y = True
|
||||
self.file_list.append(new_file_item)
|
||||
|
||||
if self.dataset_config.flip_x or self.dataset_config.flip_y:
|
||||
print(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
def setup_epoch(self):
|
||||
if self.epoch_num == 0:
|
||||
# initial setup
|
||||
# do not call for now
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
if self.is_caching_latents:
|
||||
self.cache_latents_all_latents()
|
||||
else:
|
||||
if self.dataset_config.poi is not None:
|
||||
# handle cropping to a specific point of interest
|
||||
# setup buckets every epoch
|
||||
self.setup_buckets(quiet=True)
|
||||
self.epoch_num += 1
|
||||
|
||||
def __len__(self):
|
||||
if self.dataset_config.buckets:
|
||||
return len(self.batch_indices)
|
||||
return len(self.file_list)
|
||||
|
||||
def _get_single_item(self, index):
|
||||
file_item = self.file_list[index]
|
||||
# todo make sure this matches
|
||||
img = exif_transpose(Image.open(file_item.path)).convert('RGB')
|
||||
w, h = img.size
|
||||
if w > h and file_item.scale_to_width < file_item.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
|
||||
elif h > w and file_item.scale_to_height < file_item.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
|
||||
min_img_size = min(img.size)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# todo allow scaling and cropping, will be hard to add
|
||||
# scale and crop based on file item
|
||||
img = img.resize((file_item.scale_to_width, file_item.scale_to_height), Image.BICUBIC)
|
||||
img = transforms.CenterCrop((file_item.crop_height, file_item.crop_width))(img)
|
||||
else:
|
||||
if self.random_crop:
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file_item.path}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.resolution)(img)
|
||||
else:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
||||
|
||||
img = self.transform(img)
|
||||
|
||||
# todo convert it all
|
||||
dataset_config_dict = {
|
||||
"is_reg": 1 if self.dataset_config.is_reg else 0,
|
||||
}
|
||||
|
||||
if self.caption_type is not None:
|
||||
prompt = self.get_caption_item(index)
|
||||
return img, prompt, dataset_config_dict
|
||||
else:
|
||||
return img, dataset_config_dict
|
||||
def _get_single_item(self, index) -> 'FileItemDTO':
|
||||
file_item = copy.deepcopy(self.file_list[index])
|
||||
file_item.load_and_process_image(self.transform)
|
||||
file_item.load_caption(self.caption_dict)
|
||||
return file_item
|
||||
|
||||
def __getitem__(self, item):
|
||||
if self.dataset_config.buckets:
|
||||
# for buckets we collate ourselves for now
|
||||
# todo allow a scheduler to dynamically make buckets
|
||||
# we collate ourselves
|
||||
if len(self.batch_indices) - 1 < item:
|
||||
# tried everything to solve this. No way to reset length when redoing things. Pick another index
|
||||
item = random.randint(0, len(self.batch_indices) - 1)
|
||||
idx_list = self.batch_indices[item]
|
||||
tensor_list = []
|
||||
prompt_list = []
|
||||
dataset_config_dict_list = []
|
||||
for idx in idx_list:
|
||||
if self.caption_type is not None:
|
||||
img, prompt, dataset_config_dict = self._get_single_item(idx)
|
||||
prompt_list.append(prompt)
|
||||
dataset_config_dict_list.append(dataset_config_dict)
|
||||
else:
|
||||
img, dataset_config_dict = self._get_single_item(idx)
|
||||
dataset_config_dict_list.append(dataset_config_dict)
|
||||
tensor_list.append(img.unsqueeze(0))
|
||||
|
||||
if self.caption_type is not None:
|
||||
return torch.cat(tensor_list, dim=0), prompt_list, dataset_config_dict_list
|
||||
else:
|
||||
return torch.cat(tensor_list, dim=0), dataset_config_dict_list
|
||||
return [self._get_single_item(idx) for idx in idx_list]
|
||||
else:
|
||||
# Dataloader is batching
|
||||
return self._get_single_item(item)
|
||||
|
||||
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
# TODO do bucketing
|
||||
def get_dataloader_from_datasets(
|
||||
dataset_options,
|
||||
batch_size=1,
|
||||
sd: 'StableDiffusion' = None,
|
||||
) -> DataLoader:
|
||||
if dataset_options is None or len(dataset_options) == 0:
|
||||
return None
|
||||
|
||||
datasets = []
|
||||
has_buckets = False
|
||||
is_caching_latents = False
|
||||
|
||||
dataset_config_list = []
|
||||
# preprocess them all
|
||||
for dataset_option in dataset_options:
|
||||
if isinstance(dataset_option, DatasetConfig):
|
||||
config = dataset_option
|
||||
dataset_config_list.append(dataset_option)
|
||||
else:
|
||||
config = DatasetConfig(**dataset_option)
|
||||
# preprocess raw data
|
||||
split_configs = preprocess_dataset_raw_config([dataset_option])
|
||||
for x in split_configs:
|
||||
dataset_config_list.append(DatasetConfig(**x))
|
||||
|
||||
for config in dataset_config_list:
|
||||
|
||||
if config.type == 'image':
|
||||
dataset = AiToolkitDataset(config, batch_size=batch_size)
|
||||
dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd)
|
||||
datasets.append(dataset)
|
||||
if config.buckets:
|
||||
has_buckets = True
|
||||
if config.cache_latents or config.cache_latents_to_disk:
|
||||
is_caching_latents = True
|
||||
else:
|
||||
raise ValueError(f"invalid dataset type: {config.type}")
|
||||
|
||||
concatenated_dataset = ConcatDataset(datasets)
|
||||
|
||||
# todo build scheduler that can get buckets from all datasets that match
|
||||
# todo and evenly distribute reg images
|
||||
|
||||
def dto_collation(batch: List['FileItemDTO']):
|
||||
# create DTO batch
|
||||
batch = DataLoaderBatchDTO(
|
||||
file_items=batch
|
||||
)
|
||||
return batch
|
||||
|
||||
# check if is caching latents
|
||||
|
||||
|
||||
if has_buckets:
|
||||
# make sure they all have buckets
|
||||
for dataset in datasets:
|
||||
assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none"
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
# just return as is
|
||||
return batch
|
||||
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=None, # we batch in the dataloader
|
||||
batch_size=None, # we batch in the datasets for now
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=custom_collate_fn, # Use the custom collate function
|
||||
num_workers=2
|
||||
collate_fn=dto_collation, # Use the custom collate function
|
||||
num_workers=4
|
||||
)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
num_workers=4,
|
||||
collate_fn=dto_collation
|
||||
)
|
||||
return data_loader
|
||||
|
||||
|
||||
def trigger_dataloader_setup_epoch(dataloader: DataLoader):
|
||||
# hacky but needed because of different types of datasets and dataloaders
|
||||
dataloader.len = None
|
||||
if isinstance(dataloader.dataset, list):
|
||||
for dataset in dataloader.dataset:
|
||||
if hasattr(dataset, 'datasets'):
|
||||
for sub_dataset in dataset.datasets:
|
||||
if hasattr(sub_dataset, 'setup_epoch'):
|
||||
sub_dataset.setup_epoch()
|
||||
sub_dataset.len = None
|
||||
elif hasattr(dataset, 'setup_epoch'):
|
||||
dataset.setup_epoch()
|
||||
dataset.len = None
|
||||
elif hasattr(dataloader.dataset, 'setup_epoch'):
|
||||
dataloader.dataset.setup_epoch()
|
||||
dataloader.dataset.len = None
|
||||
elif hasattr(dataloader.dataset, 'datasets'):
|
||||
dataloader.dataset.len = None
|
||||
for sub_dataset in dataloader.dataset.datasets:
|
||||
if hasattr(sub_dataset, 'setup_epoch'):
|
||||
sub_dataset.setup_epoch()
|
||||
sub_dataset.len = None
|
||||
|
||||
202
toolkit/data_transfer_object/data_loader.py
Normal file
202
toolkit/data_transfer_object/data_loader.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
import torch
|
||||
import random
|
||||
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
|
||||
printed_messages = []
|
||||
|
||||
|
||||
def print_once(msg):
|
||||
global printed_messages
|
||||
if msg not in printed_messages:
|
||||
print(msg)
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
class FileItemDTO(
|
||||
LatentCachingFileItemDTOMixin,
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
AugmentationFileItemDTOMixin,
|
||||
UnconditionalFileItemDTOMixin,
|
||||
PoiFileItemDTOMixin,
|
||||
ArgBreakMixin,
|
||||
):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.path = kwargs.get('path', None)
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
# process width and height
|
||||
try:
|
||||
w, h = image_utils.get_image_size(self.path)
|
||||
except image_utils.UnknownImageFormat:
|
||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
f'This process is faster for png, jpeg')
|
||||
img = exif_transpose(Image.open(self.path))
|
||||
h, w = img.size
|
||||
self.width: int = w
|
||||
self.height: int = h
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# self.caption_path: str = kwargs.get('caption_path', None)
|
||||
self.raw_caption: str = kwargs.get('raw_caption', None)
|
||||
# we scale first, then crop
|
||||
self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale))
|
||||
self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale))
|
||||
# crop values are from scaled size
|
||||
self.crop_x: int = kwargs.get('crop_x', 0)
|
||||
self.crop_y: int = kwargs.get('crop_y', 0)
|
||||
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
|
||||
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_x', False)
|
||||
self.augments: List[str] = self.dataset_config.augments
|
||||
|
||||
self.network_weight: float = self.dataset_config.network_weight
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_control()
|
||||
self.cleanup_mask()
|
||||
self.cleanup_unconditional()
|
||||
|
||||
|
||||
class DataLoaderBatchDTO:
|
||||
def __init__(self, **kwargs):
|
||||
try:
|
||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
||||
is_latents_cached = self.file_items[0].is_latent_cached
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_latents: Union[torch.Tensor, None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
# if we have encoded latents, we concatenate them
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
if any([x.control_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_control_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.control_tensor is not None:
|
||||
base_control_tensor = x.control_tensor
|
||||
break
|
||||
control_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.control_tensor is None:
|
||||
control_tensors.append(torch.zeros_like(base_control_tensor))
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
if any([x.mask_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_mask_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.mask_tensor is not None:
|
||||
base_mask_tensor = x.mask_tensor
|
||||
break
|
||||
mask_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.mask_tensor is None:
|
||||
mask_tensors.append(torch.zeros_like(base_mask_tensor))
|
||||
else:
|
||||
mask_tensors.append(x.mask_tensor)
|
||||
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
|
||||
|
||||
# add unaugmented tensors for ones with augments
|
||||
if any([x.unaugmented_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_unaugmented_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.unaugmented_tensor is not None:
|
||||
base_unaugmented_tensor = x.unaugmented_tensor
|
||||
break
|
||||
unaugmented_tensor = []
|
||||
for x in self.file_items:
|
||||
if x.unaugmented_tensor is None:
|
||||
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
|
||||
else:
|
||||
unaugmented_tensor.append(x.unaugmented_tensor)
|
||||
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
|
||||
|
||||
# add unconditional tensors
|
||||
if any([x.unconditional_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_unconditional_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.unaugmented_tensor is not None:
|
||||
base_unconditional_tensor = x.unconditional_tensor
|
||||
break
|
||||
unconditional_tensor = []
|
||||
for x in self.file_items:
|
||||
if x.unconditional_tensor is None:
|
||||
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
|
||||
else:
|
||||
unconditional_tensor.append(x.unconditional_tensor)
|
||||
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
def get_is_reg_list(self):
|
||||
return [x.is_reg for x in self.file_items]
|
||||
|
||||
def get_network_weight_list(self):
|
||||
return [x.network_weight for x in self.file_items]
|
||||
|
||||
def get_caption_list(
|
||||
self,
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
):
|
||||
return [x.get_caption(
|
||||
trigger=trigger,
|
||||
to_replace_list=to_replace_list,
|
||||
add_if_not_present=add_if_not_present
|
||||
) for x in self.file_items]
|
||||
|
||||
def get_caption_short_list(
|
||||
self,
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
):
|
||||
return [x.get_caption(
|
||||
trigger=trigger,
|
||||
to_replace_list=to_replace_list,
|
||||
add_if_not_present=add_if_not_present,
|
||||
short_caption=True
|
||||
) for x in self.file_items]
|
||||
|
||||
def cleanup(self):
|
||||
del self.latents
|
||||
del self.tensor
|
||||
del self.control_tensor
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,8 @@ class Embedding:
|
||||
def __init__(
|
||||
self,
|
||||
sd: 'StableDiffusion',
|
||||
embed_config: 'EmbeddingConfig'
|
||||
embed_config: 'EmbeddingConfig',
|
||||
state_dict: OrderedDict = None,
|
||||
):
|
||||
self.name = embed_config.trigger
|
||||
self.sd = sd
|
||||
@@ -38,74 +39,115 @@ class Embedding:
|
||||
additional_tokens.append(f"{self.embed_config.trigger}_{i}")
|
||||
placeholder_tokens += additional_tokens
|
||||
|
||||
num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != self.embed_config.tokens:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
# handle dual tokenizer
|
||||
self.tokenizer_list = self.sd.tokenizer if isinstance(self.sd.tokenizer, list) else [self.sd.tokenizer]
|
||||
self.text_encoder_list = self.sd.text_encoder if isinstance(self.sd.text_encoder, list) else [
|
||||
self.sd.text_encoder]
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
|
||||
# if length of token ids is more than number of orm embedding tokens fill with *
|
||||
if len(init_token_ids) > self.embed_config.tokens:
|
||||
init_token_ids = init_token_ids[:self.embed_config.tokens]
|
||||
elif len(init_token_ids) < self.embed_config.tokens:
|
||||
pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False)
|
||||
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
|
||||
self.placeholder_token_ids = []
|
||||
self.embedding_tokens = []
|
||||
|
||||
self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens)
|
||||
print(f"Adding {placeholder_tokens} tokens to tokenizer")
|
||||
print(f"Adding {self.embed_config.tokens} tokens to tokenizer")
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
# todo SDXL has 2 text encoders, need to do both for all of this
|
||||
self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer))
|
||||
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != self.embed_config.tokens:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
|
||||
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
|
||||
)
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
||||
with torch.no_grad():
|
||||
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
|
||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
init_token_ids = tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
|
||||
# if length of token ids is more than number of orm embedding tokens fill with *
|
||||
if len(init_token_ids) > self.embed_config.tokens:
|
||||
init_token_ids = init_token_ids[:self.embed_config.tokens]
|
||||
elif len(init_token_ids) < self.embed_config.tokens:
|
||||
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
|
||||
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
|
||||
|
||||
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
||||
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
|
||||
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
|
||||
self.placeholder_token_ids.append(placeholder_token_ids)
|
||||
|
||||
# returns the string to have in the prompt to trigger the embedding
|
||||
def get_embedding_string(self):
|
||||
return self.embedding_tokens
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
with torch.no_grad():
|
||||
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
|
||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||
|
||||
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
||||
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
|
||||
|
||||
# backup text encoder embeddings
|
||||
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
|
||||
|
||||
def restore_embeddings(self):
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
|
||||
self.tokenizer_list,
|
||||
self.orig_embeds_params,
|
||||
self.placeholder_token_ids):
|
||||
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[
|
||||
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
|
||||
with torch.no_grad():
|
||||
text_encoder.get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds[index_no_updates]
|
||||
|
||||
def get_trainable_params(self):
|
||||
# todo only get this one as we could have more than one
|
||||
return self.sd.text_encoder.get_input_embeddings().parameters()
|
||||
params = []
|
||||
for text_encoder in self.text_encoder_list:
|
||||
params += text_encoder.get_input_embeddings().parameters()
|
||||
return params
|
||||
|
||||
# make setter and getter for vec
|
||||
@property
|
||||
def vec(self):
|
||||
def _get_vec(self, text_encoder_idx=0):
|
||||
# should we get params instead
|
||||
# create vector from token embeds
|
||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
|
||||
# stack the tokens along batch axis adding that axis
|
||||
new_vector = torch.stack(
|
||||
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
|
||||
[token_embeds[token_id] for token_id in self.placeholder_token_ids[text_encoder_idx]],
|
||||
dim=0
|
||||
)
|
||||
return new_vector
|
||||
|
||||
@vec.setter
|
||||
def vec(self, new_vector):
|
||||
def _set_vec(self, new_vector, text_encoder_idx=0):
|
||||
# shape is (1, 768) for SD 1.5 for 1 token
|
||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
|
||||
for i in range(new_vector.shape[0]):
|
||||
# apply the weights to the placeholder tokens while preserving gradient
|
||||
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
|
||||
x = 1
|
||||
token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone()
|
||||
|
||||
# make setter and getter for vec
|
||||
@property
|
||||
def vec(self):
|
||||
return self._get_vec(0)
|
||||
|
||||
@vec.setter
|
||||
def vec(self, new_vector):
|
||||
self._set_vec(new_vector, 0)
|
||||
|
||||
@property
|
||||
def vec2(self):
|
||||
return self._get_vec(1)
|
||||
|
||||
@vec2.setter
|
||||
def vec2(self, new_vector):
|
||||
self._set_vec(new_vector, 1)
|
||||
|
||||
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
|
||||
# however, on training we don't use that pipeline, so we have to do it ourselves
|
||||
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
||||
output_prompt = prompt
|
||||
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
|
||||
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
replace_with = self.embedding_tokens if expand_token else self.trigger
|
||||
replace_with = embedding_tokens if expand_token else self.trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
@@ -120,7 +162,7 @@ class Embedding:
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = prompt.count(replace_with)
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
@@ -128,10 +170,21 @@ class Embedding:
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
def state_dict(self):
|
||||
if self.sd.is_xl:
|
||||
state_dict = OrderedDict()
|
||||
state_dict['clip_l'] = self.vec
|
||||
state_dict['clip_g'] = self.vec2
|
||||
else:
|
||||
state_dict = OrderedDict()
|
||||
state_dict['emb_params'] = self.vec
|
||||
|
||||
return state_dict
|
||||
|
||||
def save(self, filename):
|
||||
# todo check to see how to get the vector out of the embedding
|
||||
|
||||
@@ -145,13 +198,14 @@ class Embedding:
|
||||
"sd_checkpoint_name": None,
|
||||
"notes": None,
|
||||
}
|
||||
# TODO we do not currently support this. Check how auto is doing it. Only safetensors supported sor sdxl
|
||||
if filename.endswith('.pt'):
|
||||
torch.save(embedding_data, filename)
|
||||
elif filename.endswith('.bin'):
|
||||
torch.save(embedding_data, filename)
|
||||
elif filename.endswith('.safetensors'):
|
||||
# save the embedding as a safetensors file
|
||||
state_dict = {"emb_params": self.vec}
|
||||
state_dict = self.state_dict()
|
||||
# add all embedding data (except string_to_param), to metadata
|
||||
metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"})
|
||||
metadata["string_to_param"] = {"*": "emb_params"}
|
||||
@@ -163,6 +217,7 @@ class Embedding:
|
||||
path = os.path.realpath(file_path)
|
||||
filename = os.path.basename(path)
|
||||
name, ext = os.path.splitext(filename)
|
||||
tensors = {}
|
||||
ext = ext.upper()
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
_, second_ext = os.path.splitext(name)
|
||||
@@ -170,10 +225,12 @@ class Embedding:
|
||||
return
|
||||
|
||||
if ext in ['.BIN', '.PT']:
|
||||
# todo check this
|
||||
if self.sd.is_xl:
|
||||
raise Exception("XL not supported yet for bin, pt")
|
||||
data = torch.load(path, map_location="cpu")
|
||||
elif ext in ['.SAFETENSORS']:
|
||||
# rebuild the embedding from the safetensors file if it has it
|
||||
tensors = {}
|
||||
with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
for k in f.keys():
|
||||
@@ -195,26 +252,32 @@ class Embedding:
|
||||
else:
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict,
|
||||
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
if self.sd.is_xl:
|
||||
self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
|
||||
self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
else:
|
||||
raise Exception(
|
||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict,
|
||||
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
|
||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
# ref https://github.com/scardine/image_size/blob/master/get_image_size.py
|
||||
import atexit
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import io
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderTiny
|
||||
|
||||
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
|
||||
|
||||
@@ -112,7 +119,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
|
||||
width = int(w)
|
||||
height = int(h)
|
||||
elif ((size >= 24) and data.startswith(b'\211PNG\r\n\032\n')
|
||||
and (data[12:16] == b'IHDR')):
|
||||
and (data[12:16] == b'IHDR')):
|
||||
# PNGs
|
||||
imgtype = PNG
|
||||
w, h = struct.unpack(">LL", data[16:24])
|
||||
@@ -190,7 +197,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
|
||||
9: (4, boChar + "l"), # SLONG
|
||||
10: (8, boChar + "ll"), # SRATIONAL
|
||||
11: (4, boChar + "f"), # FLOAT
|
||||
12: (8, boChar + "d") # DOUBLE
|
||||
12: (8, boChar + "d") # DOUBLE
|
||||
}
|
||||
ifdOffset = struct.unpack(boChar + "L", data[4:8])[0]
|
||||
try:
|
||||
@@ -206,7 +213,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
|
||||
input.seek(entryOffset)
|
||||
tag = input.read(2)
|
||||
tag = struct.unpack(boChar + "H", tag)[0]
|
||||
if(tag == 256 or tag == 257):
|
||||
if (tag == 256 or tag == 257):
|
||||
# if type indicates that value fits into 4 bytes, value
|
||||
# offset is not an offset but value itself
|
||||
type = input.read(2)
|
||||
@@ -229,7 +236,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
|
||||
except Exception as e:
|
||||
raise UnknownImageFormat(str(e))
|
||||
elif size >= 2:
|
||||
# see http://en.wikipedia.org/wiki/ICO_(file_format)
|
||||
# see http://en.wikipedia.org/wiki/ICO_(file_format)
|
||||
imgtype = 'ICO'
|
||||
input.seek(0)
|
||||
reserved = input.read(2)
|
||||
@@ -350,13 +357,13 @@ def main(argv=None):
|
||||
|
||||
prs.add_option('-v', '--verbose',
|
||||
dest='verbose',
|
||||
action='store_true',)
|
||||
action='store_true', )
|
||||
prs.add_option('-q', '--quiet',
|
||||
dest='quiet',
|
||||
action='store_true',)
|
||||
action='store_true', )
|
||||
prs.add_option('-t', '--test',
|
||||
dest='run_tests',
|
||||
action='store_true',)
|
||||
action='store_true', )
|
||||
|
||||
argv = list(argv) if argv is not None else sys.argv[1:]
|
||||
(opts, args) = prs.parse_args(args=argv)
|
||||
@@ -417,6 +424,60 @@ def main(argv=None):
|
||||
return EX_OK
|
||||
|
||||
|
||||
is_window_shown = False
|
||||
|
||||
|
||||
def show_img(img, name='AI Toolkit'):
|
||||
global is_window_shown
|
||||
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
cv2.imshow(name, img[:, :, ::-1])
|
||||
k = cv2.waitKey(10) & 0xFF
|
||||
if k == 27: # Esc key to stop
|
||||
print('\nESC pressed, stopping')
|
||||
raise KeyboardInterrupt
|
||||
if not is_window_shown:
|
||||
is_window_shown = True
|
||||
|
||||
|
||||
|
||||
def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
|
||||
# if rank is 4
|
||||
if len(imgs.shape) == 4:
|
||||
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
|
||||
else:
|
||||
img_list = [imgs]
|
||||
# put images side by side
|
||||
img = torch.cat(img_list, dim=3)
|
||||
# img is -1 to 1, convert to 0 to 255
|
||||
img = img / 2 + 0.5
|
||||
img_numpy = img.to(torch.float32).detach().cpu().numpy()
|
||||
img_numpy = np.clip(img_numpy, 0, 1) * 255
|
||||
# convert to numpy Move channel to last
|
||||
img_numpy = img_numpy.transpose(0, 2, 3, 1)
|
||||
# convert to uint8
|
||||
img_numpy = img_numpy.astype(np.uint8)
|
||||
show_img(img_numpy[0], name=name)
|
||||
|
||||
|
||||
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
|
||||
# decode latents
|
||||
if vae.device == 'cpu':
|
||||
vae.to(latents.device)
|
||||
latents = latents / vae.config['scaling_factor']
|
||||
imgs = vae.decode(latents).sample
|
||||
show_tensors(imgs, name=name)
|
||||
|
||||
|
||||
|
||||
def on_exit():
|
||||
if is_window_shown:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
atexit.register(on_exit)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(main(argv=sys.argv[1:]))
|
||||
|
||||
sys.exit(main(argv=sys.argv[1:]))
|
||||
|
||||
410
toolkit/inversion_utils.py
Normal file
410
toolkit/inversion_utils.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# ref https://huggingface.co/spaces/editing-images/ledits/blob/main/inversion_utils.py
|
||||
|
||||
import torch
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
def mu_tilde(model, xt, x0, timestep):
|
||||
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
||||
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
||||
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
||||
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
||||
alpha_t = model.scheduler.alphas[timestep]
|
||||
beta_t = 1 - alpha_t
|
||||
alpha_bar = model.scheduler.alphas_cumprod[timestep]
|
||||
return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1 - alpha_bar)) * x0 + (
|
||||
(alpha_t ** 0.5 * (1 - alpha_prod_t_prev)) / (1 - alpha_bar)) * xt
|
||||
|
||||
|
||||
def sample_xts_from_x0(sd: StableDiffusion, sample: torch.Tensor, num_inference_steps=50):
|
||||
"""
|
||||
Samples from P(x_1:T|x_0)
|
||||
"""
|
||||
# torch.manual_seed(43256465436)
|
||||
alpha_bar = sd.noise_scheduler.alphas_cumprod
|
||||
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
|
||||
alphas = sd.noise_scheduler.alphas
|
||||
betas = 1 - alphas
|
||||
# variance_noise_shape = (
|
||||
# num_inference_steps,
|
||||
# sd.unet.in_channels,
|
||||
# sd.unet.sample_size,
|
||||
# sd.unet.sample_size)
|
||||
variance_noise_shape = list(sample.shape)
|
||||
variance_noise_shape[0] = num_inference_steps
|
||||
|
||||
timesteps = sd.noise_scheduler.timesteps.to(sd.device)
|
||||
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
||||
xts = torch.zeros(variance_noise_shape).to(sample.device, dtype=torch.float16)
|
||||
for t in reversed(timesteps):
|
||||
idx = t_to_idx[int(t)]
|
||||
xts[idx] = sample * (alpha_bar[t] ** 0.5) + torch.randn_like(sample, dtype=torch.float16) * sqrt_one_minus_alpha_bar[t]
|
||||
xts = torch.cat([xts, sample], dim=0)
|
||||
|
||||
return xts
|
||||
|
||||
|
||||
def encode_text(model, prompts):
|
||||
text_input = model.tokenizer(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=model.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
with torch.no_grad():
|
||||
text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
||||
return text_encoding
|
||||
|
||||
|
||||
def forward_step(sd: StableDiffusion, model_output, timestep, sample):
|
||||
next_timestep = min(
|
||||
sd.noise_scheduler.config['num_train_timesteps'] - 2,
|
||||
timestep + sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
|
||||
)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
|
||||
# alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
|
||||
# 5. TODO: simple noising implementation
|
||||
next_sample = sd.noise_scheduler.add_noise(
|
||||
pred_original_sample,
|
||||
model_output,
|
||||
torch.LongTensor([next_timestep]))
|
||||
return next_sample
|
||||
|
||||
|
||||
def get_variance(sd: StableDiffusion, timestep): # , prev_timestep):
|
||||
prev_timestep = timestep - sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
|
||||
alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
|
||||
prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
return variance
|
||||
|
||||
|
||||
def get_time_ids_from_latents(sd: StableDiffusion, latents: torch.Tensor):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(sd.vae.config['block_out_channels']) - 1)
|
||||
if sd.is_xl:
|
||||
bs, ch, h, w = list(latents.shape)
|
||||
|
||||
height = h * VAE_SCALE_FACTOR
|
||||
width = w * VAE_SCALE_FACTOR
|
||||
|
||||
dtype = latents.dtype
|
||||
# just do it without any cropping nonsense
|
||||
target_size = (height, width)
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
||||
|
||||
batch_time_ids = torch.cat(
|
||||
[add_time_ids for _ in range(bs)]
|
||||
)
|
||||
return batch_time_ids
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def inversion_forward_process(
|
||||
sd: StableDiffusion,
|
||||
sample: torch.Tensor,
|
||||
conditional_embeddings: PromptEmbeds,
|
||||
unconditional_embeddings: PromptEmbeds,
|
||||
etas=None,
|
||||
prog_bar=False,
|
||||
cfg_scale=3.5,
|
||||
num_inference_steps=50, eps=None
|
||||
):
|
||||
current_num_timesteps = len(sd.noise_scheduler.timesteps)
|
||||
sd.noise_scheduler.set_timesteps(num_inference_steps, device=sd.device)
|
||||
|
||||
timesteps = sd.noise_scheduler.timesteps.to(sd.device)
|
||||
# variance_noise_shape = (
|
||||
# num_inference_steps,
|
||||
# sd.unet.in_channels,
|
||||
# sd.unet.sample_size,
|
||||
# sd.unet.sample_size
|
||||
# )
|
||||
variance_noise_shape = list(sample.shape)
|
||||
variance_noise_shape[0] = num_inference_steps
|
||||
if etas is None or (type(etas) in [int, float] and etas == 0):
|
||||
eta_is_zero = True
|
||||
zs = None
|
||||
else:
|
||||
eta_is_zero = False
|
||||
if type(etas) in [int, float]: etas = [etas] * sd.noise_scheduler.num_inference_steps
|
||||
xts = sample_xts_from_x0(sd, sample, num_inference_steps=num_inference_steps)
|
||||
alpha_bar = sd.noise_scheduler.alphas_cumprod
|
||||
zs = torch.zeros(size=variance_noise_shape, device=sd.device, dtype=torch.float16)
|
||||
|
||||
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
||||
noisy_sample = sample
|
||||
op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
|
||||
|
||||
for timestep in op:
|
||||
idx = t_to_idx[int(timestep)]
|
||||
# 1. predict noise residual
|
||||
if not eta_is_zero:
|
||||
noisy_sample = xts[idx][None]
|
||||
|
||||
added_cond_kwargs = {}
|
||||
|
||||
with torch.no_grad():
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
1, # batch size
|
||||
)
|
||||
if sd.is_xl:
|
||||
add_time_ids = get_time_ids_from_latents(sd, noisy_sample)
|
||||
# add extra for cfg
|
||||
add_time_ids = torch.cat(
|
||||
[add_time_ids] * 2, dim=0
|
||||
)
|
||||
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": add_time_ids,
|
||||
}
|
||||
|
||||
# double up for cfg
|
||||
latent_model_input = torch.cat(
|
||||
[noisy_sample] * 2, dim=0
|
||||
)
|
||||
|
||||
noise_pred = sd.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
# out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=uncond_embedding)
|
||||
# cond_out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=text_embeddings)
|
||||
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if eta_is_zero:
|
||||
# 2. compute more noisy image and set x_t -> x_t+1
|
||||
noisy_sample = forward_step(sd, noise_pred, timestep, noisy_sample)
|
||||
xts = None
|
||||
|
||||
else:
|
||||
xtm1 = xts[idx + 1][None]
|
||||
# pred of x0
|
||||
pred_original_sample = (noisy_sample - (1 - alpha_bar[timestep]) ** 0.5 * noise_pred) / alpha_bar[
|
||||
timestep] ** 0.5
|
||||
|
||||
# direction to xt
|
||||
prev_timestep = timestep - sd.noise_scheduler.config[
|
||||
'num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
|
||||
alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
|
||||
prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
|
||||
|
||||
variance = get_variance(sd, timestep)
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
|
||||
|
||||
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
||||
zs[idx] = z
|
||||
|
||||
# correction to avoid error accumulation
|
||||
xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
|
||||
xts[idx + 1] = xtm1
|
||||
|
||||
if not zs is None:
|
||||
zs[-1] = torch.zeros_like(zs[-1])
|
||||
|
||||
# restore timesteps
|
||||
sd.noise_scheduler.set_timesteps(current_num_timesteps, device=sd.device)
|
||||
|
||||
return noisy_sample, zs, xts
|
||||
|
||||
|
||||
#
|
||||
# def inversion_forward_process(
|
||||
# model,
|
||||
# sample,
|
||||
# etas=None,
|
||||
# prog_bar=False,
|
||||
# prompt="",
|
||||
# cfg_scale=3.5,
|
||||
# num_inference_steps=50, eps=None
|
||||
# ):
|
||||
# if not prompt == "":
|
||||
# text_embeddings = encode_text(model, prompt)
|
||||
# uncond_embedding = encode_text(model, "")
|
||||
# timesteps = model.scheduler.timesteps.to(model.device)
|
||||
# variance_noise_shape = (
|
||||
# num_inference_steps,
|
||||
# model.unet.in_channels,
|
||||
# model.unet.sample_size,
|
||||
# model.unet.sample_size)
|
||||
# if etas is None or (type(etas) in [int, float] and etas == 0):
|
||||
# eta_is_zero = True
|
||||
# zs = None
|
||||
# else:
|
||||
# eta_is_zero = False
|
||||
# if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
|
||||
# xts = sample_xts_from_x0(model, sample, num_inference_steps=num_inference_steps)
|
||||
# alpha_bar = model.scheduler.alphas_cumprod
|
||||
# zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype=torch.float16)
|
||||
#
|
||||
# t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
||||
# noisy_sample = sample
|
||||
# op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
|
||||
#
|
||||
# for t in op:
|
||||
# idx = t_to_idx[int(t)]
|
||||
# # 1. predict noise residual
|
||||
# if not eta_is_zero:
|
||||
# noisy_sample = xts[idx][None]
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=uncond_embedding)
|
||||
# if not prompt == "":
|
||||
# cond_out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=text_embeddings)
|
||||
#
|
||||
# if not prompt == "":
|
||||
# ## classifier free guidance
|
||||
# noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
|
||||
# else:
|
||||
# noise_pred = out.sample
|
||||
#
|
||||
# if eta_is_zero:
|
||||
# # 2. compute more noisy image and set x_t -> x_t+1
|
||||
# noisy_sample = forward_step(model, noise_pred, t, noisy_sample)
|
||||
#
|
||||
# else:
|
||||
# xtm1 = xts[idx + 1][None]
|
||||
# # pred of x0
|
||||
# pred_original_sample = (noisy_sample - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
|
||||
#
|
||||
# # direction to xt
|
||||
# prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
||||
# alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
||||
# prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
||||
#
|
||||
# variance = get_variance(model, t)
|
||||
# pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
|
||||
#
|
||||
# mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
#
|
||||
# z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
||||
# zs[idx] = z
|
||||
#
|
||||
# # correction to avoid error accumulation
|
||||
# xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
|
||||
# xts[idx + 1] = xtm1
|
||||
#
|
||||
# if not zs is None:
|
||||
# zs[-1] = torch.zeros_like(zs[-1])
|
||||
#
|
||||
# return noisy_sample, zs, xts
|
||||
|
||||
|
||||
def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None):
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
|
||||
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
||||
variance = get_variance(model, timestep) # , prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
# Take care of asymetric reverse process (asyrp)
|
||||
model_output_direction = model_output
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
# 8. Add noice if eta > 0
|
||||
if eta > 0:
|
||||
if variance_noise is None:
|
||||
variance_noise = torch.randn(model_output.shape, device=model.device, dtype=torch.float16)
|
||||
sigma_z = eta * variance ** (0.5) * variance_noise
|
||||
prev_sample = prev_sample + sigma_z
|
||||
|
||||
return prev_sample
|
||||
|
||||
|
||||
def inversion_reverse_process(
|
||||
model,
|
||||
xT,
|
||||
etas=0,
|
||||
prompts="",
|
||||
cfg_scales=None,
|
||||
prog_bar=False,
|
||||
zs=None,
|
||||
controller=None,
|
||||
asyrp=False):
|
||||
batch_size = len(prompts)
|
||||
|
||||
cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1, 1, 1, 1).to(model.device, dtype=torch.float16)
|
||||
|
||||
text_embeddings = encode_text(model, prompts)
|
||||
uncond_embedding = encode_text(model, [""] * batch_size)
|
||||
|
||||
if etas is None: etas = 0
|
||||
if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
|
||||
assert len(etas) == model.scheduler.num_inference_steps
|
||||
timesteps = model.scheduler.timesteps.to(model.device)
|
||||
|
||||
xt = xT.expand(batch_size, -1, -1, -1)
|
||||
op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
|
||||
|
||||
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
||||
|
||||
for t in op:
|
||||
idx = t_to_idx[int(t)]
|
||||
## Unconditional embedding
|
||||
with torch.no_grad():
|
||||
uncond_out = model.unet.forward(xt, timestep=t,
|
||||
encoder_hidden_states=uncond_embedding)
|
||||
|
||||
## Conditional embedding
|
||||
if prompts:
|
||||
with torch.no_grad():
|
||||
cond_out = model.unet.forward(xt, timestep=t,
|
||||
encoder_hidden_states=text_embeddings)
|
||||
|
||||
z = zs[idx] if not zs is None else None
|
||||
z = z.expand(batch_size, -1, -1, -1)
|
||||
if prompts:
|
||||
## classifier free guidance
|
||||
noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
|
||||
else:
|
||||
noise_pred = uncond_out.sample
|
||||
# 2. compute less noisy image and set x_t -> x_t-1
|
||||
xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z)
|
||||
if controller is not None:
|
||||
xt = controller.step_callback(xt)
|
||||
return xt, zs
|
||||
157
toolkit/ip_adapter.py
Normal file
157
toolkit/ip_adapter.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
|
||||
class IPAdapter(torch.nn.Module):
|
||||
"""IP-Adapter"""
|
||||
|
||||
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
|
||||
super().__init__()
|
||||
self.config = adapter_config
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path)
|
||||
if adapter_config.type == 'ip':
|
||||
# ip-adapter
|
||||
image_proj_model = ImageProjModel(
|
||||
cross_attention_dim=sd.unet.config['cross_attention_dim'],
|
||||
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=4,
|
||||
)
|
||||
elif adapter_config.type == 'ip+':
|
||||
# ip-adapter-plus
|
||||
num_tokens = 16
|
||||
image_proj_model = Resampler(
|
||||
dim=sd.unet.config['cross_attention_dim'],
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=num_tokens,
|
||||
embedding_dim=self.image_encoder.config.hidden_size,
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {adapter_config.type}")
|
||||
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = sd.unet.state_dict()
|
||||
for name in sd.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor()
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||
}
|
||||
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
sd.adapter = self
|
||||
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
|
||||
self.image_proj_model = image_proj_model
|
||||
self.adapter_modules = adapter_modules
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.image_encoder.to(*args, **kwargs)
|
||||
self.image_proj_model.to(*args, **kwargs)
|
||||
self.adapter_modules.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
|
||||
def state_dict(self) -> OrderedDict:
|
||||
state_dict = OrderedDict()
|
||||
state_dict["image_proj"] = self.image_proj_model.state_dict()
|
||||
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
||||
return state_dict
|
||||
|
||||
def set_scale(self, scale):
|
||||
for attn_processor in self.pipe.unet.attn_processors.values():
|
||||
if isinstance(attn_processor, IPAttnProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@torch.no_grad()
|
||||
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], drop=False) -> torch.Tensor:
|
||||
# todo: add support for sdxl
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
return clip_image_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False) -> torch.Tensor:
|
||||
# tensors should be 0-1
|
||||
# todo: add support for sdxl
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
return clip_image_embeds
|
||||
|
||||
# use drop for prompt dropout, or negatives
|
||||
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
||||
clip_image_embeds = clip_image_embeds.detach()
|
||||
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds.detach())
|
||||
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
||||
return embeddings
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
for attn_processor in self.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
yield from self.image_proj_model.parameters(recurse)
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
3154
toolkit/keymaps/stable_diffusion_locon_sdxl.json
Normal file
3154
toolkit/keymaps/stable_diffusion_locon_sdxl.json
Normal file
File diff suppressed because it is too large
Load Diff
3498
toolkit/keymaps/stable_diffusion_refiner.json
Normal file
3498
toolkit/keymaps/stable_diffusion_refiner.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors
Normal file
BIN
toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors
Normal file
Binary file not shown.
27
toolkit/keymaps/stable_diffusion_refiner_unmatched.json
Normal file
27
toolkit/keymaps/stable_diffusion_refiner_unmatched.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"ldm": {
|
||||
"conditioner.embedders.0.model.logit_scale": {
|
||||
"shape": [],
|
||||
"min": 4.60546875,
|
||||
"max": 4.60546875
|
||||
},
|
||||
"conditioner.embedders.0.model.text_projection": {
|
||||
"shape": [
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"min": -0.15966796875,
|
||||
"max": 0.230712890625
|
||||
}
|
||||
},
|
||||
"diffusers": {
|
||||
"te1_text_projection.weight": {
|
||||
"shape": [
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"min": -0.15966796875,
|
||||
"max": 0.230712890625
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,8 +4,6 @@
|
||||
"cond_stage_model.model.ln_final.weight": "te_text_model.final_layer_norm.weight",
|
||||
"cond_stage_model.model.positional_embedding": "te_text_model.embeddings.position_embedding.weight",
|
||||
"cond_stage_model.model.token_embedding.weight": "te_text_model.embeddings.token_embedding.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias": "te_text_model.encoder.layers.0.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight": "te_text_model.encoder.layers.0.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.0.attn.out_proj.bias": "te_text_model.encoder.layers.0.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight": "te_text_model.encoder.layers.0.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.0.ln_1.bias": "te_text_model.encoder.layers.0.layer_norm1.bias",
|
||||
@@ -16,8 +14,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.weight": "te_text_model.encoder.layers.0.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.bias": "te_text_model.encoder.layers.0.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.weight": "te_text_model.encoder.layers.0.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias": "te_text_model.encoder.layers.1.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight": "te_text_model.encoder.layers.1.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.1.attn.out_proj.bias": "te_text_model.encoder.layers.1.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.1.attn.out_proj.weight": "te_text_model.encoder.layers.1.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.1.ln_1.bias": "te_text_model.encoder.layers.1.layer_norm1.bias",
|
||||
@@ -28,8 +24,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.weight": "te_text_model.encoder.layers.1.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.bias": "te_text_model.encoder.layers.1.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.weight": "te_text_model.encoder.layers.1.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias": "te_text_model.encoder.layers.10.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight": "te_text_model.encoder.layers.10.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.10.attn.out_proj.bias": "te_text_model.encoder.layers.10.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.10.attn.out_proj.weight": "te_text_model.encoder.layers.10.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.10.ln_1.bias": "te_text_model.encoder.layers.10.layer_norm1.bias",
|
||||
@@ -40,8 +34,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.weight": "te_text_model.encoder.layers.10.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.bias": "te_text_model.encoder.layers.10.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.weight": "te_text_model.encoder.layers.10.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias": "te_text_model.encoder.layers.11.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight": "te_text_model.encoder.layers.11.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.11.attn.out_proj.bias": "te_text_model.encoder.layers.11.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.11.attn.out_proj.weight": "te_text_model.encoder.layers.11.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.11.ln_1.bias": "te_text_model.encoder.layers.11.layer_norm1.bias",
|
||||
@@ -52,8 +44,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.weight": "te_text_model.encoder.layers.11.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.bias": "te_text_model.encoder.layers.11.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.weight": "te_text_model.encoder.layers.11.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias": "te_text_model.encoder.layers.12.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight": "te_text_model.encoder.layers.12.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.12.attn.out_proj.bias": "te_text_model.encoder.layers.12.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.12.attn.out_proj.weight": "te_text_model.encoder.layers.12.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.12.ln_1.bias": "te_text_model.encoder.layers.12.layer_norm1.bias",
|
||||
@@ -64,8 +54,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.weight": "te_text_model.encoder.layers.12.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.bias": "te_text_model.encoder.layers.12.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.weight": "te_text_model.encoder.layers.12.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias": "te_text_model.encoder.layers.13.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight": "te_text_model.encoder.layers.13.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.13.attn.out_proj.bias": "te_text_model.encoder.layers.13.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.13.attn.out_proj.weight": "te_text_model.encoder.layers.13.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.13.ln_1.bias": "te_text_model.encoder.layers.13.layer_norm1.bias",
|
||||
@@ -76,8 +64,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.weight": "te_text_model.encoder.layers.13.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.bias": "te_text_model.encoder.layers.13.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.weight": "te_text_model.encoder.layers.13.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias": "te_text_model.encoder.layers.14.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight": "te_text_model.encoder.layers.14.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.14.attn.out_proj.bias": "te_text_model.encoder.layers.14.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.14.attn.out_proj.weight": "te_text_model.encoder.layers.14.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.14.ln_1.bias": "te_text_model.encoder.layers.14.layer_norm1.bias",
|
||||
@@ -88,8 +74,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.weight": "te_text_model.encoder.layers.14.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.bias": "te_text_model.encoder.layers.14.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.weight": "te_text_model.encoder.layers.14.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias": "te_text_model.encoder.layers.15.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight": "te_text_model.encoder.layers.15.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.15.attn.out_proj.bias": "te_text_model.encoder.layers.15.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.15.attn.out_proj.weight": "te_text_model.encoder.layers.15.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.15.ln_1.bias": "te_text_model.encoder.layers.15.layer_norm1.bias",
|
||||
@@ -100,8 +84,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.weight": "te_text_model.encoder.layers.15.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.bias": "te_text_model.encoder.layers.15.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.weight": "te_text_model.encoder.layers.15.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias": "te_text_model.encoder.layers.16.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight": "te_text_model.encoder.layers.16.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.16.attn.out_proj.bias": "te_text_model.encoder.layers.16.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.16.attn.out_proj.weight": "te_text_model.encoder.layers.16.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.16.ln_1.bias": "te_text_model.encoder.layers.16.layer_norm1.bias",
|
||||
@@ -112,8 +94,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.weight": "te_text_model.encoder.layers.16.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.bias": "te_text_model.encoder.layers.16.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.weight": "te_text_model.encoder.layers.16.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias": "te_text_model.encoder.layers.17.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight": "te_text_model.encoder.layers.17.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.17.attn.out_proj.bias": "te_text_model.encoder.layers.17.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.17.attn.out_proj.weight": "te_text_model.encoder.layers.17.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.17.ln_1.bias": "te_text_model.encoder.layers.17.layer_norm1.bias",
|
||||
@@ -124,8 +104,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.weight": "te_text_model.encoder.layers.17.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.bias": "te_text_model.encoder.layers.17.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.weight": "te_text_model.encoder.layers.17.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias": "te_text_model.encoder.layers.18.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight": "te_text_model.encoder.layers.18.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.18.attn.out_proj.bias": "te_text_model.encoder.layers.18.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.18.attn.out_proj.weight": "te_text_model.encoder.layers.18.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.18.ln_1.bias": "te_text_model.encoder.layers.18.layer_norm1.bias",
|
||||
@@ -136,8 +114,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.weight": "te_text_model.encoder.layers.18.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.bias": "te_text_model.encoder.layers.18.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.weight": "te_text_model.encoder.layers.18.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias": "te_text_model.encoder.layers.19.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight": "te_text_model.encoder.layers.19.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.19.attn.out_proj.bias": "te_text_model.encoder.layers.19.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.19.attn.out_proj.weight": "te_text_model.encoder.layers.19.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.19.ln_1.bias": "te_text_model.encoder.layers.19.layer_norm1.bias",
|
||||
@@ -148,8 +124,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.weight": "te_text_model.encoder.layers.19.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.bias": "te_text_model.encoder.layers.19.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.weight": "te_text_model.encoder.layers.19.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias": "te_text_model.encoder.layers.2.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight": "te_text_model.encoder.layers.2.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.2.attn.out_proj.bias": "te_text_model.encoder.layers.2.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.2.attn.out_proj.weight": "te_text_model.encoder.layers.2.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.2.ln_1.bias": "te_text_model.encoder.layers.2.layer_norm1.bias",
|
||||
@@ -160,8 +134,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.weight": "te_text_model.encoder.layers.2.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.bias": "te_text_model.encoder.layers.2.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.weight": "te_text_model.encoder.layers.2.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias": "te_text_model.encoder.layers.20.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight": "te_text_model.encoder.layers.20.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.20.attn.out_proj.bias": "te_text_model.encoder.layers.20.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.20.attn.out_proj.weight": "te_text_model.encoder.layers.20.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.20.ln_1.bias": "te_text_model.encoder.layers.20.layer_norm1.bias",
|
||||
@@ -172,8 +144,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.weight": "te_text_model.encoder.layers.20.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.bias": "te_text_model.encoder.layers.20.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.weight": "te_text_model.encoder.layers.20.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias": "te_text_model.encoder.layers.21.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight": "te_text_model.encoder.layers.21.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.21.attn.out_proj.bias": "te_text_model.encoder.layers.21.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.21.attn.out_proj.weight": "te_text_model.encoder.layers.21.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.21.ln_1.bias": "te_text_model.encoder.layers.21.layer_norm1.bias",
|
||||
@@ -184,8 +154,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.weight": "te_text_model.encoder.layers.21.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.bias": "te_text_model.encoder.layers.21.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.weight": "te_text_model.encoder.layers.21.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias": "te_text_model.encoder.layers.22.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight": "te_text_model.encoder.layers.22.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.22.attn.out_proj.bias": "te_text_model.encoder.layers.22.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight": "te_text_model.encoder.layers.22.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.22.ln_1.bias": "te_text_model.encoder.layers.22.layer_norm1.bias",
|
||||
@@ -196,8 +164,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.weight": "te_text_model.encoder.layers.22.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.bias": "te_text_model.encoder.layers.22.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.weight": "te_text_model.encoder.layers.22.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias": "te_text_model.encoder.layers.3.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight": "te_text_model.encoder.layers.3.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.3.attn.out_proj.bias": "te_text_model.encoder.layers.3.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.3.attn.out_proj.weight": "te_text_model.encoder.layers.3.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.3.ln_1.bias": "te_text_model.encoder.layers.3.layer_norm1.bias",
|
||||
@@ -208,8 +174,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.weight": "te_text_model.encoder.layers.3.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.bias": "te_text_model.encoder.layers.3.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.weight": "te_text_model.encoder.layers.3.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias": "te_text_model.encoder.layers.4.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight": "te_text_model.encoder.layers.4.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.4.attn.out_proj.bias": "te_text_model.encoder.layers.4.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.4.attn.out_proj.weight": "te_text_model.encoder.layers.4.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.4.ln_1.bias": "te_text_model.encoder.layers.4.layer_norm1.bias",
|
||||
@@ -220,8 +184,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.weight": "te_text_model.encoder.layers.4.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.bias": "te_text_model.encoder.layers.4.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.weight": "te_text_model.encoder.layers.4.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias": "te_text_model.encoder.layers.5.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight": "te_text_model.encoder.layers.5.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.5.attn.out_proj.bias": "te_text_model.encoder.layers.5.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.5.attn.out_proj.weight": "te_text_model.encoder.layers.5.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.5.ln_1.bias": "te_text_model.encoder.layers.5.layer_norm1.bias",
|
||||
@@ -232,8 +194,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.weight": "te_text_model.encoder.layers.5.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.bias": "te_text_model.encoder.layers.5.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.weight": "te_text_model.encoder.layers.5.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias": "te_text_model.encoder.layers.6.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight": "te_text_model.encoder.layers.6.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.6.attn.out_proj.bias": "te_text_model.encoder.layers.6.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.6.attn.out_proj.weight": "te_text_model.encoder.layers.6.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.6.ln_1.bias": "te_text_model.encoder.layers.6.layer_norm1.bias",
|
||||
@@ -244,8 +204,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.weight": "te_text_model.encoder.layers.6.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.bias": "te_text_model.encoder.layers.6.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.weight": "te_text_model.encoder.layers.6.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias": "te_text_model.encoder.layers.7.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight": "te_text_model.encoder.layers.7.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.7.attn.out_proj.bias": "te_text_model.encoder.layers.7.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.7.attn.out_proj.weight": "te_text_model.encoder.layers.7.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.7.ln_1.bias": "te_text_model.encoder.layers.7.layer_norm1.bias",
|
||||
@@ -256,8 +214,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.weight": "te_text_model.encoder.layers.7.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.bias": "te_text_model.encoder.layers.7.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.weight": "te_text_model.encoder.layers.7.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias": "te_text_model.encoder.layers.8.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight": "te_text_model.encoder.layers.8.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.8.attn.out_proj.bias": "te_text_model.encoder.layers.8.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.8.attn.out_proj.weight": "te_text_model.encoder.layers.8.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.8.ln_1.bias": "te_text_model.encoder.layers.8.layer_norm1.bias",
|
||||
@@ -268,8 +224,6 @@
|
||||
"cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.weight": "te_text_model.encoder.layers.8.mlp.fc1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.bias": "te_text_model.encoder.layers.8.mlp.fc2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.weight": "te_text_model.encoder.layers.8.mlp.fc2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias": "te_text_model.encoder.layers.9.self_attn.MERGED.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight": "te_text_model.encoder.layers.9.self_attn.MERGED.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.9.attn.out_proj.bias": "te_text_model.encoder.layers.9.self_attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.9.attn.out_proj.weight": "te_text_model.encoder.layers.9.self_attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.9.ln_1.bias": "te_text_model.encoder.layers.9.layer_norm1.bias",
|
||||
@@ -530,11 +484,11 @@
|
||||
"first_stage_model.quant_conv.weight": "vae_quant_conv.weight",
|
||||
"model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias",
|
||||
"model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight",
|
||||
"model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias",
|
||||
"model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight",
|
||||
"model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight",
|
||||
"model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias",
|
||||
"model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight",
|
||||
@@ -566,31 +520,31 @@
|
||||
"model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_time_embedding.linear_2.weight",
|
||||
"model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias",
|
||||
"model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight",
|
||||
"model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight",
|
||||
"model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias",
|
||||
"model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight",
|
||||
"model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias",
|
||||
"model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight",
|
||||
"model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias",
|
||||
"model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight",
|
||||
"model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight",
|
||||
"model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias",
|
||||
"model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight",
|
||||
"model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias",
|
||||
"model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight",
|
||||
"model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias",
|
||||
"model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight",
|
||||
"model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight",
|
||||
"model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias",
|
||||
"model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight",
|
||||
@@ -624,11 +578,11 @@
|
||||
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias",
|
||||
"model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight",
|
||||
"model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias",
|
||||
"model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight",
|
||||
"model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight",
|
||||
"model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias",
|
||||
"model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight",
|
||||
@@ -662,11 +616,11 @@
|
||||
"model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias",
|
||||
"model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight",
|
||||
"model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight",
|
||||
"model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias",
|
||||
"model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight",
|
||||
@@ -700,11 +654,11 @@
|
||||
"model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias",
|
||||
"model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight",
|
||||
"model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias",
|
||||
"model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight",
|
||||
"model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight",
|
||||
"model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias",
|
||||
"model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight",
|
||||
@@ -738,11 +692,11 @@
|
||||
"model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias",
|
||||
"model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight",
|
||||
"model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight",
|
||||
"model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias",
|
||||
"model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight",
|
||||
@@ -776,11 +730,11 @@
|
||||
"model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias",
|
||||
"model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight",
|
||||
"model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias",
|
||||
"model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight",
|
||||
"model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight",
|
||||
"model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias",
|
||||
"model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight",
|
||||
@@ -812,11 +766,11 @@
|
||||
"model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias",
|
||||
"model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight",
|
||||
"model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight",
|
||||
"model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias",
|
||||
"model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight",
|
||||
@@ -826,11 +780,11 @@
|
||||
"model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight",
|
||||
"model.diffusion_model.out.2.bias": "unet_conv_out.bias",
|
||||
"model.diffusion_model.out.2.weight": "unet_conv_out.weight",
|
||||
"model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight",
|
||||
@@ -838,11 +792,11 @@
|
||||
"model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight",
|
||||
"model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias",
|
||||
"model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight",
|
||||
"model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight",
|
||||
@@ -850,11 +804,11 @@
|
||||
"model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight",
|
||||
"model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias",
|
||||
"model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight",
|
||||
"model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight",
|
||||
@@ -888,11 +842,11 @@
|
||||
"model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight",
|
||||
@@ -926,11 +880,11 @@
|
||||
"model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight",
|
||||
@@ -940,11 +894,11 @@
|
||||
"model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight",
|
||||
"model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias",
|
||||
"model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight",
|
||||
"model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight",
|
||||
@@ -978,11 +932,11 @@
|
||||
"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight",
|
||||
@@ -1016,11 +970,11 @@
|
||||
"model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight",
|
||||
@@ -1056,11 +1010,11 @@
|
||||
"model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias",
|
||||
"model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight",
|
||||
"model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight",
|
||||
@@ -1094,11 +1048,11 @@
|
||||
"model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight",
|
||||
@@ -1132,11 +1086,11 @@
|
||||
"model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
|
||||
"model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
|
||||
"model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight",
|
||||
@@ -1172,11 +1126,11 @@
|
||||
"model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight",
|
||||
"model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias",
|
||||
"model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight",
|
||||
"model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight",
|
||||
"model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias",
|
||||
"model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight",
|
||||
"model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias",
|
||||
"model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias",
|
||||
"model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight",
|
||||
"model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias",
|
||||
"model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight",
|
||||
@@ -1213,7 +1167,7 @@
|
||||
"model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias",
|
||||
"model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight",
|
||||
"model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias",
|
||||
"model.diffusion_model.time_embed.2.weight": "unet_mid_block.resnets.1.time_emb_proj.weight"
|
||||
"model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight"
|
||||
},
|
||||
"ldm_diffusers_shape_map": {
|
||||
"first_stage_model.decoder.mid.attn_1.k.weight": [
|
||||
@@ -1264,62 +1218,6 @@
|
||||
512
|
||||
]
|
||||
],
|
||||
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": [
|
||||
[
|
||||
128,
|
||||
256,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
128,
|
||||
256,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": [
|
||||
[
|
||||
256,
|
||||
512,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
256,
|
||||
512,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": [
|
||||
[
|
||||
256,
|
||||
128,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
256,
|
||||
128,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": [
|
||||
[
|
||||
512,
|
||||
256,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
512,
|
||||
256,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"first_stage_model.encoder.mid.attn_1.k.weight": [
|
||||
[
|
||||
512,
|
||||
@@ -1367,230 +1265,6 @@
|
||||
512,
|
||||
512
|
||||
]
|
||||
],
|
||||
"first_stage_model.post_quant_conv.weight": [
|
||||
[
|
||||
4,
|
||||
4,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
4,
|
||||
4,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"first_stage_model.quant_conv.weight": [
|
||||
[
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.input_blocks.4.0.skip_connection.weight": [
|
||||
[
|
||||
640,
|
||||
320,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
640,
|
||||
320,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.input_blocks.7.0.skip_connection.weight": [
|
||||
[
|
||||
1280,
|
||||
640,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
1280,
|
||||
640,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.0.0.skip_connection.weight": [
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.1.0.skip_connection.weight": [
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.10.0.skip_connection.weight": [
|
||||
[
|
||||
320,
|
||||
640,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
320,
|
||||
640,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.11.0.skip_connection.weight": [
|
||||
[
|
||||
320,
|
||||
640,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
320,
|
||||
640,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.2.0.skip_connection.weight": [
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.3.0.skip_connection.weight": [
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.4.0.skip_connection.weight": [
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
1280,
|
||||
2560,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.5.0.skip_connection.weight": [
|
||||
[
|
||||
1280,
|
||||
1920,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
1280,
|
||||
1920,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.6.0.skip_connection.weight": [
|
||||
[
|
||||
640,
|
||||
1920,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
640,
|
||||
1920,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.7.0.skip_connection.weight": [
|
||||
[
|
||||
640,
|
||||
1280,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
640,
|
||||
1280,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.8.0.skip_connection.weight": [
|
||||
[
|
||||
640,
|
||||
960,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
640,
|
||||
960,
|
||||
1,
|
||||
1
|
||||
]
|
||||
],
|
||||
"model.diffusion_model.output_blocks.9.0.skip_connection.weight": [
|
||||
[
|
||||
320,
|
||||
960,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
320,
|
||||
960,
|
||||
1,
|
||||
1
|
||||
]
|
||||
]
|
||||
},
|
||||
"ldm_diffusers_operator_map": {
|
||||
@@ -1606,8 +1280,7 @@
|
||||
"te_text_model.encoder.layers.0.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.0.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.0.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.0.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1621,8 +1294,7 @@
|
||||
"te_text_model.encoder.layers.1.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.1.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.1.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.1.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1636,8 +1308,7 @@
|
||||
"te_text_model.encoder.layers.10.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.10.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.10.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.10.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1651,8 +1322,7 @@
|
||||
"te_text_model.encoder.layers.11.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.11.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.11.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.11.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1666,8 +1336,7 @@
|
||||
"te_text_model.encoder.layers.12.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.12.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.12.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.12.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1681,8 +1350,7 @@
|
||||
"te_text_model.encoder.layers.13.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.13.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.13.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.13.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1696,8 +1364,7 @@
|
||||
"te_text_model.encoder.layers.14.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.14.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.14.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.14.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1711,8 +1378,7 @@
|
||||
"te_text_model.encoder.layers.15.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.15.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.15.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.15.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1726,8 +1392,7 @@
|
||||
"te_text_model.encoder.layers.16.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.16.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.16.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.16.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1741,8 +1406,7 @@
|
||||
"te_text_model.encoder.layers.17.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.17.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.17.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.17.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1756,8 +1420,7 @@
|
||||
"te_text_model.encoder.layers.18.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.18.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.18.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.18.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1771,8 +1434,7 @@
|
||||
"te_text_model.encoder.layers.19.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.19.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.19.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.19.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1786,8 +1448,7 @@
|
||||
"te_text_model.encoder.layers.2.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.2.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.2.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.2.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1801,8 +1462,7 @@
|
||||
"te_text_model.encoder.layers.20.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.20.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.20.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.20.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1816,8 +1476,7 @@
|
||||
"te_text_model.encoder.layers.21.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.21.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.21.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.21.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1831,8 +1490,7 @@
|
||||
"te_text_model.encoder.layers.22.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.22.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.22.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.22.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1846,8 +1504,7 @@
|
||||
"te_text_model.encoder.layers.3.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.3.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.3.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.3.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1861,8 +1518,7 @@
|
||||
"te_text_model.encoder.layers.4.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.4.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.4.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.4.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1876,8 +1532,7 @@
|
||||
"te_text_model.encoder.layers.5.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.5.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.5.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.5.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1891,8 +1546,7 @@
|
||||
"te_text_model.encoder.layers.6.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.6.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.6.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.6.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1906,8 +1560,7 @@
|
||||
"te_text_model.encoder.layers.7.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.7.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.7.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.7.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1921,8 +1574,7 @@
|
||||
"te_text_model.encoder.layers.8.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.8.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.8.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.8.self_attn.MERGED.weight"
|
||||
]
|
||||
},
|
||||
"cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias": {
|
||||
"cat": [
|
||||
@@ -1936,8 +1588,7 @@
|
||||
"te_text_model.encoder.layers.9.self_attn.q_proj.weight",
|
||||
"te_text_model.encoder.layers.9.self_attn.k_proj.weight",
|
||||
"te_text_model.encoder.layers.9.self_attn.v_proj.weight"
|
||||
],
|
||||
"target": "te_text_model.encoder.layers.9.self_attn.MERGED.weight"
|
||||
]
|
||||
}
|
||||
},
|
||||
"diffusers_ldm_operator_map": {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,16 +6,12 @@
|
||||
77
|
||||
],
|
||||
"min": 0.0,
|
||||
"max": 76.0,
|
||||
"mean": 38.0,
|
||||
"std": 22.375
|
||||
"max": 76.0
|
||||
},
|
||||
"conditioner.embedders.1.model.logit_scale": {
|
||||
"shape": [],
|
||||
"min": 4.60546875,
|
||||
"max": 4.60546875,
|
||||
"mean": 4.60546875,
|
||||
"std": NaN
|
||||
"max": 4.60546875
|
||||
},
|
||||
"conditioner.embedders.1.model.text_projection": {
|
||||
"shape": [
|
||||
@@ -23,9 +19,7 @@
|
||||
1280
|
||||
],
|
||||
"min": -0.15966796875,
|
||||
"max": 0.230712890625,
|
||||
"mean": 0.0,
|
||||
"std": 0.0181732177734375
|
||||
"max": 0.230712890625
|
||||
}
|
||||
},
|
||||
"diffusers": {
|
||||
@@ -35,9 +29,7 @@
|
||||
1280
|
||||
],
|
||||
"min": -0.15966796875,
|
||||
"max": 0.230712890625,
|
||||
"mean": 2.128152846125886e-05,
|
||||
"std": 0.018169498071074486
|
||||
"max": 0.230712890625
|
||||
}
|
||||
}
|
||||
}
|
||||
3419
toolkit/keymaps/stable_diffusion_ssd.json
Normal file
3419
toolkit/keymaps/stable_diffusion_ssd.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors
Normal file
BIN
toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors
Normal file
Binary file not shown.
21
toolkit/keymaps/stable_diffusion_ssd_unmatched.json
Normal file
21
toolkit/keymaps/stable_diffusion_ssd_unmatched.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"ldm": {
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids": {
|
||||
"shape": [
|
||||
1,
|
||||
77
|
||||
],
|
||||
"min": 0.0,
|
||||
"max": 76.0
|
||||
},
|
||||
"conditioner.embedders.1.model.text_model.embeddings.position_ids": {
|
||||
"shape": [
|
||||
1,
|
||||
77
|
||||
],
|
||||
"min": 0.0,
|
||||
"max": 76.0
|
||||
}
|
||||
},
|
||||
"diffusers": {}
|
||||
}
|
||||
@@ -5,14 +5,18 @@ import itertools
|
||||
|
||||
|
||||
class LosslessLatentDecoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||
super(LosslessLatentDecoder, self).__init__()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
if trainable:
|
||||
self.kernel = nn.Parameter(numpy_kernel)
|
||||
else:
|
||||
self.kernel = numpy_kernel
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
# my old code from tensorflow.
|
||||
@@ -35,19 +39,27 @@ class LosslessLatentDecoder(nn.Module):
|
||||
return kernel
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
if self.kernel.dtype != dtype:
|
||||
self.kernel = self.kernel.to(dtype=dtype)
|
||||
|
||||
# Deconvolve input tensor with the kernel
|
||||
return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
|
||||
|
||||
|
||||
class LosslessLatentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||
super(LosslessLatentEncoder, self).__init__()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||
if trainable:
|
||||
self.kernel = nn.Parameter(numpy_kernel)
|
||||
else:
|
||||
self.kernel = numpy_kernel
|
||||
|
||||
|
||||
def build_kernel(self, in_channels, latent_depth):
|
||||
@@ -70,18 +82,21 @@ class LosslessLatentEncoder(nn.Module):
|
||||
return kernel
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
if self.kernel.dtype != dtype:
|
||||
self.kernel = self.kernel.to(dtype=dtype)
|
||||
# Convolve input tensor with the kernel
|
||||
return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
|
||||
|
||||
|
||||
class LosslessLatentVAE(nn.Module):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||
def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
|
||||
super(LosslessLatentVAE, self).__init__()
|
||||
self.latent_depth = latent_depth
|
||||
self.in_channels = in_channels
|
||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype)
|
||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable)
|
||||
encoder_out_channels = self.encoder.out_channels
|
||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype)
|
||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable)
|
||||
|
||||
def forward(self, x):
|
||||
latent = self.latent_encoder(x)
|
||||
@@ -101,7 +116,7 @@ if __name__ == '__main__':
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
user_path = os.path.expanduser('~')
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
input_path = os.path.join(user_path, "Pictures/sample_2_512.png")
|
||||
|
||||
243
toolkit/lora.py
243
toolkit/lora.py
@@ -1,243 +0,0 @@
|
||||
# ref:
|
||||
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
||||
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
||||
# - https://github.com/p1atdev/LECO/blob/main/lora.py
|
||||
|
||||
import os
|
||||
import math
|
||||
from typing import Optional, List, Type, Set, Literal
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import UNet2DConditionModel
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
|
||||
"Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
|
||||
]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV = [
|
||||
"ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D",
|
||||
] # locon, 3clier
|
||||
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
|
||||
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
|
||||
|
||||
TRAINING_METHODS = Literal[
|
||||
"noxattn", # train all layers except x-attns and time_embed layers
|
||||
"innoxattn", # train all layers except self attention layers
|
||||
"selfattn", # ESD-u, train only self attention layers
|
||||
"xattn", # ESD-x, train only x attention layers
|
||||
"full", # train all layers
|
||||
# "notime",
|
||||
# "xlayer",
|
||||
# "outxattn",
|
||||
# "outsattn",
|
||||
# "inxattn",
|
||||
# "inmidsattn",
|
||||
# "selflayer",
|
||||
]
|
||||
|
||||
|
||||
class LoRAModule(nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == "Linear":
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
||||
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
||||
|
||||
elif org_module.__class__.__name__ == "Conv2d": # 一応
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
|
||||
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
|
||||
if self.lora_dim != lora_dim:
|
||||
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = nn.Conv2d(
|
||||
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
|
||||
)
|
||||
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().numpy()
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
# same as microsoft's
|
||||
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
self.org_forward(x)
|
||||
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
)
|
||||
|
||||
|
||||
class LoRANetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
rank: int = 4,
|
||||
multiplier: float = 1.0,
|
||||
alpha: float = 1.0,
|
||||
train_method: TRAINING_METHODS = "full",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = rank
|
||||
self.alpha = alpha
|
||||
|
||||
# LoRAのみ
|
||||
self.module = LoRAModule
|
||||
|
||||
# unetのloraを作る
|
||||
self.unet_loras = self.create_modules(
|
||||
LORA_PREFIX_UNET,
|
||||
unet,
|
||||
DEFAULT_TARGET_REPLACE,
|
||||
self.lora_dim,
|
||||
self.multiplier,
|
||||
train_method=train_method,
|
||||
)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
# assertion 名前の被りがないか確認しているようだ
|
||||
lora_names = set()
|
||||
for lora in self.unet_loras:
|
||||
assert (
|
||||
lora.lora_name not in lora_names
|
||||
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
|
||||
lora_names.add(lora.lora_name)
|
||||
|
||||
# 適用する
|
||||
for lora in self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(
|
||||
lora.lora_name,
|
||||
lora,
|
||||
)
|
||||
|
||||
del unet
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_modules(
|
||||
self,
|
||||
prefix: str,
|
||||
root_module: nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
rank: int,
|
||||
multiplier: float,
|
||||
train_method: TRAINING_METHODS,
|
||||
) -> list:
|
||||
loras = []
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
if train_method == "noxattn": # Cross Attention と Time Embed 以外学習
|
||||
if "attn2" in name or "time_embed" in name:
|
||||
continue
|
||||
elif train_method == "innoxattn": # Cross Attention 以外学習
|
||||
if "attn2" in name:
|
||||
continue
|
||||
elif train_method == "selfattn": # Self Attention のみ学習
|
||||
if "attn1" not in name:
|
||||
continue
|
||||
elif train_method == "xattn": # Cross Attention のみ学習
|
||||
if "attn2" not in name:
|
||||
continue
|
||||
elif train_method == "full": # 全部学習
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"train_method: {train_method} is not implemented."
|
||||
)
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ in ["Linear", "Conv2d"]:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
print(f"{lora_name}")
|
||||
lora = self.module(
|
||||
lora_name, child_module, multiplier, rank, self.alpha
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
return loras
|
||||
|
||||
def prepare_optimizer_params(self):
|
||||
all_params = []
|
||||
|
||||
if self.unet_loras: # 実質これしかない
|
||||
params = []
|
||||
[params.extend(lora.parameters()) for lora in self.unet_loras]
|
||||
param_data = {"params": params}
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
|
||||
state_dict = self.state_dict()
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
if not key.startswith("lora"):
|
||||
# remove any not lora
|
||||
del state_dict[key]
|
||||
|
||||
metadata = add_model_hash_to_meta(state_dict, metadata)
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def __enter__(self):
|
||||
for lora in self.unet_loras:
|
||||
lora.multiplier = 1.0
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb):
|
||||
for lora in self.unet_loras:
|
||||
lora.multiplier = 0
|
||||
@@ -1,14 +1,16 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from .config_modules import NetworkConfig
|
||||
from .lorm import count_parameters
|
||||
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
|
||||
from .paths import SD_SCRIPTS_ROOT
|
||||
from .train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
@@ -19,7 +21,18 @@ from torch.utils.checkpoint import checkpoint
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
# diffusers specific stuff
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
# 'GroupNorm',
|
||||
]
|
||||
CONV_MODULES = [
|
||||
'Conv2d',
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
@@ -34,12 +47,20 @@ class LoRAModule(torch.nn.Module):
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
network: 'LoRASpecialNetwork' = None,
|
||||
use_bias: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
ToolkitModuleMixin.__init__(self, network=network)
|
||||
torch.nn.Module.__init__(self)
|
||||
self.lora_name = lora_name
|
||||
self.scalar = torch.tensor(1.0)
|
||||
# check if parent has bias. if not force use_bias to False
|
||||
if org_module.bias is None:
|
||||
use_bias = False
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
if org_module.__class__.__name__ in CONV_MODULES:
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
@@ -53,15 +74,15 @@ class LoRAModule(torch.nn.Module):
|
||||
# else:
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
if org_module.__class__.__name__ in CONV_MODULES:
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
@@ -74,157 +95,20 @@ class LoRAModule(torch.nn.Module):
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
self.multiplier: Union[float, List[float]] = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
# wrap the original module so it doesn't get weights updated
|
||||
self.org_module = [org_module]
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
# this allows us to set different multipliers on a per item in a batch basis
|
||||
# allowing us to run positive and negative weights in the same batch
|
||||
# really only useful for slider training for now
|
||||
def get_multiplier(self, lora_up):
|
||||
with torch.no_grad():
|
||||
batch_size = lora_up.size(0)
|
||||
# batch will have all negative prompts first and positive prompts second
|
||||
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
|
||||
# if there is more than our multiplier, it is likely a batch size increase, so we need to
|
||||
# interleave the multipliers
|
||||
if isinstance(self.multiplier, list):
|
||||
if len(self.multiplier) == 0:
|
||||
# single item, just return it
|
||||
return self.multiplier[0]
|
||||
elif len(self.multiplier) == batch_size:
|
||||
# not doing CFG
|
||||
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
|
||||
else:
|
||||
|
||||
# we have a list of multipliers, so we need to get the multiplier for this batch
|
||||
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
|
||||
# should be 1 for if total batch size was 1
|
||||
num_interleaves = (batch_size // 2) // len(self.multiplier)
|
||||
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
|
||||
|
||||
# match lora_up rank
|
||||
if len(lora_up.size()) == 2:
|
||||
multiplier_tensor = multiplier_tensor.view(-1, 1)
|
||||
elif len(lora_up.size()) == 3:
|
||||
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
|
||||
elif len(lora_up.size()) == 4:
|
||||
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
|
||||
return multiplier_tensor.detach()
|
||||
|
||||
else:
|
||||
return self.multiplier
|
||||
|
||||
def _call_forward(self, x):
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return 0.0 # added to original forward
|
||||
|
||||
lx = self.lora_down(x)
|
||||
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return lx * scale
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
lora_output = self._call_forward(x)
|
||||
|
||||
if self.is_normalizing:
|
||||
with torch.no_grad():
|
||||
# do this calculation without multiplier
|
||||
# get a dim array from orig forward that had index of all dimensions except the batch and channel
|
||||
|
||||
# Calculate the target magnitude for the combined output
|
||||
orig_max = torch.max(torch.abs(org_forwarded))
|
||||
|
||||
# Calculate the additional increase in magnitude that lora_output would introduce
|
||||
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded))
|
||||
|
||||
epsilon = 1e-6 # Small constant to avoid division by zero
|
||||
|
||||
# Calculate the scaling factor for the lora_output
|
||||
# to ensure that the potential increase in magnitude doesn't change the original max
|
||||
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
|
||||
normalize_scaler = normalize_scaler.detach()
|
||||
|
||||
# save the scaler so it can be applied later
|
||||
self.normalize_scaler = normalize_scaler.clone().detach()
|
||||
|
||||
lora_output *= normalize_scaler
|
||||
|
||||
multiplier = self.get_multiplier(lora_output)
|
||||
|
||||
return org_forwarded + (lora_output * multiplier)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.is_checkpointing = True
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.is_checkpointing = False
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
|
||||
"""
|
||||
Applied the previous normalization calculation to the module.
|
||||
This must be called before saving or normalization will be lost.
|
||||
It is probably best to call after each batch as well.
|
||||
We just scale the up down weights to match this vector
|
||||
:return:
|
||||
"""
|
||||
# get state dict
|
||||
state_dict = self.state_dict()
|
||||
dtype = state_dict['lora_up.weight'].dtype
|
||||
device = state_dict['lora_up.weight'].device
|
||||
|
||||
# todo should we do this at fp32?
|
||||
|
||||
total_module_scale = torch.tensor(self.normalize_scaler / target_normalize_scaler) \
|
||||
.to(device, dtype=dtype)
|
||||
num_modules_layers = 2 # up and down
|
||||
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
|
||||
.to(device, dtype=dtype)
|
||||
|
||||
# apply the scaler to the up and down weights
|
||||
for key in state_dict.keys():
|
||||
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
|
||||
# do it inplace do params are updated
|
||||
state_dict[key] *= up_down_scale
|
||||
|
||||
# reset the normalization scaler
|
||||
self.normalize_scaler = target_normalize_scaler
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
# del self.org_module
|
||||
|
||||
|
||||
class LoRASpecialNetwork(LoRANetwork):
|
||||
class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
@@ -258,7 +142,18 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
module_class: Type[object] = LoRAModule,
|
||||
varbose: Optional[bool] = False,
|
||||
train_text_encoder: Optional[bool] = True,
|
||||
use_text_encoder_1: bool = True,
|
||||
use_text_encoder_2: bool = True,
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
ignore_if_contains = None,
|
||||
parameter_threshold: float = 0.0,
|
||||
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
|
||||
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||
@@ -269,8 +164,19 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
5. modules_dimとmodules_alphaを指定 (推論用)
|
||||
"""
|
||||
# call the parent of the parent we are replacing (LoRANetwork) init
|
||||
super(LoRANetwork, self).__init__()
|
||||
|
||||
torch.nn.Module.__init__(self)
|
||||
ToolkitNetworkMixin.__init__(
|
||||
self,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
is_sdxl=is_sdxl,
|
||||
is_v2=is_v2,
|
||||
is_lorm=is_lorm,
|
||||
**kwargs
|
||||
)
|
||||
if ignore_if_contains is None:
|
||||
ignore_if_contains = []
|
||||
self.ignore_if_contains = ignore_if_contains
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
@@ -281,9 +187,11 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
self.is_checkpointing = False
|
||||
self._multiplier: float = 1.0
|
||||
self.is_active: bool = False
|
||||
self._is_normalizing: bool = False
|
||||
self.torch_multiplier = None
|
||||
# triggers the state updates
|
||||
self.multiplier = multiplier
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
@@ -325,11 +233,19 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_linear = child_module.__class__.__name__ in LINEAR_MODULES
|
||||
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
skip = False
|
||||
if any([word in child_name for word in self.ignore_if_contains]):
|
||||
skip = True
|
||||
|
||||
# see if it is over threshold
|
||||
if count_parameters(child_module) < parameter_threshold:
|
||||
skip = True
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
@@ -375,6 +291,9 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
)
|
||||
loras.append(lora)
|
||||
return loras, skipped
|
||||
@@ -387,6 +306,10 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
skipped_te = []
|
||||
if train_text_encoder:
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if not use_text_encoder_1 and i == 0:
|
||||
continue
|
||||
if not use_text_encoder_2 and i == 1:
|
||||
continue
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
print(f"create LoRA for Text Encoder {index}:")
|
||||
@@ -401,9 +324,9 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
target_modules = target_lin_modules
|
||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
target_modules += target_conv_modules
|
||||
|
||||
if train_unet:
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
@@ -430,106 +353,3 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float]]:
|
||||
return self._multiplier
|
||||
|
||||
@multiplier.setter
|
||||
def multiplier(self, value: Union[float, List[float]]):
|
||||
self._multiplier = value
|
||||
self._update_lora_multiplier()
|
||||
|
||||
def _update_lora_multiplier(self):
|
||||
|
||||
if self.is_active:
|
||||
if hasattr(self, 'unet_loras'):
|
||||
for lora in self.unet_loras:
|
||||
lora.multiplier = self._multiplier
|
||||
if hasattr(self, 'text_encoder_loras'):
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.multiplier = self._multiplier
|
||||
else:
|
||||
if hasattr(self, 'unet_loras'):
|
||||
for lora in self.unet_loras:
|
||||
lora.multiplier = 0
|
||||
if hasattr(self, 'text_encoder_loras'):
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.multiplier = 0
|
||||
|
||||
# called when the context manager is entered
|
||||
# ie: with network:
|
||||
def __enter__(self):
|
||||
self.is_active = True
|
||||
self._update_lora_multiplier()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb):
|
||||
self.is_active = False
|
||||
self._update_lora_multiplier()
|
||||
|
||||
def force_to(self, device, dtype):
|
||||
self.to(device, dtype)
|
||||
loras = []
|
||||
if hasattr(self, 'unet_loras'):
|
||||
loras += self.unet_loras
|
||||
if hasattr(self, 'text_encoder_loras'):
|
||||
loras += self.text_encoder_loras
|
||||
for lora in loras:
|
||||
lora.to(device, dtype)
|
||||
|
||||
def get_all_modules(self):
|
||||
loras = []
|
||||
if hasattr(self, 'unet_loras'):
|
||||
loras += self.unet_loras
|
||||
if hasattr(self, 'text_encoder_loras'):
|
||||
loras += self.text_encoder_loras
|
||||
return loras
|
||||
|
||||
def _update_checkpointing(self):
|
||||
for module in self.get_all_modules():
|
||||
if self.is_checkpointing:
|
||||
module.enable_gradient_checkpointing()
|
||||
else:
|
||||
module.disable_gradient_checkpointing()
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
self.is_checkpointing = True
|
||||
self._update_checkpointing()
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
self.is_checkpointing = False
|
||||
self._update_checkpointing()
|
||||
|
||||
@property
|
||||
def is_normalizing(self) -> bool:
|
||||
return self._is_normalizing
|
||||
|
||||
@is_normalizing.setter
|
||||
def is_normalizing(self, value: bool):
|
||||
self._is_normalizing = value
|
||||
for module in self.get_all_modules():
|
||||
module.is_normalizing = self._is_normalizing
|
||||
|
||||
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
|
||||
for module in self.get_all_modules():
|
||||
module.apply_stored_normalizer(target_normalize_scaler)
|
||||
|
||||
460
toolkit/lorm.py
Normal file
460
toolkit/lorm.py
Normal file
@@ -0,0 +1,460 @@
|
||||
from typing import Union, Tuple, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import UNet2DConditionModel
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import LoRMConfig
|
||||
|
||||
conv = nn.Conv2d
|
||||
lin = nn.Linear
|
||||
_size_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
ExtractMode = Union[
|
||||
'fixed',
|
||||
'threshold',
|
||||
'ratio',
|
||||
'quantile',
|
||||
'percentage'
|
||||
]
|
||||
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
]
|
||||
CONV_MODULES = [
|
||||
# 'Conv2d',
|
||||
# 'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
# "ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D",
|
||||
]
|
||||
|
||||
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE
|
||||
|
||||
UNET_TARGET_REPLACE_NAME = [
|
||||
"conv_in",
|
||||
"conv_out",
|
||||
"time_embedding.linear_1",
|
||||
"time_embedding.linear_2",
|
||||
]
|
||||
|
||||
UNET_MODULES_TO_AVOID = [
|
||||
]
|
||||
|
||||
|
||||
# Low Rank Convolution
|
||||
class LoRMCon2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
lorm_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: _size_2_t,
|
||||
stride: _size_2_t = 1,
|
||||
padding: Union[str, _size_2_t] = 'same',
|
||||
dilation: _size_2_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.lorm_channels = lorm_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
self.down = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=lorm_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
padding_mode=padding_mode,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
# Kernel size on the up is always 1x1.
|
||||
# I don't think you could calculate a dual 3x3, or I can't at least
|
||||
|
||||
self.up = nn.Conv2d(
|
||||
in_channels=lorm_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(1, 1),
|
||||
stride=1,
|
||||
padding='same',
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
padding_mode='zeros',
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
|
||||
x = input
|
||||
x = self.down(x)
|
||||
x = self.up(x)
|
||||
return x
|
||||
|
||||
|
||||
class LoRMLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
lorm_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.lorm_features = lorm_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.down = nn.Linear(
|
||||
in_features=in_features,
|
||||
out_features=lorm_features,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
|
||||
)
|
||||
self.up = nn.Linear(
|
||||
in_features=lorm_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
# bias=True,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
|
||||
x = input
|
||||
x = self.down(x)
|
||||
x = self.up(x)
|
||||
return x
|
||||
|
||||
|
||||
def extract_conv(
|
||||
weight: Union[torch.Tensor, nn.Parameter],
|
||||
mode='fixed',
|
||||
mode_param=0,
|
||||
device='cpu'
|
||||
) -> Tuple[Tensor, Tensor, int, Tensor]:
|
||||
weight = weight.to(device)
|
||||
out_ch, in_ch, kernel_size, _ = weight.shape
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1))
|
||||
if mode == 'percentage':
|
||||
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
|
||||
original_params = out_ch * in_ch * kernel_size * kernel_size
|
||||
desired_params = mode_param * original_params
|
||||
# Solve for lora_rank from the equation
|
||||
lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch))
|
||||
elif mode == 'fixed':
|
||||
lora_rank = mode_param
|
||||
elif mode == 'threshold':
|
||||
assert mode_param >= 0
|
||||
lora_rank = torch.sum(S > mode_param).item()
|
||||
elif mode == 'ratio':
|
||||
assert 1 >= mode_param >= 0
|
||||
min_s = torch.max(S) * mode_param
|
||||
lora_rank = torch.sum(S > min_s).item()
|
||||
elif mode == 'quantile' or mode == 'percentile':
|
||||
assert 1 >= mode_param >= 0
|
||||
s_cum = torch.cumsum(S, dim=0)
|
||||
min_cum_sum = mode_param * torch.sum(S)
|
||||
lora_rank = torch.sum(s_cum < min_cum_sum).item()
|
||||
else:
|
||||
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
||||
lora_rank = max(1, lora_rank)
|
||||
lora_rank = min(out_ch, in_ch, lora_rank)
|
||||
if lora_rank >= out_ch / 2:
|
||||
lora_rank = int(out_ch / 2)
|
||||
print(f"rank is higher than it should be")
|
||||
# print(f"Skipping layer as determined rank is too high")
|
||||
# return None, None, None, None
|
||||
# return weight, 'full'
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
|
||||
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
|
||||
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
|
||||
del U, S, Vh, weight
|
||||
return extract_weight_A, extract_weight_B, lora_rank, diff
|
||||
|
||||
|
||||
def extract_linear(
|
||||
weight: Union[torch.Tensor, nn.Parameter],
|
||||
mode='fixed',
|
||||
mode_param=0,
|
||||
device='cpu',
|
||||
) -> Tuple[Tensor, Tensor, int, Tensor]:
|
||||
weight = weight.to(device)
|
||||
out_ch, in_ch = weight.shape
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight)
|
||||
|
||||
if mode == 'percentage':
|
||||
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
|
||||
desired_params = mode_param * out_ch * in_ch
|
||||
# Solve for lora_rank from the equation
|
||||
lora_rank = int(desired_params / (in_ch + out_ch))
|
||||
elif mode == 'fixed':
|
||||
lora_rank = mode_param
|
||||
elif mode == 'threshold':
|
||||
assert mode_param >= 0
|
||||
lora_rank = torch.sum(S > mode_param).item()
|
||||
elif mode == 'ratio':
|
||||
assert 1 >= mode_param >= 0
|
||||
min_s = torch.max(S) * mode_param
|
||||
lora_rank = torch.sum(S > min_s).item()
|
||||
elif mode == 'quantile':
|
||||
assert 1 >= mode_param >= 0
|
||||
s_cum = torch.cumsum(S, dim=0)
|
||||
min_cum_sum = mode_param * torch.sum(S)
|
||||
lora_rank = torch.sum(s_cum < min_cum_sum).item()
|
||||
else:
|
||||
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
||||
lora_rank = max(1, lora_rank)
|
||||
lora_rank = min(out_ch, in_ch, lora_rank)
|
||||
if lora_rank >= out_ch / 2:
|
||||
# print(f"rank is higher than it should be")
|
||||
lora_rank = int(out_ch / 2)
|
||||
# return weight, 'full'
|
||||
# print(f"Skipping layer as determined rank is too high")
|
||||
# return None, None, None, None
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
diff = (weight - U @ Vh).detach()
|
||||
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
|
||||
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
|
||||
del U, S, Vh, weight
|
||||
return extract_weight_A, extract_weight_B, lora_rank, diff
|
||||
|
||||
|
||||
def replace_module_by_path(network, name, module):
|
||||
"""Replace a module in a network by its name."""
|
||||
name_parts = name.split('.')
|
||||
current_module = network
|
||||
for part in name_parts[:-1]:
|
||||
current_module = getattr(current_module, part)
|
||||
try:
|
||||
setattr(current_module, name_parts[-1], module)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def count_parameters(module):
|
||||
return sum(p.numel() for p in module.parameters())
|
||||
|
||||
|
||||
def compute_optimal_bias(original_module, linear_down, linear_up, X):
|
||||
Y_original = original_module(X)
|
||||
Y_approx = linear_up(linear_down(X))
|
||||
E = Y_original - Y_approx
|
||||
|
||||
optimal_bias = E.mean(dim=0)
|
||||
|
||||
return optimal_bias
|
||||
|
||||
|
||||
def format_with_commas(n):
|
||||
return f"{n:,}"
|
||||
|
||||
|
||||
def print_lorm_extract_details(
|
||||
start_num_params: int,
|
||||
end_num_params: int,
|
||||
num_replaced: int,
|
||||
):
|
||||
start_formatted = format_with_commas(start_num_params)
|
||||
end_formatted = format_with_commas(end_num_params)
|
||||
num_replaced_formatted = format_with_commas(num_replaced)
|
||||
|
||||
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
|
||||
|
||||
print(f"Convert UNet result:")
|
||||
print(f" - converted: {num_replaced:>{width},} modules")
|
||||
print(f" - start: {start_num_params:>{width},} params")
|
||||
print(f" - end: {end_num_params:>{width},} params")
|
||||
|
||||
|
||||
lorm_ignore_if_contains = [
|
||||
'proj_out', 'proj_in',
|
||||
]
|
||||
|
||||
lorm_parameter_threshold = 1000000
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusers_unet_to_lorm(
|
||||
unet: UNet2DConditionModel,
|
||||
config: LoRMConfig,
|
||||
):
|
||||
print('Converting UNet to LoRM UNet')
|
||||
start_num_params = count_parameters(unet)
|
||||
named_modules = list(unet.named_modules())
|
||||
|
||||
num_replaced = 0
|
||||
|
||||
pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet")
|
||||
layer_names_replaced = []
|
||||
converted_modules = []
|
||||
ignore_if_contains = [
|
||||
'proj_out', 'proj_in',
|
||||
]
|
||||
|
||||
for name, module in named_modules:
|
||||
module_name = module.__class__.__name__
|
||||
if module_name in UNET_TARGET_REPLACE_MODULE:
|
||||
for child_name, child_module in module.named_modules():
|
||||
new_module: Union[LoRMCon2d, LoRMLinear, None] = None
|
||||
# if child name includes attn, skip it
|
||||
combined_name = combined_name = f"{name}.{child_name}"
|
||||
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
|
||||
# pass
|
||||
|
||||
lorm_config = config.get_config_for_module(combined_name)
|
||||
|
||||
extract_mode = lorm_config.extract_mode
|
||||
extract_mode_param = lorm_config.extract_mode_param
|
||||
parameter_threshold = lorm_config.parameter_threshold
|
||||
|
||||
if any([word in child_name for word in ignore_if_contains]):
|
||||
pass
|
||||
|
||||
elif child_module.__class__.__name__ in LINEAR_MODULES:
|
||||
if count_parameters(child_module) > parameter_threshold:
|
||||
|
||||
dtype = child_module.weight.dtype
|
||||
# extract and convert
|
||||
down_weight, up_weight, lora_dim, diff = extract_linear(
|
||||
weight=child_module.weight.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
if down_weight is None:
|
||||
continue
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
up_weight = up_weight.to(dtype=dtype)
|
||||
bias_weight = None
|
||||
if child_module.bias is not None:
|
||||
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
|
||||
# linear layer weights = (out_features, in_features)
|
||||
new_module = LoRMLinear(
|
||||
in_features=down_weight.shape[1],
|
||||
lorm_features=lora_dim,
|
||||
out_features=up_weight.shape[0],
|
||||
bias=bias_weight is not None,
|
||||
device=down_weight.device,
|
||||
dtype=down_weight.dtype
|
||||
)
|
||||
|
||||
# replace the weights
|
||||
new_module.down.weight.data = down_weight
|
||||
new_module.up.weight.data = up_weight
|
||||
if bias_weight is not None:
|
||||
new_module.up.bias.data = bias_weight
|
||||
# else:
|
||||
# new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data)
|
||||
|
||||
# bias_correction = compute_optimal_bias(
|
||||
# child_module,
|
||||
# new_module.down,
|
||||
# new_module.up,
|
||||
# torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype)
|
||||
# )
|
||||
# new_module.up.bias.data += bias_correction
|
||||
|
||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
||||
if count_parameters(child_module) > parameter_threshold:
|
||||
dtype = child_module.weight.dtype
|
||||
down_weight, up_weight, lora_dim, diff = extract_conv(
|
||||
weight=child_module.weight.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
if down_weight is None:
|
||||
continue
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
up_weight = up_weight.to(dtype=dtype)
|
||||
bias_weight = None
|
||||
if child_module.bias is not None:
|
||||
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
|
||||
|
||||
new_module = LoRMCon2d(
|
||||
in_channels=down_weight.shape[1],
|
||||
lorm_channels=lora_dim,
|
||||
out_channels=up_weight.shape[0],
|
||||
kernel_size=child_module.kernel_size,
|
||||
dilation=child_module.dilation,
|
||||
padding=child_module.padding,
|
||||
padding_mode=child_module.padding_mode,
|
||||
stride=child_module.stride,
|
||||
bias=bias_weight is not None,
|
||||
device=down_weight.device,
|
||||
dtype=down_weight.dtype
|
||||
)
|
||||
# replace the weights
|
||||
new_module.down.weight.data = down_weight
|
||||
new_module.up.weight.data = up_weight
|
||||
if bias_weight is not None:
|
||||
new_module.up.bias.data = bias_weight
|
||||
|
||||
if new_module:
|
||||
combined_name = f"{name}.{child_name}"
|
||||
replace_module_by_path(unet, combined_name, new_module)
|
||||
converted_modules.append(new_module)
|
||||
num_replaced += 1
|
||||
layer_names_replaced.append(
|
||||
f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
|
||||
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
end_num_params = count_parameters(unet)
|
||||
|
||||
def sorting_key(s):
|
||||
# Extract the number part, remove commas, and convert to integer
|
||||
return int(s.split("-")[1].strip().replace(",", ""))
|
||||
|
||||
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
|
||||
for layer_name in sorted_layer_names_replaced:
|
||||
print(layer_name)
|
||||
|
||||
print_lorm_extract_details(
|
||||
start_num_params=start_num_params,
|
||||
end_num_params=end_num_params,
|
||||
num_replaced=num_replaced,
|
||||
)
|
||||
|
||||
return converted_modules
|
||||
@@ -27,11 +27,17 @@ class ComparativeTotalVariation(torch.nn.Module):
|
||||
# Gradient penalty
|
||||
def get_gradient_penalty(critic, real, fake, device):
|
||||
with torch.autocast(device_type='cuda'):
|
||||
alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
|
||||
real = real.float()
|
||||
fake = fake.float()
|
||||
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
|
||||
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
|
||||
if torch.isnan(interpolates).any():
|
||||
print('d_interpolates is nan')
|
||||
d_interpolates = critic(interpolates)
|
||||
fake = torch.ones(real.size(0), 1, device=device)
|
||||
|
||||
|
||||
if torch.isnan(d_interpolates).any():
|
||||
print('fake is nan')
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=d_interpolates,
|
||||
inputs=interpolates,
|
||||
@@ -41,10 +47,14 @@ def get_gradient_penalty(critic, real, fake, device):
|
||||
only_inputs=True,
|
||||
)[0]
|
||||
|
||||
# see if any are nan
|
||||
if torch.isnan(gradients).any():
|
||||
print('gradients is nan')
|
||||
|
||||
gradients = gradients.view(gradients.size(0), -1)
|
||||
gradient_norm = gradients.norm(2, dim=1)
|
||||
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
|
||||
return gradient_penalty
|
||||
return gradient_penalty.float()
|
||||
|
||||
|
||||
class PatternLoss(torch.nn.Module):
|
||||
|
||||
373
toolkit/lycoris_special.py
Normal file
373
toolkit/lycoris_special.py
Normal file
@@ -0,0 +1,373 @@
|
||||
import math
|
||||
import os
|
||||
from typing import Optional, Union, List, Type
|
||||
|
||||
import torch
|
||||
from lycoris.kohya import LycorisNetwork, LoConModule
|
||||
from lycoris.modules.glora import GLoRAModule
|
||||
from torch import nn
|
||||
from transformers import CLIPTextModel
|
||||
from torch.nn import functional as F
|
||||
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
|
||||
|
||||
# diffusers specific stuff
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
]
|
||||
CONV_MODULES = [
|
||||
'Conv2d',
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
class LoConSpecialModule(ToolkitModuleMixin, LoConModule, ExtractableModuleMixin):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name, org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4, alpha=1,
|
||||
dropout=0., rank_dropout=0., module_dropout=0.,
|
||||
use_cp=False,
|
||||
network: 'LycorisSpecialNetwork' = None,
|
||||
use_bias=False,
|
||||
**kwargs,
|
||||
):
|
||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||
# call super of super
|
||||
ToolkitModuleMixin.__init__(self, network=network)
|
||||
torch.nn.Module.__init__(self)
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.cp = False
|
||||
|
||||
# check if parent has bias. if not force use_bias to False
|
||||
if org_module.bias is None:
|
||||
use_bias = False
|
||||
|
||||
self.scalar = nn.Parameter(torch.tensor(0.0))
|
||||
orig_module_name = org_module.__class__.__name__
|
||||
if orig_module_name in CONV_MODULES:
|
||||
self.isconv = True
|
||||
# For general LoCon
|
||||
in_dim = org_module.in_channels
|
||||
k_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
out_dim = org_module.out_channels
|
||||
self.down_op = F.conv2d
|
||||
self.up_op = F.conv2d
|
||||
if use_cp and k_size != (1, 1):
|
||||
self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
||||
self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False)
|
||||
self.cp = True
|
||||
else:
|
||||
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
|
||||
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=use_bias)
|
||||
elif orig_module_name in LINEAR_MODULES:
|
||||
self.isconv = False
|
||||
self.down_op = F.linear
|
||||
self.up_op = F.linear
|
||||
if orig_module_name == 'GroupNorm':
|
||||
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (56320x120 and 320x32)
|
||||
in_dim = org_module.num_channels
|
||||
out_dim = org_module.num_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
||||
self.lora_up = nn.Linear(lora_dim, out_dim, bias=use_bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.shape = org_module.weight.shape
|
||||
|
||||
if dropout:
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
self.dropout = nn.Identity()
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
# same as microsoft's
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.lora_up.weight)
|
||||
if self.cp:
|
||||
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = [org_module]
|
||||
self.register_load_state_dict_post_hook(self.load_weight_hook)
|
||||
|
||||
def load_weight_hook(self, *args, **kwargs):
|
||||
self.scalar = nn.Parameter(torch.ones_like(self.scalar))
|
||||
|
||||
|
||||
class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
"ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D",
|
||||
# 'UNet2DConditionModel',
|
||||
# 'Conv2d',
|
||||
# 'Timesteps',
|
||||
# 'TimestepEmbedding',
|
||||
# 'Linear',
|
||||
# 'SiLU',
|
||||
# 'ModuleList',
|
||||
# 'DownBlock2D',
|
||||
# 'ResnetBlock2D', # need
|
||||
# 'GroupNorm',
|
||||
# 'LoRACompatibleConv',
|
||||
# 'LoRACompatibleLinear',
|
||||
# 'Dropout',
|
||||
# 'CrossAttnDownBlock2D', # needed
|
||||
# 'Transformer2DModel', # maybe not, has duplicates
|
||||
# 'BasicTransformerBlock', # duplicates
|
||||
# 'LayerNorm',
|
||||
# 'Attention',
|
||||
# 'FeedForward',
|
||||
# 'GEGLU',
|
||||
# 'UpBlock2D',
|
||||
# 'UNetMidBlock2DCrossAttn'
|
||||
]
|
||||
UNET_TARGET_REPLACE_NAME = [
|
||||
"conv_in",
|
||||
"conv_out",
|
||||
"time_embedding.linear_1",
|
||||
"time_embedding.linear_2",
|
||||
]
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||
unet,
|
||||
multiplier: float = 1.0,
|
||||
lora_dim: int = 4,
|
||||
alpha: float = 1,
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
use_cp: Optional[bool] = False,
|
||||
network_module: Type[object] = LoConSpecialModule,
|
||||
train_unet: bool = True,
|
||||
train_text_encoder: bool = True,
|
||||
use_text_encoder_1: bool = True,
|
||||
use_text_encoder_2: bool = True,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# call ToolkitNetworkMixin super
|
||||
ToolkitNetworkMixin.__init__(
|
||||
self,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
is_lorm=is_lorm,
|
||||
**kwargs
|
||||
)
|
||||
# call the parent of the parent LycorisNetwork
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
# LyCORIS unique stuff
|
||||
if dropout is None:
|
||||
dropout = 0
|
||||
if rank_dropout is None:
|
||||
rank_dropout = 0
|
||||
if module_dropout is None:
|
||||
module_dropout = 0
|
||||
self.train_unet = train_unet
|
||||
self.train_text_encoder = train_text_encoder
|
||||
|
||||
self.torch_multiplier = None
|
||||
# triggers a tensor update
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if not self.ENABLE_CONV or conv_lora_dim is None:
|
||||
conv_lora_dim = 0
|
||||
conv_alpha = 0
|
||||
|
||||
self.conv_lora_dim = int(conv_lora_dim)
|
||||
if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim:
|
||||
print('Apply different lora dim for conv layer')
|
||||
print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}')
|
||||
elif self.conv_lora_dim == 0:
|
||||
print('Disable conv layer')
|
||||
|
||||
self.alpha = alpha
|
||||
self.conv_alpha = float(conv_alpha)
|
||||
if self.conv_lora_dim and self.alpha != self.conv_alpha:
|
||||
print('Apply different alpha value for conv layer')
|
||||
print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}')
|
||||
|
||||
if 1 >= dropout >= 0:
|
||||
print(f'Use Dropout value: {dropout}')
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
prefix,
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules,
|
||||
target_replace_names=[]
|
||||
) -> List[network_module]:
|
||||
print('Create LyCORIS Module')
|
||||
loras = []
|
||||
# remove this
|
||||
named_modules = root_module.named_modules()
|
||||
# add a few to tthe generator
|
||||
|
||||
for name, module in named_modules:
|
||||
module_name = module.__class__.__name__
|
||||
if module_name in target_replace_modules:
|
||||
if module_name in self.MODULE_ALGO_MAP:
|
||||
algo = self.MODULE_ALGO_MAP[module_name]
|
||||
else:
|
||||
algo = network_module
|
||||
for child_name, child_module in module.named_modules():
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
if lora_name.startswith('lora_unet_input_blocks_1_0_emb_layers_1'):
|
||||
print(f"{lora_name}")
|
||||
|
||||
if child_module.__class__.__name__ in LINEAR_MODULES and lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, child_module, self.multiplier,
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
||||
k_size, *_ = child_module.kernel_size
|
||||
if k_size == 1 and lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, child_module, self.multiplier,
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
elif conv_lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, child_module, self.multiplier,
|
||||
self.conv_lora_dim, self.conv_alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
loras.append(lora)
|
||||
elif name in target_replace_names:
|
||||
if name in self.NAME_ALGO_MAP:
|
||||
algo = self.NAME_ALGO_MAP[name]
|
||||
else:
|
||||
algo = network_module
|
||||
lora_name = prefix + '.' + name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
if module.__class__.__name__ == 'Linear' and lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, module, self.multiplier,
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
parent=module,
|
||||
network=self,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
elif module.__class__.__name__ == 'Conv2d':
|
||||
k_size, *_ = module.kernel_size
|
||||
if k_size == 1 and lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, module, self.multiplier,
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
elif conv_lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, module, self.multiplier,
|
||||
self.conv_lora_dim, self.conv_alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
if network_module == GLoRAModule:
|
||||
print('GLoRA enabled, only train transformer')
|
||||
# only train transformer (for GLoRA)
|
||||
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
"Attention",
|
||||
]
|
||||
LycorisSpecialNetwork.UNET_TARGET_REPLACE_NAME = []
|
||||
|
||||
if isinstance(text_encoder, list):
|
||||
text_encoders = text_encoder
|
||||
use_index = True
|
||||
else:
|
||||
text_encoders = [text_encoder]
|
||||
use_index = False
|
||||
|
||||
self.text_encoder_loras = []
|
||||
if self.train_text_encoder:
|
||||
for i, te in enumerate(text_encoders):
|
||||
if not use_text_encoder_1 and i == 0:
|
||||
continue
|
||||
if not use_text_encoder_2 and i == 1:
|
||||
continue
|
||||
self.text_encoder_loras.extend(create_modules(
|
||||
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
|
||||
te,
|
||||
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
))
|
||||
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
if self.train_unet:
|
||||
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
|
||||
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
|
||||
else:
|
||||
self.unet_loras = []
|
||||
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
self.weights_sd = None
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
@@ -77,6 +77,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
||||
|
||||
|
||||
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
return parse_metadata_from_safetensors(metadata)
|
||||
try:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
return parse_metadata_from_safetensors(metadata)
|
||||
except Exception as e:
|
||||
print(f"Error loading metadata from {file_path}: {e}")
|
||||
return OrderedDict()
|
||||
|
||||
566
toolkit/network_mixins.py
Normal file
566
toolkit/network_mixins.py
Normal file
@@ -0,0 +1,566 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import weakref
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import NetworkConfig
|
||||
from toolkit.lorm import extract_conv, extract_linear, count_parameters
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
||||
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
|
||||
Module = Union['LoConSpecialModule', 'LoRAModule']
|
||||
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
# 'GroupNorm',
|
||||
]
|
||||
CONV_MODULES = [
|
||||
'Conv2d',
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
ExtractMode = Union[
|
||||
'existing'
|
||||
'fixed',
|
||||
'threshold',
|
||||
'ratio',
|
||||
'quantile',
|
||||
'percentage'
|
||||
]
|
||||
|
||||
|
||||
def broadcast_and_multiply(tensor, multiplier):
|
||||
# Determine the number of dimensions required
|
||||
num_extra_dims = tensor.dim() - multiplier.dim()
|
||||
|
||||
# Unsqueezing the tensor to match the dimensionality
|
||||
for _ in range(num_extra_dims):
|
||||
multiplier = multiplier.unsqueeze(-1)
|
||||
|
||||
# Multiplying the broadcasted tensor with the output tensor
|
||||
result = tensor * multiplier
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_bias(tensor, bias):
|
||||
if bias is None:
|
||||
return tensor
|
||||
# add batch dim
|
||||
bias = bias.unsqueeze(0)
|
||||
bias = torch.cat([bias] * tensor.size(0), dim=0)
|
||||
# Determine the number of dimensions required
|
||||
num_extra_dims = tensor.dim() - bias.dim()
|
||||
|
||||
# Unsqueezing the tensor to match the dimensionality
|
||||
for _ in range(num_extra_dims):
|
||||
bias = bias.unsqueeze(-1)
|
||||
|
||||
# we may need to swap -1 for -2
|
||||
if bias.size(1) != tensor.size(1):
|
||||
if len(bias.size()) == 3:
|
||||
bias = bias.permute(0, 2, 1)
|
||||
elif len(bias.size()) == 4:
|
||||
bias = bias.permute(0, 3, 1, 2)
|
||||
|
||||
# Multiplying the broadcasted tensor with the output tensor
|
||||
try:
|
||||
result = tensor + bias
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print(tensor.size())
|
||||
print(bias.size())
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ExtractableModuleMixin:
|
||||
def extract_weight(
|
||||
self: Module,
|
||||
extract_mode: ExtractMode = "existing",
|
||||
extract_mode_param: Union[int, float] = None,
|
||||
):
|
||||
device = self.lora_down.weight.device
|
||||
weight_to_extract = self.org_module[0].weight
|
||||
if extract_mode == "existing":
|
||||
extract_mode = 'fixed'
|
||||
extract_mode_param = self.lora_dim
|
||||
|
||||
if self.org_module[0].__class__.__name__ in CONV_MODULES:
|
||||
# do conv extraction
|
||||
down_weight, up_weight, new_dim, diff = extract_conv(
|
||||
weight=weight_to_extract.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=device
|
||||
)
|
||||
|
||||
elif self.org_module[0].__class__.__name__ in LINEAR_MODULES:
|
||||
# do linear extraction
|
||||
down_weight, up_weight, new_dim, diff = extract_linear(
|
||||
weight=weight_to_extract.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}")
|
||||
|
||||
self.lora_dim = new_dim
|
||||
|
||||
# inject weights into the param
|
||||
self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach()
|
||||
self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach()
|
||||
|
||||
# copy bias if we have one and are using them
|
||||
if self.org_module[0].bias is not None and self.lora_up.bias is not None:
|
||||
self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach()
|
||||
|
||||
# set up alphas
|
||||
self.alpha = (self.alpha * 0) + down_weight.shape[0]
|
||||
self.scale = self.alpha / self.lora_dim
|
||||
|
||||
# assign them
|
||||
|
||||
# handle trainable scaler method locon does
|
||||
if hasattr(self, 'scalar'):
|
||||
# scaler is a parameter update the value with 1.0
|
||||
self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype)
|
||||
|
||||
|
||||
class ToolkitModuleMixin:
|
||||
def __init__(
|
||||
self: Module,
|
||||
*args,
|
||||
network: Network,
|
||||
**kwargs
|
||||
):
|
||||
self.network_ref: weakref.ref = weakref.ref(network)
|
||||
self.is_checkpointing = False
|
||||
self._multiplier: Union[float, list, torch.Tensor] = None
|
||||
|
||||
def _call_forward(self: Module, x):
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return 0.0 # added to original forward
|
||||
|
||||
if hasattr(self, 'lora_mid') and self.lora_mid is not None:
|
||||
lx = self.lora_mid(self.lora_down(x))
|
||||
else:
|
||||
try:
|
||||
lx = self.lora_down(x)
|
||||
except RuntimeError as e:
|
||||
print(f"Error in {self.__class__.__name__} lora_down")
|
||||
|
||||
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
|
||||
lx = self.dropout(lx)
|
||||
# normal dropout
|
||||
elif self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.rank_dropout > 0 and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# handle trainable scaler method locon does
|
||||
if hasattr(self, 'scalar'):
|
||||
scale = scale * self.scalar
|
||||
|
||||
return lx * scale
|
||||
|
||||
|
||||
def lorm_forward(self: Network, x, *args, **kwargs):
|
||||
network: Network = self.network_ref()
|
||||
if not network.is_active:
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
if network.lorm_train_mode == 'local':
|
||||
# we are going to predict input with both and do a loss on them
|
||||
inputs = x.detach()
|
||||
with torch.no_grad():
|
||||
# get the local prediction
|
||||
target_pred = self.org_forward(inputs, *args, **kwargs).detach()
|
||||
with torch.set_grad_enabled(True):
|
||||
# make a prediction with the lorm
|
||||
lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True)))
|
||||
|
||||
local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float())
|
||||
# backpropr
|
||||
local_loss.backward()
|
||||
|
||||
network.module_losses.append(local_loss.detach())
|
||||
# return the original as we dont want our trainer to affect ones down the line
|
||||
return target_pred
|
||||
|
||||
else:
|
||||
return self.lora_up(self.lora_down(x))
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
skip = False
|
||||
network: Network = self.network_ref()
|
||||
if network.is_lorm:
|
||||
# we are doing lorm
|
||||
return self.lorm_forward(x, *args, **kwargs)
|
||||
|
||||
# skip if not active
|
||||
if not network.is_active:
|
||||
skip = True
|
||||
|
||||
# skip if is merged in
|
||||
if network.is_merged_in:
|
||||
skip = True
|
||||
|
||||
# skip if multiplier is 0
|
||||
if network._multiplier == 0:
|
||||
skip = True
|
||||
|
||||
if skip:
|
||||
# network is not active, avoid doing anything
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
lora_output = self._call_forward(x)
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
|
||||
lora_output_batch_size = lora_output.size(0)
|
||||
multiplier_batch_size = multiplier.size(0)
|
||||
if lora_output_batch_size != multiplier_batch_size:
|
||||
num_interleaves = lora_output_batch_size // multiplier_batch_size
|
||||
# todo check if this is correct, do we just concat when doing cfg?
|
||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||
|
||||
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
||||
return x
|
||||
|
||||
def enable_gradient_checkpointing(self: Module):
|
||||
self.is_checkpointing = True
|
||||
|
||||
def disable_gradient_checkpointing(self: Module):
|
||||
self.is_checkpointing = False
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_out(self: Module, merge_out_weight=1.0):
|
||||
# make sure it is positive
|
||||
merge_out_weight = abs(merge_out_weight)
|
||||
# merging out is just merging in the negative of the weight
|
||||
self.merge_in(merge_weight=-merge_out_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_in(self: Module, merge_weight=1.0):
|
||||
# get up/down weight
|
||||
up_weight = self.lora_up.weight.clone().float()
|
||||
down_weight = self.lora_down.weight.clone().float()
|
||||
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
orig_dtype = org_sd["weight"].dtype
|
||||
weight = org_sd["weight"].float()
|
||||
|
||||
multiplier = merge_weight
|
||||
scale = self.scale
|
||||
# handle trainable scaler method locon does
|
||||
if hasattr(self, 'scalar'):
|
||||
scale = scale * self.scalar
|
||||
|
||||
# merge weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + multiplier * conved * scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
|
||||
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
|
||||
# outputs the same. It is basically a LoRA but with the original module removed
|
||||
|
||||
# if a state dict is passed, use those weights instead of extracting
|
||||
# todo load from state dict
|
||||
network: Network = self.network_ref()
|
||||
lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name)
|
||||
|
||||
extract_mode = lorm_config.extract_mode
|
||||
extract_mode_param = lorm_config.extract_mode_param
|
||||
parameter_threshold = lorm_config.parameter_threshold
|
||||
self.extract_weight(
|
||||
extract_mode=extract_mode,
|
||||
extract_mode_param=extract_mode_param
|
||||
)
|
||||
|
||||
|
||||
class ToolkitNetworkMixin:
|
||||
def __init__(
|
||||
self: Network,
|
||||
*args,
|
||||
train_text_encoder: Optional[bool] = True,
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
network_config: Optional[NetworkConfig] = None,
|
||||
is_lorm=False,
|
||||
**kwargs
|
||||
):
|
||||
self.train_text_encoder = train_text_encoder
|
||||
self.train_unet = train_unet
|
||||
self.is_checkpointing = False
|
||||
self._multiplier: float = 1.0
|
||||
self.is_active: bool = False
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_merged_in = False
|
||||
self.is_lorm = is_lorm
|
||||
self.network_config: NetworkConfig = network_config
|
||||
self.module_losses: List[torch.Tensor] = []
|
||||
self.lorm_train_mode: Literal['local', None] = None
|
||||
self.can_merge_in = not is_lorm
|
||||
|
||||
def get_keymap(self: Network):
|
||||
if self.is_sdxl:
|
||||
keymap_tail = 'sdxl'
|
||||
elif self.is_v2:
|
||||
keymap_tail = 'sd2'
|
||||
else:
|
||||
keymap_tail = 'sd1'
|
||||
# load keymap
|
||||
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
|
||||
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
|
||||
|
||||
keymap = None
|
||||
# check if file exists
|
||||
if os.path.exists(keymap_path):
|
||||
with open(keymap_path, 'r') as f:
|
||||
keymap = json.load(f)['ldm_diffusers_keymap']
|
||||
|
||||
return keymap
|
||||
|
||||
def save_weights(
|
||||
self: Network,
|
||||
file, dtype=torch.float16,
|
||||
metadata=None,
|
||||
extra_state_dict: Optional[OrderedDict] = None
|
||||
):
|
||||
keymap = self.get_keymap()
|
||||
|
||||
save_keymap = {}
|
||||
if keymap is not None:
|
||||
for ldm_key, diffusers_key in keymap.items():
|
||||
# invert them
|
||||
save_keymap[diffusers_key] = ldm_key
|
||||
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
save_dict = OrderedDict()
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
save_key = save_keymap[key] if key in save_keymap else key
|
||||
save_dict[save_key] = v
|
||||
|
||||
if extra_state_dict is not None:
|
||||
# add extra items to state dict
|
||||
for key in list(extra_state_dict.keys()):
|
||||
v = extra_state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
save_dict[key] = v
|
||||
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
metadata = add_model_hash_to_meta(state_dict, metadata)
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
save_file(save_dict, file, metadata)
|
||||
else:
|
||||
torch.save(save_dict, file)
|
||||
|
||||
def load_weights(self: Network, file):
|
||||
# allows us to save and load to and from ldm weights
|
||||
keymap = self.get_keymap()
|
||||
keymap = {} if keymap is None else keymap
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
load_sd = OrderedDict()
|
||||
for key, value in weights_sd.items():
|
||||
load_key = keymap[key] if key in keymap else key
|
||||
load_sd[load_key] = value
|
||||
|
||||
# extract extra items from state dict
|
||||
current_state_dict = self.state_dict()
|
||||
extra_dict = OrderedDict()
|
||||
to_delete = []
|
||||
for key in list(load_sd.keys()):
|
||||
if key not in current_state_dict:
|
||||
extra_dict[key] = load_sd[key]
|
||||
to_delete.append(key)
|
||||
for key in to_delete:
|
||||
del load_sd[key]
|
||||
|
||||
info = self.load_state_dict(load_sd, False)
|
||||
if len(extra_dict.keys()) == 0:
|
||||
extra_dict = None
|
||||
return extra_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_torch_multiplier(self: Network):
|
||||
# builds a tensor for fast usage in the forward pass of the network modules
|
||||
# without having to set it in every single module every time it changes
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
|
||||
elif isinstance(multiplier, list):
|
||||
tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype)
|
||||
elif isinstance(multiplier, torch.Tensor):
|
||||
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
|
||||
|
||||
self.torch_multiplier = tensor_multiplier.clone().detach()
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float], List[List[float]]]:
|
||||
return self._multiplier
|
||||
|
||||
@multiplier.setter
|
||||
def multiplier(self, value: Union[float, List[float], List[List[float]]]):
|
||||
# it takes time to update all the multipliers, so we only do it if the value has changed
|
||||
if self._multiplier == value:
|
||||
return
|
||||
# if we are setting a single value but have a list, keep the list if every item is the same as value
|
||||
self._multiplier = value
|
||||
self._update_torch_multiplier()
|
||||
|
||||
# called when the context manager is entered
|
||||
# ie: with network:
|
||||
def __enter__(self: Network):
|
||||
self.is_active = True
|
||||
|
||||
def __exit__(self: Network, exc_type, exc_value, tb):
|
||||
self.is_active = False
|
||||
|
||||
def force_to(self: Network, device, dtype):
|
||||
self.to(device, dtype)
|
||||
loras = []
|
||||
if hasattr(self, 'unet_loras'):
|
||||
loras += self.unet_loras
|
||||
if hasattr(self, 'text_encoder_loras'):
|
||||
loras += self.text_encoder_loras
|
||||
for lora in loras:
|
||||
lora.to(device, dtype)
|
||||
|
||||
def get_all_modules(self: Network) -> List[Module]:
|
||||
loras = []
|
||||
if hasattr(self, 'unet_loras'):
|
||||
loras += self.unet_loras
|
||||
if hasattr(self, 'text_encoder_loras'):
|
||||
loras += self.text_encoder_loras
|
||||
return loras
|
||||
|
||||
def _update_checkpointing(self: Network):
|
||||
for module in self.get_all_modules():
|
||||
if self.is_checkpointing:
|
||||
module.enable_gradient_checkpointing()
|
||||
else:
|
||||
module.disable_gradient_checkpointing()
|
||||
|
||||
def enable_gradient_checkpointing(self: Network):
|
||||
# not supported
|
||||
self.is_checkpointing = True
|
||||
self._update_checkpointing()
|
||||
|
||||
def disable_gradient_checkpointing(self: Network):
|
||||
# not supported
|
||||
self.is_checkpointing = False
|
||||
self._update_checkpointing()
|
||||
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
self.is_merged_in = True
|
||||
for module in self.get_all_modules():
|
||||
module.merge_in(merge_weight)
|
||||
|
||||
def merge_out(self: Network, merge_weight=1.0):
|
||||
if not self.is_merged_in:
|
||||
return
|
||||
self.is_merged_in = False
|
||||
for module in self.get_all_modules():
|
||||
module.merge_out(merge_weight)
|
||||
|
||||
def extract_weight(
|
||||
self: Network,
|
||||
extract_mode: ExtractMode = "existing",
|
||||
extract_mode_param: Union[int, float] = None,
|
||||
):
|
||||
if extract_mode_param is None:
|
||||
raise ValueError("extract_mode_param must be set")
|
||||
for module in tqdm(self.get_all_modules(), desc="Extracting weights"):
|
||||
module.extract_weight(
|
||||
extract_mode=extract_mode,
|
||||
extract_mode_param=extract_mode_param
|
||||
)
|
||||
|
||||
def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None):
|
||||
for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"):
|
||||
module.setup_lorm(state_dict=state_dict)
|
||||
|
||||
def calculate_lorem_parameter_reduction(self):
|
||||
params_reduced = 0
|
||||
for module in self.get_all_modules():
|
||||
num_orig_module_params = count_parameters(module.org_module[0])
|
||||
num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up)
|
||||
params_reduced += (num_orig_module_params - num_lorem_params)
|
||||
|
||||
return params_reduced
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from transformers import Adafactor
|
||||
|
||||
|
||||
def get_optimizer(
|
||||
@@ -35,6 +36,8 @@ def get_optimizer(
|
||||
if use_lr < 0.1:
|
||||
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
|
||||
use_lr = 1.0
|
||||
|
||||
print(f"Using lr {use_lr}")
|
||||
# let net be the neural network you want to train
|
||||
# you can choose weight decay value based on your problem, 0 by default
|
||||
optimizer = Prodigy(params, lr=use_lr, **optimizer_params)
|
||||
@@ -43,6 +46,8 @@ def get_optimizer(
|
||||
|
||||
if lower_type == "adam8bit":
|
||||
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
|
||||
elif lower_type == "adamw8bit":
|
||||
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params)
|
||||
elif lower_type == "lion8bit":
|
||||
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
|
||||
else:
|
||||
@@ -52,10 +57,15 @@ def get_optimizer(
|
||||
elif lower_type == 'adamw':
|
||||
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
|
||||
elif lower_type == 'lion':
|
||||
from lion_pytorch import Lion
|
||||
return Lion(params, lr=learning_rate, **optimizer_params)
|
||||
try:
|
||||
from lion_pytorch import Lion
|
||||
return Lion(params, lr=learning_rate, **optimizer_params)
|
||||
except ImportError:
|
||||
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
|
||||
elif lower_type == 'adagrad':
|
||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
||||
elif lower_type == 'adafactor':
|
||||
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||
return optimizer
|
||||
|
||||
91
toolkit/orig_configs/sd_xl_refiner.yaml
Normal file
91
toolkit/orig_configs/sd_xl_refiner.yaml
Normal file
@@ -0,0 +1,91 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.13025
|
||||
disable_first_stage_autocast: True
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
adm_in_channels: 2560
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 384
|
||||
attention_resolutions: [4, 2]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 4
|
||||
context_dim: [1280, 1280, 1280, 1280] # 1280
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
legacy: False
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn and vector cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||
params:
|
||||
arch: ViT-bigG-14
|
||||
version: laion2b_s39b_b160k
|
||||
legacy: False
|
||||
freeze: True
|
||||
layer: penultimate
|
||||
always_return_pooled: True
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: aesthetic_score
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by one
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
@@ -5,6 +5,8 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
||||
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
||||
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
||||
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
||||
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
||||
|
||||
# check if ENV variable is set
|
||||
if 'MODELS_PATH' in os.environ:
|
||||
|
||||
@@ -1,14 +1,297 @@
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from diffusers.utils import is_torch_xla_available
|
||||
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: 'AutoencoderKL',
|
||||
text_encoder: 'CLIPTextModel',
|
||||
text_encoder_2: 'CLIPTextModelWithProjection',
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
tokenizer_2: 'CLIPTokenizer',
|
||||
unet: 'UNet2DConditionModel',
|
||||
scheduler: 'KarrasDiffusionSchedulers',
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.sampler = None
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
model = ModelWrapper(unet, scheduler.alphas_cumprod)
|
||||
if scheduler.config.prediction_type == "v_prediction":
|
||||
self.k_diffusion_model = CompVisVDenoiser(model)
|
||||
else:
|
||||
self.k_diffusion_model = CompVisDenoiser(model)
|
||||
|
||||
def set_scheduler(self, scheduler_type: str):
|
||||
library = importlib.import_module("k_diffusion")
|
||||
sampling = getattr(library, "sampling")
|
||||
self.sampler = getattr(sampling, scheduler_type)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
use_karras_sigmas: bool = False,
|
||||
):
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 7.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
# 5. Prepare sigmas
|
||||
if use_karras_sigmas:
|
||||
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
|
||||
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
|
||||
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sigmas = sigmas.to(device)
|
||||
else:
|
||||
sigmas = self.scheduler.sigmas
|
||||
sigmas = sigmas.to(prompt_embeds.dtype)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
latents = latents * sigmas[0]
|
||||
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
|
||||
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
|
||||
|
||||
# 7. Define model function
|
||||
def model_fn(x, t):
|
||||
latent_model_input = torch.cat([x] * 2)
|
||||
t = torch.cat([t] * 2)
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
# noise_pred = self.unet(
|
||||
# latent_model_input,
|
||||
# t,
|
||||
# encoder_hidden_states=prompt_embeds,
|
||||
# cross_attention_kwargs=cross_attention_kwargs,
|
||||
# added_cond_kwargs=added_cond_kwargs,
|
||||
# return_dict=False,
|
||||
# )[0]
|
||||
|
||||
noise_pred = self.k_diffusion_model(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,)[0]
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
return noise_pred
|
||||
|
||||
|
||||
# 8. Run k-diffusion solver
|
||||
sampler_kwargs = {}
|
||||
# should work without it
|
||||
noise_sampler_seed = None
|
||||
|
||||
|
||||
if "noise_sampler" in inspect.signature(self.sampler).parameters:
|
||||
min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
|
||||
sampler_kwargs["noise_sampler"] = noise_sampler
|
||||
|
||||
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
|
||||
class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super().__init__(*args, **kwargs)
|
||||
|
||||
def predict_noise(
|
||||
self,
|
||||
@@ -532,3 +815,389 @@ class CustomStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
return noise_pred
|
||||
|
||||
|
||||
class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
denoising_start: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in both text-encoders
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
denoising_end (`float`, *optional*):
|
||||
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
||||
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
||||
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
||||
denoising_start (`float`, *optional*):
|
||||
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
|
||||
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
|
||||
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
|
||||
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
|
||||
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||
explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
||||
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
||||
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 8.1 Apply denoising_end
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
# 8.2 Determine denoising_start
|
||||
denoising_start_index = 0
|
||||
if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1:
|
||||
discrete_timestep_start = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_start * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps)))
|
||||
|
||||
|
||||
with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar:
|
||||
for i, t in enumerate(timesteps, start=denoising_start_index):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
if not output_type == "latent":
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
|
||||
25
toolkit/progress_bar.py
Normal file
25
toolkit/progress_bar.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
|
||||
|
||||
class ToolkitProgressBar(tqdm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.paused = False
|
||||
self.last_time = self._time()
|
||||
|
||||
def pause(self):
|
||||
if not self.paused:
|
||||
self.paused = True
|
||||
self.last_time = self._time()
|
||||
|
||||
def unpause(self):
|
||||
if self.paused:
|
||||
self.paused = False
|
||||
cur_t = self._time()
|
||||
self.start_t += cur_t - self.last_time
|
||||
self.last_print_t = cur_t
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
if not self.paused:
|
||||
super().update(*args, **kwargs)
|
||||
@@ -1,12 +1,11 @@
|
||||
import os
|
||||
from typing import Optional, TYPE_CHECKING, List
|
||||
from typing import Optional, TYPE_CHECKING, List, Union, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import itertools
|
||||
|
||||
@@ -19,6 +18,39 @@ class ACTION_TYPES_SLIDER:
|
||||
ENHANCE_NEGATIVE = 1
|
||||
|
||||
|
||||
class PromptEmbeds:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: Union[torch.Tensor, None]
|
||||
|
||||
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
|
||||
if isinstance(args, list) or isinstance(args, tuple):
|
||||
# xl
|
||||
self.text_embeds = args[0]
|
||||
self.pooled_embeds = args[1]
|
||||
else:
|
||||
# sdv1.x, sdv2.x
|
||||
self.text_embeds = args
|
||||
self.pooled_embeds = None
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.text_embeds = self.text_embeds.to(*args, **kwargs)
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
self.text_embeds = self.text_embeds.detach()
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.detach()
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
if self.pooled_embeds is not None:
|
||||
return PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
|
||||
else:
|
||||
return PromptEmbeds(self.text_embeds.clone())
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -73,6 +105,18 @@ class EncodedPromptPair:
|
||||
self.both_targets = self.both_targets.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
self.target_class = self.target_class.detach()
|
||||
self.target_class_with_neutral = self.target_class_with_neutral.detach()
|
||||
self.positive_target = self.positive_target.detach()
|
||||
self.positive_target_with_neutral = self.positive_target_with_neutral.detach()
|
||||
self.negative_target = self.negative_target.detach()
|
||||
self.negative_target_with_neutral = self.negative_target_with_neutral.detach()
|
||||
self.neutral = self.neutral.detach()
|
||||
self.empty_prompt = self.empty_prompt.detach()
|
||||
self.both_targets = self.both_targets.detach()
|
||||
return self
|
||||
|
||||
|
||||
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
||||
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
|
||||
@@ -235,15 +279,17 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
|
||||
return anchors
|
||||
|
||||
|
||||
def get_permutations(s):
|
||||
def get_permutations(s, max_permutations=8):
|
||||
# Split the string by comma
|
||||
phrases = [phrase.strip() for phrase in s.split(',')]
|
||||
|
||||
# remove empty strings
|
||||
phrases = [phrase for phrase in phrases if len(phrase) > 0]
|
||||
# shuffle the list
|
||||
random.shuffle(phrases)
|
||||
|
||||
# Get all permutations
|
||||
permutations = list(itertools.permutations(phrases))
|
||||
permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)])
|
||||
|
||||
# Convert the tuples back to comma separated strings
|
||||
return [', '.join(permutation) for permutation in permutations]
|
||||
@@ -251,8 +297,8 @@ def get_permutations(s):
|
||||
|
||||
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
|
||||
from toolkit.config_modules import SliderTargetConfig
|
||||
pos_permutations = get_permutations(target.positive)
|
||||
neg_permutations = get_permutations(target.negative)
|
||||
pos_permutations = get_permutations(target.positive, max_permutations=max_permutations)
|
||||
neg_permutations = get_permutations(target.negative, max_permutations=max_permutations)
|
||||
|
||||
permutations = []
|
||||
for pos, neg in itertools.product(pos_permutations, neg_permutations):
|
||||
@@ -465,3 +511,39 @@ def build_latent_image_batch_for_prompt_pair(
|
||||
latent_list.append(neg_latent)
|
||||
|
||||
return torch.cat(latent_list, dim=0)
|
||||
|
||||
|
||||
def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||
if trigger is None:
|
||||
# process as empty string to remove any [trigger] tokens
|
||||
trigger = ''
|
||||
output_prompt = prompt
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
replace_with = trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
if trigger.strip() != "":
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
# if num_instances > 1:
|
||||
# print(
|
||||
# f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
116
toolkit/sampler.py
Normal file
116
toolkit/sampler.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import copy
|
||||
|
||||
from diffusers import (
|
||||
DDPMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
LCMScheduler
|
||||
)
|
||||
|
||||
from k_diffusion.external import CompVisDenoiser
|
||||
|
||||
from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler
|
||||
|
||||
# scheduler:
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
sdxl_sampler_config = {
|
||||
"_class_name": "EulerDiscreteScheduler",
|
||||
"_diffusers_version": "0.19.0.dev0",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": False,
|
||||
"interpolation_type": "linear",
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": False,
|
||||
"skip_prk_steps": True,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "leading",
|
||||
"trained_betas": None,
|
||||
"use_karras_sigmas": False
|
||||
}
|
||||
|
||||
|
||||
def get_sampler(
|
||||
sampler: str,
|
||||
):
|
||||
sched_init_args = {}
|
||||
|
||||
if sampler.startswith("k_"):
|
||||
sched_init_args["use_karras_sigmas"] = True
|
||||
|
||||
if sampler == "ddim":
|
||||
scheduler_cls = DDIMScheduler
|
||||
elif sampler == "ddpm": # ddpm is not supported ?
|
||||
scheduler_cls = DDPMScheduler
|
||||
elif sampler == "pndm":
|
||||
scheduler_cls = PNDMScheduler
|
||||
elif sampler == "lms" or sampler == "k_lms":
|
||||
scheduler_cls = LMSDiscreteScheduler
|
||||
elif sampler == "euler" or sampler == "k_euler":
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
elif sampler == "euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
elif sampler == "dpmsolver" or sampler == "dpmsolver++" or sampler == "k_dpmsolver" or sampler == "k_dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = sampler.replace("k_", "")
|
||||
elif sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
elif sampler == "heun":
|
||||
scheduler_cls = HeunDiscreteScheduler
|
||||
elif sampler == "dpm_2":
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif sampler == "dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
elif sampler == "lcm":
|
||||
scheduler_cls = LCMScheduler
|
||||
elif sampler == "custom_lcm":
|
||||
scheduler_cls = CustomLCMScheduler
|
||||
|
||||
config = copy.deepcopy(sdxl_sampler_config)
|
||||
config.update(sched_init_args)
|
||||
|
||||
scheduler = scheduler_cls.from_config(config)
|
||||
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
# testing
|
||||
if __name__ == "__main__":
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from diffusers import StableDiffusionKDiffusionPipeline
|
||||
import torch
|
||||
import os
|
||||
|
||||
inference_steps = 25
|
||||
|
||||
pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
k_diffusion_model = CompVisDenoiser(model)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "an astronaut riding a horse on mars"
|
||||
pipe.set_scheduler("sample_heun")
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||
|
||||
image.save("./astronaut_heun_k_diffusion.png")
|
||||
553
toolkit/samplers/custom_lcm_scheduler.py
Normal file
553
toolkit/samplers/custom_lcm_scheduler.py
Normal file
@@ -0,0 +1,553 @@
|
||||
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class LCMSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
denoised: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.FloatTensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
class CustomLCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
||||
non-Markovian guidance.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
|
||||
attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
|
||||
accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
|
||||
functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
original_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
|
||||
will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
|
||||
clip_sample (`bool`, defaults to `True`):
|
||||
Clip the predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, defaults to 1.0):
|
||||
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
set_alpha_to_one (`bool`, defaults to `True`):
|
||||
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||||
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the alpha value at step 0.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
timestep_scaling (`float`, defaults to 10.0):
|
||||
The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
|
||||
`c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
|
||||
error at the default of `10.0` is already pretty small).
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "scaled_linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
original_inference_steps: int = 50,
|
||||
clip_sample: bool = False,
|
||||
clip_sample_range: float = 1.0,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
timestep_scaling: float = 10.0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
# Rescale for zero SNR
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
self.original_inference_steps = 50
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
|
||||
self.train_timesteps = 1000
|
||||
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
return self._step_index
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, *remaining_dims = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
strength: int = 1.0,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
original_inference_steps (`int`, *optional*):
|
||||
The original number of inference steps, which will be used to generate a linearly-spaced timestep
|
||||
schedule (which is different from the standard `diffusers` implementation). We will then take
|
||||
`num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
|
||||
our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
|
||||
"""
|
||||
|
||||
original_inference_steps = self.original_inference_steps
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
original_steps = (
|
||||
original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
|
||||
)
|
||||
|
||||
if original_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
if num_inference_steps > original_steps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
|
||||
f" {original_steps} because the final timestep schedule will be a subset of the"
|
||||
f" `original_inference_steps`-sized initial timestep schedule."
|
||||
)
|
||||
|
||||
# LCM Timesteps Setting
|
||||
# The skipping step parameter k from the paper.
|
||||
k = self.config.num_train_timesteps // original_steps
|
||||
# LCM Training/Distillation Steps Schedule
|
||||
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
|
||||
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
|
||||
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
||||
|
||||
if skipping_step < 1:
|
||||
raise ValueError(
|
||||
f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
|
||||
)
|
||||
|
||||
# LCM Inference Steps Schedule
|
||||
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
|
||||
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
|
||||
inference_indices = np.linspace(0, len(lcm_origin_timesteps) - 1, num=num_inference_steps)
|
||||
inference_indices = np.floor(inference_indices).astype(np.int64)
|
||||
timesteps = lcm_origin_timesteps[inference_indices]
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
|
||||
|
||||
self._step_index = None
|
||||
|
||||
def get_scalings_for_boundary_condition_discrete(self, timestep):
|
||||
self.sigma_data = 0.5 # Default: 0.5
|
||||
scaled_timestep = timestep * self.config.timestep_scaling
|
||||
|
||||
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[LCMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# 1. get previous step value
|
||||
prev_step_index = self.step_index + 1
|
||||
if prev_step_index < len(self.timesteps):
|
||||
prev_timestep = self.timesteps[prev_step_index]
|
||||
else:
|
||||
prev_timestep = timestep
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
# 3. Get scalings for boundary conditions
|
||||
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
||||
|
||||
# 4. Compute the predicted original sample x_0 based on the model parameterization
|
||||
if self.config.prediction_type == "epsilon": # noise-prediction
|
||||
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
||||
elif self.config.prediction_type == "sample": # x-prediction
|
||||
predicted_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction": # v-prediction
|
||||
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
||||
" `v_prediction` for `LCMScheduler`."
|
||||
)
|
||||
|
||||
# 5. Clip or threshold "predicted x_0"
|
||||
if self.config.thresholding:
|
||||
predicted_original_sample = self._threshold_sample(predicted_original_sample)
|
||||
elif self.config.clip_sample:
|
||||
predicted_original_sample = predicted_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
# 6. Denoise model output using boundary conditions
|
||||
denoised = c_out * predicted_original_sample + c_skip * sample
|
||||
|
||||
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
|
||||
# Noise is not used on the final timestep of the timestep schedule.
|
||||
# This also means that noise is not used for one-step sampling.
|
||||
if self.step_index != self.num_inference_steps - 1:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
|
||||
)
|
||||
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
||||
else:
|
||||
prev_sample = denoised
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, denoised)
|
||||
|
||||
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
||||
def get_velocity(
|
||||
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
||||
timesteps = timesteps.to(sample.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -32,6 +32,10 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
with open(mapping_path, 'r') as f:
|
||||
mapping = json.load(f, object_pairs_hook=OrderedDict)
|
||||
|
||||
# keep track of keys not matched
|
||||
ldm_matched_keys = []
|
||||
diffusers_matched_keys = []
|
||||
|
||||
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
|
||||
ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map']
|
||||
ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map']
|
||||
@@ -52,11 +56,15 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
|
||||
cat_list.append(diffusers_state_dict[diffusers_key].detach())
|
||||
converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
|
||||
diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['cat'])
|
||||
ldm_matched_keys.append(ldm_key)
|
||||
if 'slice' in ldm_diffusers_operator_map[ldm_key]:
|
||||
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
|
||||
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
|
||||
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device,
|
||||
dtype=dtype)
|
||||
diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['slice'])
|
||||
ldm_matched_keys.append(ldm_key)
|
||||
|
||||
# process the rest of the keys
|
||||
for ldm_key in ldm_diffusers_keymap:
|
||||
@@ -67,13 +75,29 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
if ldm_key in ldm_diffusers_shape_map:
|
||||
tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0])
|
||||
converted_state_dict[ldm_key] = tensor
|
||||
diffusers_matched_keys.append(ldm_diffusers_keymap[ldm_key])
|
||||
ldm_matched_keys.append(ldm_key)
|
||||
|
||||
# see if any are missing from know mapping
|
||||
mapped_diffusers_keys = list(ldm_diffusers_keymap.values())
|
||||
mapped_ldm_keys = list(ldm_diffusers_keymap.keys())
|
||||
|
||||
missing_diffusers_keys = [x for x in mapped_diffusers_keys if x not in diffusers_matched_keys]
|
||||
missing_ldm_keys = [x for x in mapped_ldm_keys if x not in ldm_matched_keys]
|
||||
|
||||
if len(missing_diffusers_keys) > 0:
|
||||
print(f"WARNING!!!! Missing {len(missing_diffusers_keys)} diffusers keys")
|
||||
print(missing_diffusers_keys)
|
||||
if len(missing_ldm_keys) > 0:
|
||||
print(f"WARNING!!!! Missing {len(missing_ldm_keys)} ldm keys")
|
||||
print(missing_ldm_keys)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_ldm_state_dict_from_diffusers(
|
||||
state_dict: 'OrderedDict',
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2',
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
|
||||
device='cpu',
|
||||
dtype=get_torch_dtype('fp32'),
|
||||
):
|
||||
@@ -87,6 +111,14 @@ def get_ldm_state_dict_from_diffusers(
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
||||
elif sd_version == 'ssd':
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
||||
elif sd_version == 'sdxl_refiner':
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json')
|
||||
else:
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
|
||||
@@ -105,7 +137,7 @@ def save_ldm_model_from_diffusers(
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
||||
):
|
||||
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||
sd.state_dict(),
|
||||
@@ -117,3 +149,95 @@ def save_ldm_model_from_diffusers(
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
|
||||
def save_lora_from_diffusers(
|
||||
lora_state_dict: 'OrderedDict',
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
||||
):
|
||||
converted_state_dict = OrderedDict()
|
||||
# only handle sxdxl for now
|
||||
if sd_version != 'sdxl' and sd_version != 'ssd':
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
for key, value in lora_state_dict.items():
|
||||
# todo verify if this works with ssd
|
||||
# test encoders share keys for some reason
|
||||
if key.begins_with('lora_te'):
|
||||
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
|
||||
else:
|
||||
converted_key = key
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
|
||||
def save_t2i_from_diffusers(
|
||||
t2i_state_dict: 'OrderedDict',
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
dtype=get_torch_dtype('fp16'),
|
||||
):
|
||||
# todo: test compatibility with non diffusers
|
||||
converted_state_dict = OrderedDict()
|
||||
for key, value in t2i_state_dict.items():
|
||||
converted_state_dict[key] = value.detach().to('cpu', dtype=dtype)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
|
||||
def load_t2i_model(
|
||||
path_to_file,
|
||||
device: Union[str] = 'cpu',
|
||||
dtype: torch.dtype = torch.float32
|
||||
):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
converted_state_dict = OrderedDict()
|
||||
for key, value in raw_state_dict.items():
|
||||
# todo see if we need to convert dict
|
||||
converted_state_dict[key] = value.detach().to(device, dtype=dtype)
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
|
||||
|
||||
def save_ip_adapter_from_diffusers(
|
||||
combined_state_dict: 'OrderedDict',
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
dtype=get_torch_dtype('fp16'),
|
||||
):
|
||||
# todo: test compatibility with non diffusers
|
||||
converted_state_dict = OrderedDict()
|
||||
for module_name, state_dict in combined_state_dict.items():
|
||||
for key, value in state_dict.items():
|
||||
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
|
||||
def load_ip_adapter_model(
|
||||
path_to_file,
|
||||
device: Union[str] = 'cpu',
|
||||
dtype: torch.dtype = torch.float32
|
||||
):
|
||||
# check if it is safetensors or checkpoint
|
||||
if path_to_file.endswith('.safetensors'):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
combined_state_dict = OrderedDict()
|
||||
for combo_key, value in raw_state_dict.items():
|
||||
key_split = combo_key.split('.')
|
||||
module_name = key_split.pop(0)
|
||||
if module_name not in combined_state_dict:
|
||||
combined_state_dict[module_name] = OrderedDict()
|
||||
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
|
||||
return combined_state_dict
|
||||
else:
|
||||
return torch.load(path_to_file, map_location=device)
|
||||
|
||||
@@ -1,33 +1,57 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup
|
||||
|
||||
|
||||
def get_lr_scheduler(
|
||||
name: Optional[str],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
max_iterations: Optional[int],
|
||||
lr_min: Optional[float],
|
||||
**kwargs,
|
||||
):
|
||||
if name == "cosine":
|
||||
if 'total_iters' in kwargs:
|
||||
kwargs['T_max'] = kwargs.pop('total_iters')
|
||||
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "cosine_with_restarts":
|
||||
if 'total_iters' in kwargs:
|
||||
kwargs['T_0'] = kwargs.pop('total_iters')
|
||||
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer, T_0=max_iterations, T_mult=2, eta_min=lr_min, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "step":
|
||||
|
||||
return torch.optim.lr_scheduler.StepLR(
|
||||
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "constant":
|
||||
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
|
||||
if 'facor' not in kwargs:
|
||||
kwargs['factor'] = 1.0
|
||||
|
||||
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
|
||||
elif name == "linear":
|
||||
|
||||
return torch.optim.lr_scheduler.LinearLR(
|
||||
optimizer, start_factor=0.5, end_factor=0.5, total_iters=max_iterations, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == 'constant_with_warmup':
|
||||
# see if num_warmup_steps is in kwargs
|
||||
if 'num_warmup_steps' not in kwargs:
|
||||
print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000")
|
||||
kwargs['num_warmup_steps'] = 1000
|
||||
del kwargs['total_iters']
|
||||
return get_constant_schedule_with_warmup(optimizer, **kwargs)
|
||||
else:
|
||||
# try to use a diffusers scheduler
|
||||
print(f"Trying to use diffusers scheduler {name}")
|
||||
try:
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
return schedule_func(optimizer, **kwargs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
raise ValueError(
|
||||
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
|
||||
)
|
||||
|
||||
88
toolkit/sd_device_states_presets.py
Normal file
88
toolkit/sd_device_states_presets.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import copy
|
||||
|
||||
empty_preset = {
|
||||
'vae': {
|
||||
'training': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
'unet': {
|
||||
'training': False,
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
'text_encoder': {
|
||||
'training': False,
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
'adapter': {
|
||||
'training': False,
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
'refiner_unet': {
|
||||
'training': False,
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_train_sd_device_state_preset(
|
||||
device: Union[str, torch.device],
|
||||
train_unet: bool = False,
|
||||
train_text_encoder: bool = False,
|
||||
cached_latents: bool = False,
|
||||
train_lora: bool = False,
|
||||
train_adapter: bool = False,
|
||||
train_embedding: bool = False,
|
||||
train_refiner: bool = False,
|
||||
):
|
||||
preset = copy.deepcopy(empty_preset)
|
||||
if not cached_latents:
|
||||
preset['vae']['device'] = device
|
||||
|
||||
if train_unet:
|
||||
preset['unet']['training'] = True
|
||||
preset['unet']['requires_grad'] = True
|
||||
preset['unet']['device'] = device
|
||||
else:
|
||||
preset['unet']['device'] = device
|
||||
|
||||
if train_text_encoder:
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['text_encoder']['requires_grad'] = True
|
||||
preset['text_encoder']['device'] = device
|
||||
else:
|
||||
preset['text_encoder']['device'] = device
|
||||
|
||||
if train_embedding:
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['text_encoder']['requires_grad'] = True
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['unet']['training'] = True
|
||||
|
||||
if train_refiner:
|
||||
preset['refiner_unet']['training'] = True
|
||||
preset['refiner_unet']['requires_grad'] = True
|
||||
preset['refiner_unet']['device'] = device
|
||||
# if not training unet, move that to cpu
|
||||
if not train_unet:
|
||||
preset['unet']['device'] = 'cpu'
|
||||
|
||||
if train_lora:
|
||||
# preset['text_encoder']['requires_grad'] = False
|
||||
preset['unet']['requires_grad'] = False
|
||||
if train_refiner:
|
||||
preset['refiner_unet']['requires_grad'] = False
|
||||
|
||||
if train_adapter:
|
||||
preset['adapter']['requires_grad'] = True
|
||||
preset['adapter']['training'] = True
|
||||
preset['adapter']['device'] = device
|
||||
preset['unet']['training'] = True
|
||||
|
||||
return preset
|
||||
File diff suppressed because it is too large
Load Diff
@@ -33,12 +33,17 @@ class ContentLoss(nn.Module):
|
||||
|
||||
# Define the separate loss function
|
||||
def separated_loss(y_pred, y_true):
|
||||
y_pred = y_pred.float()
|
||||
y_true = y_true.float()
|
||||
diff = torch.abs(y_pred - y_true)
|
||||
l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
|
||||
return 2. * l2 / content_size
|
||||
|
||||
# Calculate itemized loss
|
||||
pred_itemized_loss = separated_loss(pred_layer, target_layer)
|
||||
# check if is nan
|
||||
if torch.isnan(pred_itemized_loss).any():
|
||||
print('pred_itemized_loss is nan')
|
||||
|
||||
# Calculate the mean of itemized loss
|
||||
loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
|
||||
@@ -48,6 +53,7 @@ class ContentLoss(nn.Module):
|
||||
|
||||
|
||||
def convert_to_gram_matrix(inputs):
|
||||
inputs = inputs.float()
|
||||
shape = inputs.size()
|
||||
batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
|
||||
size = height * width * filters
|
||||
@@ -93,11 +99,14 @@ class StyleLoss(nn.Module):
|
||||
target_grams = convert_to_gram_matrix(style_target)
|
||||
pred_grams = convert_to_gram_matrix(preds)
|
||||
itemized_loss = separated_loss(pred_grams, target_grams)
|
||||
# check if is nan
|
||||
if torch.isnan(itemized_loss).any():
|
||||
print('itemized_loss is nan')
|
||||
# reshape itemized loss to be (batch, 1, 1, 1)
|
||||
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
|
||||
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
|
||||
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
|
||||
self.loss = loss.to(input_dtype)
|
||||
self.loss = loss.to(input_dtype).float()
|
||||
return stacked_input.to(input_dtype)
|
||||
|
||||
|
||||
@@ -149,7 +158,7 @@ def get_style_model_and_losses(
|
||||
):
|
||||
# content_layers = ['conv_4']
|
||||
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
content_layers = ['conv4_2']
|
||||
content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']
|
||||
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
|
||||
# set all weights in the model to our dtype
|
||||
|
||||
65
toolkit/timer.py
Normal file
65
toolkit/timer.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import time
|
||||
from collections import OrderedDict, deque
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self, name='Timer', max_buffer=10):
|
||||
self.name = name
|
||||
self.max_buffer = max_buffer
|
||||
self.timers = OrderedDict()
|
||||
self.active_timers = {}
|
||||
self.current_timer = None # Used for the context manager functionality
|
||||
|
||||
def start(self, timer_name):
|
||||
if timer_name not in self.timers:
|
||||
self.timers[timer_name] = deque(maxlen=self.max_buffer)
|
||||
self.active_timers[timer_name] = time.time()
|
||||
|
||||
def cancel(self, timer_name):
|
||||
"""Cancel an active timer."""
|
||||
if timer_name in self.active_timers:
|
||||
del self.active_timers[timer_name]
|
||||
|
||||
def stop(self, timer_name):
|
||||
if timer_name not in self.active_timers:
|
||||
raise ValueError(f"Timer '{timer_name}' was not started!")
|
||||
|
||||
elapsed_time = time.time() - self.active_timers[timer_name]
|
||||
self.timers[timer_name].append(elapsed_time)
|
||||
|
||||
# Clean up active timers
|
||||
del self.active_timers[timer_name]
|
||||
|
||||
# Check if this timer's buffer exceeds max_buffer and remove the oldest if it does
|
||||
if len(self.timers[timer_name]) > self.max_buffer:
|
||||
self.timers[timer_name].popleft()
|
||||
|
||||
def print(self):
|
||||
print(f"\nTimer '{self.name}':")
|
||||
# sort by longest at top
|
||||
for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
|
||||
avg_time = sum(timings) / len(timings)
|
||||
print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
|
||||
|
||||
print('')
|
||||
|
||||
def reset(self):
|
||||
self.timers.clear()
|
||||
self.active_timers.clear()
|
||||
|
||||
def __call__(self, timer_name):
|
||||
"""Enable the use of the Timer class as a context manager."""
|
||||
self.current_timer = timer_name
|
||||
self.start(timer_name)
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type is None:
|
||||
# No exceptions, stop the timer normally
|
||||
self.stop(self.current_timer)
|
||||
else:
|
||||
# There was an exception, cancel the timer
|
||||
self.cancel(self.current_timer)
|
||||
@@ -5,6 +5,9 @@ import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Union
|
||||
import sys
|
||||
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
@@ -444,29 +447,78 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def text_tokenize(
|
||||
tokenizer: 'CLIPTokenizer', # 普通ならひとつ、XLならふたつ!
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length: int = None,
|
||||
max_length_multiplier: int = 4,
|
||||
):
|
||||
return tokenizer(
|
||||
# allow fo up to 4x the max length for long prompts
|
||||
if max_length is None:
|
||||
if truncate:
|
||||
max_length = tokenizer.model_max_length
|
||||
else:
|
||||
# allow up to 4x the max length for long prompts
|
||||
max_length = tokenizer.model_max_length * max_length_multiplier
|
||||
|
||||
input_ids = tokenizer(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
padding='max_length',
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
if truncate or max_length == tokenizer.model_max_length:
|
||||
return input_ids
|
||||
else:
|
||||
# remove additional padding
|
||||
num_chunks = input_ids.shape[1] // tokenizer.model_max_length
|
||||
chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1)
|
||||
|
||||
# New list to store non-redundant chunks
|
||||
non_redundant_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element
|
||||
non_redundant_chunks.append(chunk)
|
||||
|
||||
input_ids = torch.cat(non_redundant_chunks, dim=1)
|
||||
return input_ids
|
||||
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
|
||||
def text_encode_xl(
|
||||
text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'],
|
||||
tokens: torch.FloatTensor,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_length: int = 77, # not sure what default to put here, always pass one?
|
||||
truncate: bool = True,
|
||||
):
|
||||
prompt_embeds = text_encoder(
|
||||
tokens.to(text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
||||
if truncate:
|
||||
# normal short prompt 77 tokens max
|
||||
prompt_embeds = text_encoder(
|
||||
tokens.to(text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
||||
else:
|
||||
# handle long prompts
|
||||
prompt_embeds_list = []
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
pooled_prompt_embeds = None
|
||||
for i in range(0, tokens.shape[-1], max_length):
|
||||
# todo run it through the in a single batch
|
||||
section_tokens = tokens[:, i: i + max_length]
|
||||
embeds = text_encoder(section_tokens, output_hidden_states=True)
|
||||
pooled_prompt_embed = embeds[0]
|
||||
if pooled_prompt_embeds is None:
|
||||
# we only want the first ( I think??)
|
||||
pooled_prompt_embeds = pooled_prompt_embed
|
||||
prompt_embed = embeds.hidden_states[-2] # always penultimate layer
|
||||
prompt_embeds_list.append(prompt_embed)
|
||||
|
||||
prompt_embeds = torch.cat(prompt_embeds_list, dim=1)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -479,26 +531,43 @@ def encode_prompts_xl(
|
||||
tokenizers: list['CLIPTokenizer'],
|
||||
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']],
|
||||
prompts: list[str],
|
||||
prompts2: Union[list[str], None],
|
||||
num_images_per_prompt: int = 1,
|
||||
use_text_encoder_1: bool = True, # sdxl
|
||||
use_text_encoder_2: bool = True # sdxl
|
||||
use_text_encoder_2: bool = True, # sdxl
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# text_encoder and text_encoder_2's penuultimate layer's output
|
||||
text_embeds_list = []
|
||||
pooled_text_embeds = None # always text_encoder_2's pool
|
||||
if prompts2 is None:
|
||||
prompts2 = prompts
|
||||
|
||||
for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
|
||||
# todo, we are using a blank string to ignore that encoder for now.
|
||||
# find a better way to do this (zeroing?, removing it from the unet?)
|
||||
prompt_list_to_use = prompts
|
||||
prompt_list_to_use = prompts if idx == 0 else prompts2
|
||||
if idx == 0 and not use_text_encoder_1:
|
||||
prompt_list_to_use = ["" for _ in prompts]
|
||||
if idx == 1 and not use_text_encoder_2:
|
||||
prompt_list_to_use = ["" for _ in prompts]
|
||||
|
||||
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use)
|
||||
if dropout_prob > 0.0:
|
||||
# randomly drop out prompts
|
||||
prompt_list_to_use = [
|
||||
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use
|
||||
]
|
||||
|
||||
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
|
||||
# set the max length for the next one
|
||||
if idx == 0:
|
||||
max_length = text_tokens_input_ids.shape[-1]
|
||||
|
||||
text_embeds, pooled_text_embeds = text_encode_xl(
|
||||
text_encoder, text_tokens_input_ids, num_images_per_prompt
|
||||
text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length,
|
||||
truncate=truncate
|
||||
)
|
||||
|
||||
text_embeds_list.append(text_embeds)
|
||||
@@ -511,17 +580,49 @@ def encode_prompts_xl(
|
||||
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
||||
|
||||
|
||||
def text_encode(text_encoder: 'CLIPTextModel', tokens):
|
||||
return text_encoder(tokens.to(text_encoder.device))[0]
|
||||
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
|
||||
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
|
||||
if max_length is None and not truncate:
|
||||
raise ValueError("max_length must be set if truncate is True")
|
||||
try:
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("tokens.device", tokens.device)
|
||||
print("text_encoder.device", text_encoder.device)
|
||||
raise e
|
||||
|
||||
if truncate:
|
||||
return text_encoder(tokens)[0]
|
||||
else:
|
||||
# handle long prompts
|
||||
prompt_embeds_list = []
|
||||
for i in range(0, tokens.shape[-1], max_length):
|
||||
prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0]
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
return torch.cat(prompt_embeds_list, dim=1)
|
||||
|
||||
|
||||
def encode_prompts(
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
text_encoder: 'CLIPTokenizer',
|
||||
text_encoder: 'CLIPTextModel',
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
):
|
||||
text_tokens = text_tokenize(tokenizer, prompts)
|
||||
text_embeddings = text_encode(text_encoder, text_tokens)
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
if dropout_prob > 0.0:
|
||||
# randomly drop out prompts
|
||||
prompts = [
|
||||
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
|
||||
]
|
||||
|
||||
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
|
||||
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@@ -602,18 +703,86 @@ def get_all_snr(noise_scheduler, device):
|
||||
all_snr.requires_grad = False
|
||||
return all_snr.to(device)
|
||||
|
||||
class LearnableSNRGamma:
|
||||
"""
|
||||
This is a trainer for learnable snr gamma
|
||||
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
|
||||
"""
|
||||
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
||||
self.device = device
|
||||
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
||||
self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
|
||||
self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
|
||||
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
|
||||
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
|
||||
self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01)
|
||||
self.buffer = []
|
||||
self.max_buffer_size = 20
|
||||
|
||||
def forward(self, loss, timesteps):
|
||||
# do a our train loop for lsnr here and return our values detached
|
||||
loss = loss.detach()
|
||||
with torch.no_grad():
|
||||
loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
|
||||
for loss_chunk in loss_chunks:
|
||||
self.buffer.append(loss_chunk.mean().detach())
|
||||
if len(self.buffer) > self.max_buffer_size:
|
||||
self.buffer.pop(0)
|
||||
all_snr = get_all_snr(self.noise_scheduler, loss.device)
|
||||
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
|
||||
base_snrs = snr.clone().detach()
|
||||
snr.requires_grad = True
|
||||
snr = (snr + self.offset_1) * self.scale + self.offset_2
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
with torch.no_grad():
|
||||
target = torch.mean(torch.stack(self.buffer)).detach()
|
||||
|
||||
# local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
|
||||
squared_differences = (snr_adjusted_loss - target) ** 2
|
||||
local_loss = torch.mean(squared_differences)
|
||||
local_loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach()
|
||||
|
||||
|
||||
def apply_learnable_snr_gos(
|
||||
loss,
|
||||
timesteps,
|
||||
learnable_snr_trainer: LearnableSNRGamma
|
||||
):
|
||||
|
||||
snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps)
|
||||
|
||||
snr = (snr + offset_1) * scale + offset_2
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
|
||||
return snr_adjusted_loss
|
||||
|
||||
|
||||
def apply_snr_weight(
|
||||
loss,
|
||||
timesteps,
|
||||
noise_scheduler: Union['DDPMScheduler'],
|
||||
gamma
|
||||
gamma,
|
||||
fixed=False,
|
||||
):
|
||||
# will get it form noise scheduler if exist or will calculate it if not
|
||||
# will get it from noise scheduler if exist or will calculate it if not
|
||||
all_snr = get_all_snr(noise_scheduler, loss.device)
|
||||
|
||||
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||
step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
|
||||
snr = torch.stack([all_snr[t] for t in step_indices])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
if fixed:
|
||||
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
||||
else:
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
|
||||
return snr_adjusted_loss
|
||||
|
||||
Reference in New Issue
Block a user