mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Upgraded to dev for t2i on diffusers. Minor migrations to make it work.
This commit is contained in:
@@ -77,7 +77,7 @@ class TrainConfig:
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
|
||||
@@ -12,6 +12,7 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
@@ -268,6 +269,37 @@ class PairedImageDataset(Dataset):
|
||||
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
img_path = img_path_or_tuple[1]
|
||||
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
|
||||
# always use # 2 (pos)
|
||||
bucket_resolution = get_bucket_for_image_size(
|
||||
width=img2.width,
|
||||
height=img2.height,
|
||||
resolution=self.size
|
||||
)
|
||||
|
||||
# images will be same base dimension, but may be trimmed. We need to shrink and then central crop
|
||||
if bucket_resolution['width'] > bucket_resolution['height']:
|
||||
img1_scale_to_height = bucket_resolution["height"]
|
||||
img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height))
|
||||
img2_scale_to_height = bucket_resolution["height"]
|
||||
img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height))
|
||||
else:
|
||||
img1_scale_to_width = bucket_resolution["width"]
|
||||
img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width))
|
||||
img2_scale_to_width = bucket_resolution["width"]
|
||||
img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width))
|
||||
|
||||
img1_crop_height = bucket_resolution["height"]
|
||||
img1_crop_width = bucket_resolution["width"]
|
||||
img2_crop_height = bucket_resolution["height"]
|
||||
img2_crop_width = bucket_resolution["width"]
|
||||
|
||||
# scale then center crop images
|
||||
img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC)
|
||||
img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1)
|
||||
img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC)
|
||||
img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2)
|
||||
|
||||
# combine them side by side
|
||||
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
|
||||
img.paste(img1, (0, 0))
|
||||
@@ -275,15 +307,14 @@ class PairedImageDataset(Dataset):
|
||||
else:
|
||||
img_path = img_path_or_tuple
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
height = self.size
|
||||
# determine width to keep aspect ratio
|
||||
width = int(img.size[0] * height / img.size[1])
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((width, height), Image.BICUBIC)
|
||||
|
||||
prompt = self.get_prompt_item(index)
|
||||
|
||||
height = self.size
|
||||
# determine width to keep aspect ratio
|
||||
width = int(img.size[0] * height / img.size[1])
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((width, height), Image.BICUBIC)
|
||||
img = self.transform(img)
|
||||
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
@@ -122,11 +122,14 @@ class ToolkitModuleMixin:
|
||||
|
||||
return lx * scale
|
||||
|
||||
def forward(self: Module, x):
|
||||
# this may get an additional positional arg or not
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
# diffusers added scale to resnet.. not sure what it does
|
||||
if self._multiplier is None:
|
||||
self.set_multiplier(0.0)
|
||||
|
||||
org_forwarded = self.org_forward(x)
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
lora_output = self._call_forward(x)
|
||||
multiplier = self._multiplier.clone().detach()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user