Upgraded to dev for t2i on diffusers. Minor migrations to make it work.

This commit is contained in:
Jaret Burkett
2023-09-11 14:46:06 -06:00
parent 083cefa78c
commit e8583860ad
7 changed files with 356 additions and 12 deletions

View File

@@ -0,0 +1,193 @@
import os
import random
from collections import OrderedDict
from typing import List
import numpy as np
from PIL import Image
from diffusers import T2IAdapter
from torch.utils.data import DataLoader
from diffusers import StableDiffusionXLAdapterPipeline
from tqdm import tqdm
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.sampler import get_sampler
from toolkit.stable_diffusion_model import StableDiffusion
import gc
import torch
from jobs.process import BaseExtensionProcess
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.train_tools import get_torch_dtype
from controlnet_aux.midas import MidasDetector
from diffusers.utils import load_image
def flush():
torch.cuda.empty_cache()
gc.collect()
class GenerateConfig:
def __init__(self, **kwargs):
self.prompts: List[str]
self.sampler = kwargs.get('sampler', 'ddpm')
self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1)
self.walk_seed = kwargs.get('walk_seed', False)
self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.prompt_2 = kwargs.get('prompt_2', None)
self.neg_2 = kwargs.get('neg_2', None)
self.prompts = kwargs.get('prompts', None)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext = kwargs.get('ext', 'png')
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
if kwargs.get('shuffle', False):
# shuffle the prompts
random.shuffle(self.prompts)
class ReferenceGenerator(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.output_folder = self.get_conf('output_folder', required=True)
self.device = self.get_conf('device', 'cuda')
self.model_config = ModelConfig(**self.get_conf('model', required=True))
self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
self.is_latents_cached = True
raw_datasets = self.get_conf('datasets', None)
if raw_datasets is not None and len(raw_datasets) > 0:
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
self.dtype = self.get_conf('dtype', 'float16')
self.torch_dtype = get_torch_dtype(self.dtype)
self.params = []
if raw_datasets is not None and len(raw_datasets) > 0:
for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset)
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
if not is_caching:
self.is_latents_cached = False
if dataset.is_reg:
if self.datasets_reg is None:
self.datasets_reg = []
self.datasets_reg.append(dataset)
else:
if self.datasets is None:
self.datasets = []
self.datasets.append(dataset)
self.progress_bar = None
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
dtype=self.dtype,
)
print(f"Using device {self.device}")
self.data_loader: DataLoader = None
self.adapter: T2IAdapter = None
def run(self):
super().run()
print("Loading model...")
self.sd.load_model()
device = torch.device(self.device)
if self.generate_config.t2i_adapter_path is not None:
self.adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16"
).to(device)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to(device)
pipe = StableDiffusionXLAdapterPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=get_sampler(self.generate_config.sampler),
adapter=self.adapter,
).to(device)
pipe.set_progress_bar_config(disable=True)
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
num_batches = len(self.data_loader)
pbar = tqdm(total=num_batches, desc="Generating images")
seed = self.generate_config.seed
# load images from datasets, use tqdm
for i, batch in enumerate(self.data_loader):
batch: DataLoaderBatchDTO = batch
file_item: FileItemDTO = batch.file_items[0]
img_path = file_item.path
img_filename = os.path.basename(img_path)
img_filename_no_ext = os.path.splitext(img_filename)[0]
output_path = os.path.join(self.output_folder, img_filename)
output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
caption = batch.get_caption_list()[0]
img: torch.Tensor = batch.tensor.clone()
# image comes in -1 to 1. convert to a PIL RGB image
img = (img + 1) / 2
img = img.clamp(0, 1)
img = img[0].permute(1, 2, 0).cpu().numpy()
img = (img * 255).astype(np.uint8)
image = Image.fromarray(img)
width, height = image.size
min_res = min(width, height)
if self.generate_config.walk_seed:
seed = seed + 1
if self.generate_config.seed == -1:
# random
seed = random.randint(0, 1000000)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# generate depth map
image = midas_depth(
image,
detect_resolution=min_res, # do 512 ?
image_resolution=min_res
)
# image.save(output_depth_path)
gen_images = pipe(
prompt=caption,
negative_prompt=self.generate_config.neg,
image=image,
num_inference_steps=self.generate_config.sample_steps,
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
guidance_scale=self.generate_config.guidance_scale,
).images[0]
gen_images.save(output_path)
# save caption
with open(output_caption_path, 'w') as f:
f.write(caption)
pbar.update(1)
batch.cleanup()
pbar.close()
print("Done generating images")
# cleanup
del self.sd
gc.collect()
torch.cuda.empty_cache()

View File

@@ -0,0 +1,25 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# This is for generic training (LoRA, Dreambooth, FineTuning)
class AdvancedReferenceGeneratorExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "reference_generator"
# name is the name of the extension for printing
name = "Reference Generator"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .ReferenceGenerator import ReferenceGenerator
return ReferenceGenerator
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
AdvancedReferenceGeneratorExtension,
]

View File

