mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge pull request #7 from ostris/textual_inversion
Textual inversion training
This commit is contained in:
@@ -0,0 +1,152 @@
|
||||
import copy
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, Union, List
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
|
||||
from toolkit.config_modules import ReferenceDatasetConfig
|
||||
from toolkit.data_loader import PairedImageDataset, ImageDataset
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_noise_offset
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
import random
|
||||
from toolkit.basic import value_map
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class TextualInversionTrainer(BaseSDTrainProcess):
|
||||
sd: StableDiffusion
|
||||
data_loader: DataLoader = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
pass
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
|
||||
# keep original embeddings as reference
|
||||
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
|
||||
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
||||
self.sd.text_encoder.train()
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
with torch.no_grad():
|
||||
imgs, prompts = batch
|
||||
|
||||
# very loosely based on this. very loosely
|
||||
# ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
|
||||
|
||||
# make sure the embedding is in the prompts
|
||||
conditioned_prompts = [self.embedding.inject_embedding_to_prompt(
|
||||
x,
|
||||
expand_token=True,
|
||||
add_if_not_present=True,
|
||||
) for x in prompts]
|
||||
|
||||
batch_size = imgs.shape[0]
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
latents = self.sd.encode_images(imgs)
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=imgs.shape[2],
|
||||
pixel_width=imgs.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noise.requires_grad = False
|
||||
|
||||
flush()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
noisy_latents.requires_grad = False
|
||||
|
||||
# text encoding
|
||||
embedding_list = []
|
||||
# embed the prompts
|
||||
for prompt in conditioned_prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
noise = noise.to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[
|
||||
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
|
||||
with torch.no_grad():
|
||||
self.sd.text_encoder.get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = self.orig_embeds_params[index_no_updates]
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
# end hook_train_loop
|
||||
25
extensions_built_in/textual_inversion_trainer/__init__.py
Normal file
25
extensions_built_in/textual_inversion_trainer/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# We make a subclass of Extension
|
||||
class OffsetSliderTrainer(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "textual_inversion_trainer"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Textual Inversion Trainer"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .TextualInversionTrainer import TextualInversionTrainer
|
||||
return TextualInversionTrainer
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
OffsetSliderTrainer
|
||||
]
|
||||
@@ -0,0 +1,92 @@
|
||||
---
|
||||
job: extension
|
||||
config:
|
||||
name: test_v1
|
||||
process:
|
||||
- type: 'textual_inversion_trainer'
|
||||
training_folder: "out/TI"
|
||||
device: cuda:0
|
||||
# for tensorboard logging
|
||||
log_dir: "out/.tensorboard"
|
||||
embedding:
|
||||
trigger: "your_trigger_here"
|
||||
tokens: 12
|
||||
init_words: "man with short brown hair"
|
||||
save_format: "safetensors" # 'safetensors' or 'pt'
|
||||
save:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 100 # save every this many steps
|
||||
max_step_saves_to_keep: 5 # only affects step counts
|
||||
datasets:
|
||||
- folder_path: "/path/to/dataset"
|
||||
caption_type: "txt"
|
||||
default_caption: "[trigger]"
|
||||
buckets: true
|
||||
resolution: 512
|
||||
train:
|
||||
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
||||
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
||||
steps: 3000
|
||||
weight_jitter: 0.0
|
||||
lr: 5e-5
|
||||
train_unet: false
|
||||
gradient_checkpointing: true
|
||||
train_text_encoder: false
|
||||
optimizer: "adamw"
|
||||
# optimizer: "prodigy"
|
||||
optimizer_params:
|
||||
weight_decay: 1e-2
|
||||
lr_scheduler: "constant"
|
||||
max_denoising_steps: 1000
|
||||
batch_size: 4
|
||||
dtype: bf16
|
||||
xformers: true
|
||||
min_snr_gamma: 5.0
|
||||
# skip_first_sample: true
|
||||
noise_offset: 0.0 # not needed for this
|
||||
model:
|
||||
# objective reality v2
|
||||
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
||||
is_v2: false # for v2 models
|
||||
is_xl: false # for SDXL models
|
||||
is_v_pred: false # for v-prediction models (most v2 models)
|
||||
sample:
|
||||
sampler: "ddpm" # must match train.noise_scheduler
|
||||
sample_every: 100 # sample every this many steps
|
||||
width: 512
|
||||
height: 512
|
||||
prompts:
|
||||
- "photo of [trigger] laughing"
|
||||
- "photo of [trigger] smiling"
|
||||
- "[trigger] close up"
|
||||
- "dark scene [trigger] frozen"
|
||||
- "[trigger] nighttime"
|
||||
- "a painting of [trigger]"
|
||||
- "a drawing of [trigger]"
|
||||
- "a cartoon of [trigger]"
|
||||
- "[trigger] pixar style"
|
||||
- "[trigger] costume"
|
||||
neg: ""
|
||||
seed: 42
|
||||
walk_seed: false
|
||||
guidance_scale: 7
|
||||
sample_steps: 20
|
||||
network_multiplier: 1.0
|
||||
|
||||
logging:
|
||||
log_every: 10 # log every this many steps
|
||||
use_wandb: false # not supported yet
|
||||
verbose: false
|
||||
|
||||
# You can put any information you want here, and it will be saved in the model.
|
||||
# The below is an example, but you can put your grocery list in it if you want.
|
||||
# It is saved in the model so be aware of that. The software will include this
|
||||
# plus some other information for you automatically
|
||||
meta:
|
||||
# [name] gets replaced with the name above
|
||||
name: "[name]"
|
||||
# version: '1.0'
|
||||
# creator:
|
||||
# name: Your Name
|
||||
# email: your@gmail.com
|
||||
# website: https://your.website
|
||||
@@ -5,6 +5,8 @@ from typing import Union
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
|
||||
@@ -20,7 +22,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -30,6 +32,7 @@ def flush():
|
||||
|
||||
class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sd: StableDiffusion
|
||||
embedding: Union[Embedding, None] = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -59,6 +62,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.lr_scheduler = None
|
||||
self.data_loader: Union[DataLoader, None] = None
|
||||
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
self.datasets = None
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
self.datasets = [DatasetConfig(**d) for d in raw_datasets]
|
||||
|
||||
self.embed_config = None
|
||||
embedding_raw = self.get_conf('embedding', None)
|
||||
if embedding_raw is not None:
|
||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
@@ -68,6 +81,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
self.embedding = None
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
@@ -89,8 +103,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
output_path = os.path.join(sample_folder, filename)
|
||||
|
||||
prompt = sample_config.prompts[i]
|
||||
|
||||
# add embedding if there is one
|
||||
# note: diffusers will automatically expand the trigger to the number of added tokens
|
||||
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
)
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=sample_config.prompts[i], # it will autoparse the prompt
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
width=sample_config.width,
|
||||
height=sample_config.height,
|
||||
negative_prompt=sample_config.neg,
|
||||
@@ -175,6 +199,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
metadata=save_meta
|
||||
)
|
||||
self.network.multiplier = prev_multiplier
|
||||
elif self.embedding is not None:
|
||||
# set current step
|
||||
self.embedding.step = self.step_num
|
||||
# change filename to pt if that is set
|
||||
if self.embed_config.save_format == "pt":
|
||||
# replace extension
|
||||
file_path = os.path.splitext(file_path)[0] + ".pt"
|
||||
self.embedding.save(file_path)
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -197,7 +229,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def hook_before_train_loop(self):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch=None):
|
||||
def before_dataset_load(self):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
@@ -208,6 +243,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||
pattern = f"{self.job.name}*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
# try pt
|
||||
pattern = f"{self.job.name}*.pt"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
return latest_file
|
||||
@@ -230,11 +270,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def run(self):
|
||||
# run base process run
|
||||
BaseTrainProcess.run(self)
|
||||
### HOOk ###
|
||||
self.before_dataset_load()
|
||||
# load datasets if passed in the root process
|
||||
if self.datasets is not None:
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size)
|
||||
|
||||
### HOOK ###
|
||||
self.hook_before_model_load()
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
# may get disabled elsewhere
|
||||
self.sd.unet.enable_gradient_checkpointing()
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# model is loaded from BaseSDProcess
|
||||
@@ -303,7 +353,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.print(f"Loading from {latest_save_path}")
|
||||
self.load_weights(latest_save_path)
|
||||
self.network.multiplier = 1.0
|
||||
elif self.embed_config is not None:
|
||||
self.embedding = Embedding(
|
||||
sd=self.sd,
|
||||
embed_config=self.embed_config
|
||||
)
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
|
||||
# resume state from embedding
|
||||
self.step_num = self.embedding.step
|
||||
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
else:
|
||||
params = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Literal
|
||||
import random
|
||||
|
||||
|
||||
@@ -50,6 +50,14 @@ class NetworkConfig:
|
||||
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.trigger = kwargs.get('trigger', 'custom_embedding')
|
||||
self.tokens = kwargs.get('tokens', 4)
|
||||
self.init_words = kwargs.get('init_words', '*')
|
||||
self.save_format = kwargs.get('save_format', 'safetensors')
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||
@@ -68,6 +76,7 @@ class TrainConfig:
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -146,6 +155,21 @@ class SliderConfig:
|
||||
self.targets.append(target)
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
caption_type: Literal["txt", "caption"] = 'txt'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.type = kwargs.get('type', 'image') # sd, slider, reference
|
||||
self.folder_path: str = kwargs.get('folder_path', None)
|
||||
self.default_caption: str = kwargs.get('default_caption', None)
|
||||
self.caption_type: str = kwargs.get('caption_type', None)
|
||||
self.random_scale: bool = kwargs.get('random_scale', False)
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
self.resolution: int = kwargs.get('resolution', 512)
|
||||
self.scale: float = kwargs.get('scale', 1.0)
|
||||
self.buckets: bool = kwargs.get('buckets', False)
|
||||
|
||||
|
||||
class GenerateImageConfig:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,23 +1,40 @@
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.dataloader_mixins import CaptionMixin
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
BUCKET_STEPS = 64
|
||||
|
||||
def get_bucket_sizes_for_resolution(resolution: int) -> List[int]:
|
||||
# make sure resolution is divisible by 8
|
||||
if resolution % 8 != 0:
|
||||
resolution = resolution - (resolution % 8)
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.name = self.get_config('name', 'dataset')
|
||||
self.path = self.get_config('path', required=True)
|
||||
self.scale = self.get_config('scale', 1)
|
||||
self.random_scale = self.get_config('random_scale', False)
|
||||
self.include_prompt = self.get_config('include_prompt', False)
|
||||
self.default_prompt = self.get_config('default_prompt', '')
|
||||
if self.include_prompt:
|
||||
self.caption_type = self.get_config('caption_type', 'txt')
|
||||
else:
|
||||
self.caption_type = None
|
||||
# we always random crop if random scale is enabled
|
||||
self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False)
|
||||
|
||||
@@ -81,7 +98,11 @@ class ImageDataset(Dataset):
|
||||
|
||||
img = self.transform(img)
|
||||
|
||||
return img
|
||||
if self.include_prompt:
|
||||
prompt = self.get_caption_item(index)
|
||||
return img, prompt
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class Augments:
|
||||
@@ -268,3 +289,102 @@ class PairedImageDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
class AiToolkitDataset(Dataset, CaptionMixin):
|
||||
def __init__(self, dataset_config: 'DatasetConfig'):
|
||||
self.dataset_config = dataset_config
|
||||
self.folder_path = dataset_config.folder_path
|
||||
self.caption_type = dataset_config.caption_type
|
||||
self.default_caption = dataset_config.default_caption
|
||||
self.random_scale = dataset_config.random_scale
|
||||
self.scale = dataset_config.scale
|
||||
# we always random crop if random scale is enabled
|
||||
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
|
||||
self.resolution = dataset_config.resolution
|
||||
|
||||
# get the file list
|
||||
self.file_list = [
|
||||
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||
]
|
||||
|
||||
# this might take a while
|
||||
print(f" - Preprocessing image dimensions")
|
||||
new_file_list = []
|
||||
bad_count = 0
|
||||
for file in tqdm(self.file_list):
|
||||
img = Image.open(file)
|
||||
if int(min(img.size) * self.scale) >= self.resolution:
|
||||
new_file_list.append(file)
|
||||
else:
|
||||
bad_count += 1
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = self.file_list[index]
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
|
||||
min_img_size = min(img.size)
|
||||
|
||||
if self.random_crop:
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.resolution)(img)
|
||||
else:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
||||
|
||||
img = self.transform(img)
|
||||
|
||||
if self.caption_type is not None:
|
||||
prompt = self.get_caption_item(index)
|
||||
return img, prompt
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
# TODO do bucketing
|
||||
if dataset_options is None or len(dataset_options) == 0:
|
||||
return None
|
||||
|
||||
datasets = []
|
||||
for dataset_option in dataset_options:
|
||||
if isinstance(dataset_option, DatasetConfig):
|
||||
config = dataset_option
|
||||
else:
|
||||
config = DatasetConfig(**dataset_option)
|
||||
if config.type == 'image':
|
||||
dataset = AiToolkitDataset(config)
|
||||
datasets.append(dataset)
|
||||
else:
|
||||
raise ValueError(f"invalid dataset type: {config.type}")
|
||||
|
||||
concatenated_dataset = ConcatDataset(datasets)
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
)
|
||||
return data_loader
|
||||
|
||||
43
toolkit/dataloader_mixins.py
Normal file
43
toolkit/dataloader_mixins.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
|
||||
|
||||
class CaptionMixin:
|
||||
def get_caption_item(self, index):
|
||||
if not hasattr(self, 'caption_type'):
|
||||
raise Exception('caption_type not found on class instance')
|
||||
if not hasattr(self, 'file_list'):
|
||||
raise Exception('file_list not found on class instance')
|
||||
img_path_or_tuple = self.file_list[index]
|
||||
if isinstance(img_path_or_tuple, tuple):
|
||||
# check if either has a prompt file
|
||||
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
if not os.path.exists(prompt_path):
|
||||
path_no_ext = os.path.splitext(img_path_or_tuple[1])[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
else:
|
||||
img_path = img_path_or_tuple
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
|
||||
if os.path.exists(prompt_path):
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
# remove any newlines
|
||||
prompt = prompt.replace('\n', ', ')
|
||||
# remove new lines for all operating systems
|
||||
prompt = prompt.replace('\r', ', ')
|
||||
prompt_split = prompt.split(',')
|
||||
# remove empty strings
|
||||
prompt_split = [p.strip() for p in prompt_split if p.strip()]
|
||||
# join back together
|
||||
prompt = ', '.join(prompt_split)
|
||||
else:
|
||||
prompt = ''
|
||||
# get default_prompt if it exists on the class instance
|
||||
if hasattr(self, 'default_prompt'):
|
||||
prompt = self.default_prompt
|
||||
if hasattr(self, 'default_caption'):
|
||||
prompt = self.default_caption
|
||||
return prompt
|
||||
220
toolkit/embedding.py
Normal file
220
toolkit/embedding.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.config_modules import EmbeddingConfig
|
||||
|
||||
|
||||
# this is a frankenstein mix of automatic1111 and my own code
|
||||
|
||||
class Embedding:
|
||||
def __init__(
|
||||
self,
|
||||
sd: 'StableDiffusion',
|
||||
embed_config: 'EmbeddingConfig'
|
||||
):
|
||||
self.name = embed_config.trigger
|
||||
self.sd = sd
|
||||
self.trigger = embed_config.trigger
|
||||
self.embed_config = embed_config
|
||||
self.step = 0
|
||||
# setup our embedding
|
||||
# Add the placeholder token in tokenizer
|
||||
placeholder_tokens = [self.embed_config.trigger]
|
||||
|
||||
# add dummy tokens for multi-vector
|
||||
additional_tokens = []
|
||||
for i in range(1, self.embed_config.tokens):
|
||||
additional_tokens.append(f"{self.embed_config.trigger}_{i}")
|
||||
placeholder_tokens += additional_tokens
|
||||
|
||||
num_added_tokens = self.sd.tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != self.embed_config.tokens:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
init_token_ids = self.sd.tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
|
||||
# if length of token ids is more than number of orm embedding tokens fill with *
|
||||
if len(init_token_ids) > self.embed_config.tokens:
|
||||
init_token_ids = init_token_ids[:self.embed_config.tokens]
|
||||
elif len(init_token_ids) < self.embed_config.tokens:
|
||||
pad_token_id = self.sd.tokenizer.encode(["*"], add_special_tokens=False)
|
||||
init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
|
||||
|
||||
self.placeholder_token_ids = self.sd.tokenizer.convert_tokens_to_ids(placeholder_tokens)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
# todo SDXL has 2 text encoders, need to do both for all of this
|
||||
self.sd.text_encoder.resize_token_embeddings(len(self.sd.tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
||||
with torch.no_grad():
|
||||
for initializer_token_id, token_id in zip(init_token_ids, self.placeholder_token_ids):
|
||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||
|
||||
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
||||
self.embedding_tokens = " ".join(self.sd.tokenizer.convert_ids_to_tokens(self.placeholder_token_ids))
|
||||
|
||||
# returns the string to have in the prompt to trigger the embedding
|
||||
def get_embedding_string(self):
|
||||
return self.embedding_tokens
|
||||
|
||||
def get_trainable_params(self):
|
||||
# todo only get this one as we could have more than one
|
||||
return self.sd.text_encoder.get_input_embeddings().parameters()
|
||||
|
||||
# make setter and getter for vec
|
||||
@property
|
||||
def vec(self):
|
||||
# should we get params instead
|
||||
# create vector from token embeds
|
||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
||||
# stack the tokens along batch axis adding that axis
|
||||
new_vector = torch.stack(
|
||||
[token_embeds[token_id] for token_id in self.placeholder_token_ids],
|
||||
dim=0
|
||||
)
|
||||
return new_vector
|
||||
|
||||
@vec.setter
|
||||
def vec(self, new_vector):
|
||||
# shape is (1, 768) for SD 1.5 for 1 token
|
||||
token_embeds = self.sd.text_encoder.get_input_embeddings().weight.data
|
||||
for i in range(new_vector.shape[0]):
|
||||
# apply the weights to the placeholder tokens while preserving gradient
|
||||
token_embeds[self.placeholder_token_ids[i]] = new_vector[i].clone()
|
||||
x = 1
|
||||
|
||||
# diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
|
||||
# however, on training we don't use that pipeline, so we have to do it ourselves
|
||||
def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
||||
output_prompt = prompt
|
||||
default_replacements = [self.name, self.trigger, "[name]", "[trigger]", self.embedding_tokens]
|
||||
|
||||
replace_with = self.embedding_tokens if expand_token else self.trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {self.name} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
def save(self, filename):
|
||||
# todo check to see how to get the vector out of the embedding
|
||||
|
||||
embedding_data = {
|
||||
"string_to_token": {"*": 265},
|
||||
"string_to_param": {"*": self.vec},
|
||||
"name": self.name,
|
||||
"step": self.step,
|
||||
# todo get these
|
||||
"sd_checkpoint": None,
|
||||
"sd_checkpoint_name": None,
|
||||
"notes": None,
|
||||
}
|
||||
if filename.endswith('.pt'):
|
||||
torch.save(embedding_data, filename)
|
||||
elif filename.endswith('.bin'):
|
||||
torch.save(embedding_data, filename)
|
||||
elif filename.endswith('.safetensors'):
|
||||
# save the embedding as a safetensors file
|
||||
state_dict = {"emb_params": self.vec}
|
||||
# add all embedding data (except string_to_param), to metadata
|
||||
metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"})
|
||||
metadata["string_to_param"] = {"*": "emb_params"}
|
||||
save_meta = get_meta_for_safetensors(metadata, name=self.name)
|
||||
save_file(state_dict, filename, metadata=save_meta)
|
||||
|
||||
def load_embedding_from_file(self, file_path, device):
|
||||
# full path
|
||||
path = os.path.realpath(file_path)
|
||||
filename = os.path.basename(path)
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
_, second_ext = os.path.splitext(name)
|
||||
if second_ext.upper() == '.PREVIEW':
|
||||
return
|
||||
|
||||
if ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
elif ext in ['.SAFETENSORS']:
|
||||
# rebuild the embedding from the safetensors file if it has it
|
||||
tensors = {}
|
||||
with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
for k in f.keys():
|
||||
tensors[k] = f.get_tensor(k)
|
||||
# data = safetensors.torch.load_file(path, device="cpu")
|
||||
if metadata and 'string_to_param' in metadata and 'emb_params' in tensors:
|
||||
# our format
|
||||
def try_json(v):
|
||||
try:
|
||||
return json.loads(v)
|
||||
except:
|
||||
return v
|
||||
|
||||
data = {k: try_json(v) for k, v in metadata.items()}
|
||||
data['string_to_param'] = {'*': tensors['emb_params']}
|
||||
else:
|
||||
# old format
|
||||
data = tensors
|
||||
else:
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict,
|
||||
'_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
if 'step' in data:
|
||||
self.step = int(data['step'])
|
||||
|
||||
self.vec = emb.detach().to(device, dtype=torch.float32)
|
||||
@@ -435,7 +435,7 @@ class StableDiffusion:
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
latents.shape[0], # batch size
|
||||
1, # batch size
|
||||
)
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
@@ -506,6 +506,17 @@ class StableDiffusion:
|
||||
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
# check if we need to concat timesteps
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
ts_bs = timestep.shape[0]
|
||||
if ts_bs != latent_model_input.shape[0]:
|
||||
if ts_bs == 1:
|
||||
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
@@ -520,6 +531,11 @@ class StableDiffusion:
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
return noise_pred
|
||||
|
||||
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
||||
|
||||
Reference in New Issue
Block a user