mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
Bug fixes, work on maing IP adapters more customizable.
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import List, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torchvision import transforms
|
||||
@@ -24,6 +25,45 @@ if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class RescaleTransform:
|
||||
"""Transform to rescale images to the range [-1, 1]."""
|
||||
|
||||
def __call__(self, image):
|
||||
return image * 2 - 1
|
||||
|
||||
|
||||
class NormalizeSDXLTransform:
|
||||
"""
|
||||
Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
|
||||
|
||||
Mean: tensor([ 0.0002, -0.1034, -0.1879])
|
||||
Standard Deviation: tensor([0.5436, 0.5116, 0.5033])
|
||||
"""
|
||||
|
||||
def __call__(self, image):
|
||||
return transforms.Normalize(
|
||||
mean=[0.0002, -0.1034, -0.1879],
|
||||
std=[0.5436, 0.5116, 0.5033],
|
||||
)(image)
|
||||
|
||||
|
||||
class NormalizeSD15Transform:
|
||||
"""
|
||||
Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
|
||||
|
||||
Mean: tensor([-0.1600, -0.2450, -0.3227])
|
||||
Standard Deviation: tensor([0.5319, 0.4997, 0.5139])
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, image):
|
||||
return transforms.Normalize(
|
||||
mean=[-0.1600, -0.2450, -0.3227],
|
||||
std=[0.5319, 0.4997, 0.5139],
|
||||
)(image)
|
||||
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
@@ -63,7 +103,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
RescaleTransform(),
|
||||
])
|
||||
|
||||
def get_config(self, key, default=None, required=False):
|
||||
@@ -200,7 +240,7 @@ class PairedImageDataset(Dataset):
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
RescaleTransform(),
|
||||
])
|
||||
|
||||
def get_all_prompts(self):
|
||||
@@ -368,6 +408,23 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
# repeat the list
|
||||
file_list = file_list * self.dataset_config.num_repeats
|
||||
|
||||
if self.dataset_config.standardize_images:
|
||||
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
|
||||
NormalizeMethod = NormalizeSDXLTransform
|
||||
else:
|
||||
NormalizeMethod = NormalizeSD15Transform
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
RescaleTransform(),
|
||||
NormalizeMethod(),
|
||||
])
|
||||
else:
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
RescaleTransform(),
|
||||
])
|
||||
|
||||
# this might take a while
|
||||
print(f" - Preprocessing image dimensions")
|
||||
bad_count = 0
|
||||
@@ -375,7 +432,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
try:
|
||||
file_item = FileItemDTO(
|
||||
path=file,
|
||||
dataset_config=dataset_config
|
||||
dataset_config=dataset_config,
|
||||
dataloader_transforms=self.transform,
|
||||
)
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
@@ -411,10 +469,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
if self.dataset_config.flip_x or self.dataset_config.flip_y:
|
||||
print(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user