Bug fixes, work on maing IP adapters more customizable.

This commit is contained in:
Jaret Burkett
2023-12-24 08:32:39 -07:00
parent 7703e3a15e
commit 0f8daa5612
7 changed files with 243 additions and 36 deletions

View File

@@ -8,6 +8,7 @@ from typing import List, TYPE_CHECKING
import cv2
import numpy as np
import torch
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
@@ -24,6 +25,45 @@ if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
class RescaleTransform:
"""Transform to rescale images to the range [-1, 1]."""
def __call__(self, image):
return image * 2 - 1
class NormalizeSDXLTransform:
"""
Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
Mean: tensor([ 0.0002, -0.1034, -0.1879])
Standard Deviation: tensor([0.5436, 0.5116, 0.5033])
"""
def __call__(self, image):
return transforms.Normalize(
mean=[0.0002, -0.1034, -0.1879],
std=[0.5436, 0.5116, 0.5033],
)(image)
class NormalizeSD15Transform:
"""
Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
Mean: tensor([-0.1600, -0.2450, -0.3227])
Standard Deviation: tensor([0.5319, 0.4997, 0.5139])
"""
def __call__(self, image):
return transforms.Normalize(
mean=[-0.1600, -0.2450, -0.3227],
std=[0.5319, 0.4997, 0.5139],
)(image)
class ImageDataset(Dataset, CaptionMixin):
def __init__(self, config):
self.config = config
@@ -63,7 +103,7 @@ class ImageDataset(Dataset, CaptionMixin):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
RescaleTransform(),
])
def get_config(self, key, default=None, required=False):
@@ -200,7 +240,7 @@ class PairedImageDataset(Dataset):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
RescaleTransform(),
])
def get_all_prompts(self):
@@ -368,6 +408,23 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
# repeat the list
file_list = file_list * self.dataset_config.num_repeats
if self.dataset_config.standardize_images:
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
NormalizeMethod = NormalizeSDXLTransform
else:
NormalizeMethod = NormalizeSD15Transform
self.transform = transforms.Compose([
transforms.ToTensor(),
RescaleTransform(),
NormalizeMethod(),
])
else:
self.transform = transforms.Compose([
transforms.ToTensor(),
RescaleTransform(),
])
# this might take a while
print(f" - Preprocessing image dimensions")
bad_count = 0
@@ -375,7 +432,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
try:
file_item = FileItemDTO(
path=file,
dataset_config=dataset_config
dataset_config=dataset_config,
dataloader_transforms=self.transform,
)
self.file_list.append(file_item)
except Exception as e:
@@ -411,10 +469,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
if self.dataset_config.flip_x or self.dataset_config.flip_y:
print(f" - Found {len(self.file_list)} images after adding flips")
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
self.setup_epoch()