Upgraded to dev for t2i on diffusers. Minor migrations to make it work.

This commit is contained in:
Jaret Burkett
2023-09-11 14:46:06 -06:00
parent 083cefa78c
commit e8583860ad
7 changed files with 356 additions and 12 deletions

View File

@@ -77,7 +77,7 @@ class TrainConfig:
self.optimizer_params = kwargs.get('optimizer_params', {})
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 50)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)

View File

@@ -12,6 +12,7 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
from tqdm import tqdm
import albumentations as A
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
@@ -268,6 +269,37 @@ class PairedImageDataset(Dataset):
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
img_path = img_path_or_tuple[1]
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
# always use # 2 (pos)
bucket_resolution = get_bucket_for_image_size(
width=img2.width,
height=img2.height,
resolution=self.size
)
# images will be same base dimension, but may be trimmed. We need to shrink and then central crop
if bucket_resolution['width'] > bucket_resolution['height']:
img1_scale_to_height = bucket_resolution["height"]
img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height))
img2_scale_to_height = bucket_resolution["height"]
img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height))
else:
img1_scale_to_width = bucket_resolution["width"]
img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width))
img2_scale_to_width = bucket_resolution["width"]
img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width))
img1_crop_height = bucket_resolution["height"]
img1_crop_width = bucket_resolution["width"]
img2_crop_height = bucket_resolution["height"]
img2_crop_width = bucket_resolution["width"]
# scale then center crop images
img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC)
img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1)
img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC)
img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2)
# combine them side by side
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
img.paste(img1, (0, 0))
@@ -275,15 +307,14 @@ class PairedImageDataset(Dataset):
else:
img_path = img_path_or_tuple
img = exif_transpose(Image.open(img_path)).convert('RGB')
height = self.size
# determine width to keep aspect ratio
width = int(img.size[0] * height / img.size[1])
# Downscale the source image first
img = img.resize((width, height), Image.BICUBIC)
prompt = self.get_prompt_item(index)
height = self.size
# determine width to keep aspect ratio
width = int(img.size[0] * height / img.size[1])
# Downscale the source image first
img = img.resize((width, height), Image.BICUBIC)
img = self.transform(img)
return img, prompt, (self.neg_weight, self.pos_weight)

View File

@@ -122,11 +122,14 @@ class ToolkitModuleMixin:
return lx * scale
def forward(self: Module, x):
# this may get an additional positional arg or not
def forward(self: Module, x, *args, **kwargs):
# diffusers added scale to resnet.. not sure what it does
if self._multiplier is None:
self.set_multiplier(0.0)
org_forwarded = self.org_forward(x)
org_forwarded = self.org_forward(x, *args, **kwargs)
lora_output = self._call_forward(x)
multiplier = self._multiplier.clone().detach()