@@ -0,0 +1,91 @@
---
job: extension
config:
name: test_v1
process:
- type: 'textual_inversion_trainer'
training_folder: "out/TI"
device: cuda:0
# for tensorboard logging
log_dir: "out/.tensorboard"
embedding:
trigger: "your_trigger_here"
tokens: 12
init_words: "man with short brown hair"
save_format: "safetensors" # 'safetensors' or 'pt'
save:
dtype: float16 # precision to save
save_every: 100 # save every this many steps
max_step_saves_to_keep: 5 # only affects step counts
datasets:
- folder_path: "/path/to/dataset"
caption_ext: "txt"
default_caption: "[trigger]"
buckets: true
resolution: 512
train:
noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
steps: 3000
weight_jitter: 0.0
lr: 5e-5
train_unet: false
gradient_checkpointing: true
train_text_encoder: false
optimizer: "adamw"
# optimizer: "prodigy"
optimizer_params:
weight_decay: 1e-2
lr_scheduler: "constant"
max_denoising_steps: 1000
batch_size: 4
dtype: bf16
xformers: true
min_snr_gamma: 5.0
# skip_first_sample: true
noise_offset: 0.0 # not needed for this
model:
# objective reality v2
name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
is_v2: false # for v2 models
is_xl: false # for SDXL models
is_v_pred: false # for v-prediction models (most v2 models)
sample:
sampler: "ddpm" # must match train.noise_scheduler
sample_every: 100 # sample every this many steps
width: 512
height: 512
prompts:
- "photo of [trigger] laughing"
- "photo of [trigger] smiling"
- "[trigger] close up"
- "dark scene [trigger] frozen"
- "[trigger] nighttime"
- "a painting of [trigger]"
- "a drawing of [trigger]"
- "a cartoon of [trigger]"
- "[trigger] pixar style"
- "[trigger] costume"
neg: ""
seed: 42
walk_seed: false
guidance_scale: 7
sample_steps: 20
network_multiplier: 1.0
logging:
log_every: 10 # log every this many steps
use_wandb: false # not supported yet
verbose: false
# You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want.
# It is saved in the model so be aware of that. The software will include this
# plus some other information for you automatically
meta:
# [name] gets replaced with the name above
name: "[name]"
# version: '1.0'
# creator:
# name: Your Name
# email: your@gmail.com
# website: https://your.website

View File

@@ -1,7 +1,7 @@
torch
torchvision
safetensors
diffusers
git+https://github.com/huggingface/diffusers.git
transformers
lycoris_lora
flatten_json
@@ -19,4 +19,5 @@ omegaconf
k-diffusion
open_clip_torch
timm
prodigyopt
prodigyopt
controlnet_aux==0.0.7

View File

@@ -77,7 +77,7 @@ class TrainConfig:
self.optimizer_params = kwargs.get('optimizer_params', {})
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)

View File

@@ -12,6 +12,7 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
from tqdm import tqdm
import albumentations as A
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
@@ -268,6 +269,37 @@ class PairedImageDataset(Dataset):
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
img_path = img_path_or_tuple[1]
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
# always use # 2 (pos)
bucket_resolution = get_bucket_for_image_size(
width=img2.width,
height=img2.height,
resolution=self.size
)
# images will be same base dimension, but may be trimmed. We need to shrink and then central crop
if bucket_resolution['width'] > bucket_resolution['height']:
img1_scale_to_height = bucket_resolution["height"]
img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height))
img2_scale_to_height = bucket_resolution["height"]
img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height))
else:
img1_scale_to_width = bucket_resolution["width"]
img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width))
img2_scale_to_width = bucket_resolution["width"]
img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width))
img1_crop_height = bucket_resolution["height"]
img1_crop_width = bucket_resolution["width"]
img2_crop_height = bucket_resolution["height"]
img2_crop_width = bucket_resolution["width"]
# scale then center crop images
img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC)
img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1)
img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC)
img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2)
# combine them side by side
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
img.paste(img1, (0, 0))
@@ -275,15 +307,14 @@ class PairedImageDataset(Dataset):
else:
img_path = img_path_or_tuple
img = exif_transpose(Image.open(img_path)).convert('RGB')
height = self.size
# determine width to keep aspect ratio
width = int(img.size[0] * height / img.size[1])
# Downscale the source image first
img = img.resize((width, height), Image.BICUBIC)
prompt = self.get_prompt_item(index)
height = self.size
# determine width to keep aspect ratio
width = int(img.size[0] * height / img.size[1])
# Downscale the source image first
img = img.resize((width, height), Image.BICUBIC)
img = self.transform(img)
return img, prompt, (self.neg_weight, self.pos_weight)

View File

@@ -122,11 +122,14 @@ class ToolkitModuleMixin:
return lx * scale
def forward(self: Module, x):
# this may get an additional positional arg or not
def forward(self: Module, x, *args, **kwargs):
# diffusers added scale to resnet.. not sure what it does
if self._multiplier is None:
self.set_multiplier(0.0)
org_forwarded = self.org_forward(x)
org_forwarded = self.org_forward(x, *args, **kwargs)
lora_output = self._call_forward(x)
multiplier = self._multiplier.clone().detach()