mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Upgraded to dev for t2i on diffusers. Minor migrations to make it work.
This commit is contained in:
193
extensions_built_in/advanced_generator/ReferenceGenerator.py
Normal file
193
extensions_built_in/advanced_generator/ReferenceGenerator.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from torch.utils.data import DataLoader
|
||||
from diffusers import StableDiffusionXLAdapterPipeline
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from controlnet_aux.midas import MidasDetector
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class GenerateConfig:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.prompts: List[str]
|
||||
self.sampler = kwargs.get('sampler', 'ddpm')
|
||||
self.neg = kwargs.get('neg', '')
|
||||
self.seed = kwargs.get('seed', -1)
|
||||
self.walk_seed = kwargs.get('walk_seed', False)
|
||||
self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
|
||||
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
||||
self.sample_steps = kwargs.get('sample_steps', 20)
|
||||
self.prompt_2 = kwargs.get('prompt_2', None)
|
||||
self.neg_2 = kwargs.get('neg_2', None)
|
||||
self.prompts = kwargs.get('prompts', None)
|
||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||
self.ext = kwargs.get('ext', 'png')
|
||||
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
||||
if kwargs.get('shuffle', False):
|
||||
# shuffle the prompts
|
||||
random.shuffle(self.prompts)
|
||||
|
||||
|
||||
class ReferenceGenerator(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.output_folder = self.get_conf('output_folder', required=True)
|
||||
self.device = self.get_conf('device', 'cuda')
|
||||
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
||||
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
|
||||
self.is_latents_cached = True
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
||||
self.datasets = None
|
||||
self.datasets_reg = None
|
||||
self.dtype = self.get_conf('dtype', 'float16')
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.params = []
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
for raw_dataset in raw_datasets:
|
||||
dataset = DatasetConfig(**raw_dataset)
|
||||
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
||||
if not is_caching:
|
||||
self.is_latents_cached = False
|
||||
if dataset.is_reg:
|
||||
if self.datasets_reg is None:
|
||||
self.datasets_reg = []
|
||||
self.datasets_reg.append(dataset)
|
||||
else:
|
||||
if self.datasets is None:
|
||||
self.datasets = []
|
||||
self.datasets.append(dataset)
|
||||
|
||||
self.progress_bar = None
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
print(f"Using device {self.device}")
|
||||
self.data_loader: DataLoader = None
|
||||
self.adapter: T2IAdapter = None
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print("Loading model...")
|
||||
self.sd.load_model()
|
||||
device = torch.device(self.device)
|
||||
|
||||
if self.generate_config.t2i_adapter_path is not None:
|
||||
self.adapter = T2IAdapter.from_pretrained(
|
||||
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16"
|
||||
).to(device)
|
||||
|
||||
midas_depth = MidasDetector.from_pretrained(
|
||||
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
|
||||
).to(device)
|
||||
|
||||
pipe = StableDiffusionXLAdapterPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder[0],
|
||||
text_encoder_2=self.sd.text_encoder[1],
|
||||
tokenizer=self.sd.tokenizer[0],
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=get_sampler(self.generate_config.sampler),
|
||||
adapter=self.adapter,
|
||||
).to(device)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
||||
|
||||
num_batches = len(self.data_loader)
|
||||
pbar = tqdm(total=num_batches, desc="Generating images")
|
||||
seed = self.generate_config.seed
|
||||
# load images from datasets, use tqdm
|
||||
for i, batch in enumerate(self.data_loader):
|
||||
batch: DataLoaderBatchDTO = batch
|
||||
|
||||
file_item: FileItemDTO = batch.file_items[0]
|
||||
img_path = file_item.path
|
||||
img_filename = os.path.basename(img_path)
|
||||
img_filename_no_ext = os.path.splitext(img_filename)[0]
|
||||
output_path = os.path.join(self.output_folder, img_filename)
|
||||
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
|
||||
output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
|
||||
|
||||
caption = batch.get_caption_list()[0]
|
||||
|
||||
img: torch.Tensor = batch.tensor.clone()
|
||||
# image comes in -1 to 1. convert to a PIL RGB image
|
||||
img = (img + 1) / 2
|
||||
img = img.clamp(0, 1)
|
||||
img = img[0].permute(1, 2, 0).cpu().numpy()
|
||||
img = (img * 255).astype(np.uint8)
|
||||
image = Image.fromarray(img)
|
||||
|
||||
width, height = image.size
|
||||
min_res = min(width, height)
|
||||
|
||||
if self.generate_config.walk_seed:
|
||||
seed = seed + 1
|
||||
|
||||
if self.generate_config.seed == -1:
|
||||
# random
|
||||
seed = random.randint(0, 1000000)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# generate depth map
|
||||
image = midas_depth(
|
||||
image,
|
||||
detect_resolution=min_res, # do 512 ?
|
||||
image_resolution=min_res
|
||||
)
|
||||
|
||||
# image.save(output_depth_path)
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=caption,
|
||||
negative_prompt=self.generate_config.neg,
|
||||
image=image,
|
||||
num_inference_steps=self.generate_config.sample_steps,
|
||||
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
).images[0]
|
||||
gen_images.save(output_path)
|
||||
|
||||
# save caption
|
||||
with open(output_caption_path, 'w') as f:
|
||||
f.write(caption)
|
||||
|
||||
pbar.update(1)
|
||||
batch.cleanup()
|
||||
|
||||
pbar.close()
|
||||
print("Done generating images")
|
||||
# cleanup
|
||||
del self.sd
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
25
extensions_built_in/advanced_generator/__init__.py
Normal file
25
extensions_built_in/advanced_generator/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class AdvancedReferenceGeneratorExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "reference_generator"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Reference Generator"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .ReferenceGenerator import ReferenceGenerator
|
||||
return ReferenceGenerator
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
AdvancedReferenceGeneratorExtension,
|
||||
]
|
||||
@@ -0,0 +1,91 @@
|
||||
---
|
||||
job: extension
|
||||
config:
|
||||
name: test_v1
|
||||
process:
|
||||
- type: 'textual_inversion_trainer'
|
||||
training_folder: "out/TI"
|
||||
device: cuda:0
|
||||
# for tensorboard logging
|
||||
log_dir: "out/.tensorboard"
|
||||
embedding:
|
||||
trigger: "your_trigger_here"
|
||||
tokens: 12
|
||||
init_words: "man with short brown hair"
|
||||
save_format: "safetensors" # 'safetensors' or 'pt'
|
||||
save:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 100 # save every this many steps
|
||||
max_step_saves_to_keep: 5 # only affects step counts
|
||||
datasets:
|
||||
- folder_path: "/path/to/dataset"
|
||||
caption_ext: "txt"
|
||||
default_caption: "[trigger]"
|
||||
buckets: true
|
||||
resolution: 512
|
||||
train:
|
||||
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
|
||||
steps: 3000
|
||||
weight_jitter: 0.0
|
||||
lr: 5e-5
|
||||
train_unet: false
|
||||
gradient_checkpointing: true
|
||||
train_text_encoder: false
|
||||
optimizer: "adamw"
|
||||
# optimizer: "prodigy"
|
||||
optimizer_params:
|
||||
weight_decay: 1e-2
|
||||
lr_scheduler: "constant"
|
||||
max_denoising_steps: 1000
|
||||
batch_size: 4
|
||||
dtype: bf16
|
||||
xformers: true
|
||||
min_snr_gamma: 5.0
|
||||
# skip_first_sample: true
|
||||
noise_offset: 0.0 # not needed for this
|
||||
model:
|
||||
# objective reality v2
|
||||
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
|
||||
is_v2: false # for v2 models
|
||||
is_xl: false # for SDXL models
|
||||
is_v_pred: false # for v-prediction models (most v2 models)
|
||||
sample:
|
||||
sampler: "ddpm" # must match train.noise_scheduler
|
||||
sample_every: 100 # sample every this many steps
|
||||
width: 512
|
||||
height: 512
|
||||
prompts:
|
||||
- "photo of [trigger] laughing"
|
||||
- "photo of [trigger] smiling"
|
||||
- "[trigger] close up"
|
||||
- "dark scene [trigger] frozen"
|
||||
- "[trigger] nighttime"
|
||||
- "a painting of [trigger]"
|
||||
- "a drawing of [trigger]"
|
||||
- "a cartoon of [trigger]"
|
||||
- "[trigger] pixar style"
|
||||
- "[trigger] costume"
|
||||
neg: ""
|
||||
seed: 42
|
||||
walk_seed: false
|
||||
guidance_scale: 7
|
||||
sample_steps: 20
|
||||
network_multiplier: 1.0
|
||||
|
||||
logging:
|
||||
log_every: 10 # log every this many steps
|
||||
use_wandb: false # not supported yet
|
||||
verbose: false
|
||||
|
||||
# You can put any information you want here, and it will be saved in the model.
|
||||
# The below is an example, but you can put your grocery list in it if you want.
|
||||
# It is saved in the model so be aware of that. The software will include this
|
||||
# plus some other information for you automatically
|
||||
meta:
|
||||
# [name] gets replaced with the name above
|
||||
name: "[name]"
|
||||
# version: '1.0'
|
||||
# creator:
|
||||
# name: Your Name
|
||||
# email: your@gmail.com
|
||||
# website: https://your.website
|
||||
@@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchvision
|
||||
safetensors
|
||||
diffusers
|
||||
git+https://github.com/huggingface/diffusers.git
|
||||
transformers
|
||||
lycoris_lora
|
||||
flatten_json
|
||||
@@ -19,4 +19,5 @@ omegaconf
|
||||
k-diffusion
|
||||
open_clip_torch
|
||||
timm
|
||||
prodigyopt
|
||||
prodigyopt
|
||||
controlnet_aux==0.0.7
|
||||
@@ -77,7 +77,7 @@ class TrainConfig:
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
|
||||
@@ -12,6 +12,7 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
@@ -268,6 +269,37 @@ class PairedImageDataset(Dataset):
|
||||
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
img_path = img_path_or_tuple[1]
|
||||
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
|
||||
# always use # 2 (pos)
|
||||
bucket_resolution = get_bucket_for_image_size(
|
||||
width=img2.width,
|
||||
height=img2.height,
|
||||
resolution=self.size
|
||||
)
|
||||
|
||||
# images will be same base dimension, but may be trimmed. We need to shrink and then central crop
|
||||
if bucket_resolution['width'] > bucket_resolution['height']:
|
||||
img1_scale_to_height = bucket_resolution["height"]
|
||||
img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height))
|
||||
img2_scale_to_height = bucket_resolution["height"]
|
||||
img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height))
|
||||
else:
|
||||
img1_scale_to_width = bucket_resolution["width"]
|
||||
img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width))
|
||||
img2_scale_to_width = bucket_resolution["width"]
|
||||
img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width))
|
||||
|
||||
img1_crop_height = bucket_resolution["height"]
|
||||
img1_crop_width = bucket_resolution["width"]
|
||||
img2_crop_height = bucket_resolution["height"]
|
||||
img2_crop_width = bucket_resolution["width"]
|
||||
|
||||
# scale then center crop images
|
||||
img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC)
|
||||
img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1)
|
||||
img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC)
|
||||
img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2)
|
||||
|
||||
# combine them side by side
|
||||
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
|
||||
img.paste(img1, (0, 0))
|
||||
@@ -275,15 +307,14 @@ class PairedImageDataset(Dataset):
|
||||
else:
|
||||
img_path = img_path_or_tuple
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
height = self.size
|
||||
# determine width to keep aspect ratio
|
||||
width = int(img.size[0] * height / img.size[1])
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((width, height), Image.BICUBIC)
|
||||
|
||||
prompt = self.get_prompt_item(index)
|
||||
|
||||
height = self.size
|
||||
# determine width to keep aspect ratio
|
||||
width = int(img.size[0] * height / img.size[1])
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((width, height), Image.BICUBIC)
|
||||
img = self.transform(img)
|
||||
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
@@ -122,11 +122,14 @@ class ToolkitModuleMixin:
|
||||
|
||||
return lx * scale
|
||||
|
||||
def forward(self: Module, x):
|
||||
# this may get an additional positional arg or not
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
# diffusers added scale to resnet.. not sure what it does
|
||||
if self._multiplier is None:
|
||||
self.set_multiplier(0.0)
|
||||
|
||||
org_forwarded = self.org_forward(x)
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
lora_output = self._call_forward(x)
|
||||
multiplier = self._multiplier.clone().detach()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user