Hude rework to move the batch to a DTO to make it far more modular to the future ui

This commit is contained in:
Jaret Burkett
2023-08-29 10:22:19 -06:00
parent bd758ff203
commit 714854ee86
10 changed files with 286 additions and 232 deletions

View File

@@ -35,6 +35,7 @@ class SDTrainer(BaseSDTrainProcess):
def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
self.optimizer.zero_grad()
flush()
@@ -53,6 +54,9 @@ class SDTrainer(BaseSDTrainProcess):
else:
network = BlankNetwork()
# set the weights
network.multiplier = network_weight_list
# activate network if it exits
with network:
with torch.set_grad_enabled(grad_on_text_encoder):
@@ -114,5 +118,7 @@ class SDTrainer(BaseSDTrainProcess):
loss_dict = OrderedDict(
{'loss': loss.item()}
)
# reset network multiplier
network.multiplier = 1.0
return loss_dict

View File

@@ -19,7 +19,7 @@ config:
max_step_saves_to_keep: 5 # only affects step counts
datasets:
- folder_path: "/path/to/dataset"
caption_type: "txt"
caption_ext: "txt"
default_caption: "[trigger]"
buckets: true
resolution: 512

View File

@@ -6,6 +6,7 @@ from typing import Union
from torch.utils.data import DataLoader
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -23,7 +24,7 @@ import torch
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig, EmbeddingConfig, DatasetConfig
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config
def flush():
@@ -67,6 +68,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.trigger_word = self.get_conf('trigger_word', None)
raw_datasets = self.get_conf('datasets', None)
if raw_datasets is not None and len(raw_datasets) > 0:
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
if raw_datasets is not None and len(raw_datasets) > 0:
@@ -94,6 +97,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.model_config.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
self.sd = StableDiffusion(
device=self.device,
@@ -307,16 +316,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
def process_general_training_batch(self, batch):
with torch.no_grad():
imgs, prompts, dataset_config = batch
# convert the 0 or 1 for is reg to a bool list
if isinstance(dataset_config, list):
is_reg_list = [x.get('is_reg', 0) for x in dataset_config]
else:
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(is_reg_list, torch.Tensor):
is_reg_list = is_reg_list.numpy().tolist()
is_reg_list = [bool(x) for x in is_reg_list]
imgs = batch.tensor
prompts = batch.get_caption_list()
is_reg_list = batch.get_is_reg_list()
conditioned_prompts = []
@@ -473,6 +475,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# resume state from embedding
self.step_num = self.embedding.step
self.start_step = self.step_num
# set trainable params
params = self.embedding.get_trainable_params()
@@ -556,13 +559,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
with torch.no_grad():
# if is even step and we have a reg dataset, use that
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
if step % 2 == 0 and dataloader_reg is not None:
is_reg_step = False
is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
# don't do a reg step on sample or save steps as we dont want to normalize on those
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
try:
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator_reg = iter(dataloader_reg)
batch = next(dataloader_iterator_reg)
is_reg_step = True
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
@@ -601,11 +609,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.step_num != self.start_step:
# pause progress bar
self.progress_bar.unpause() # makes it so doesn't track time
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
if is_sample_step:
# print above the progress bar
self.sample(self.step_num)
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
if is_save_step:
# print above the progress bar
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
@@ -623,10 +631,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
# end of step
self.step_num = step
# apply network normalizer if we are using it
if self.network is not None and self.network.is_normalizing:
# apply network normalizer if we are using it, not on regularization steps
if self.network is not None and self.network.is_normalizing and not is_reg_step:
self.network.apply_stored_normalizer()
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
batch.cleanup()
self.sample(self.step_num + 1)
print("")
self.save()

View File

@@ -1,76 +0,0 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import time
from collections import OrderedDict
import os
from toolkit.config_modules import SliderConfig
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
import torch
from leco import train_util, model_util
from leco.prompt_util import PromptEmbedsCache
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
def flush():
torch.cuda.empty_cache()
gc.collect()
class LoRAHack:
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'suppression')
class TrainLoRAHack(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
def hook_before_train_loop(self):
# we don't need text encoder so move it to cpu
self.sd.text_encoder.to("cpu")
flush()
# end hook_before_train_loop
if self.hack_config.type == 'suppression':
# set all params to self.current_suppression
params = self.network.parameters()
for param in params:
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = noise * 0.001
def supress_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
loss_dict = OrderedDict(
{'sup': 0.0}
)
# increase noise
for param in self.network.parameters():
# get random noise for each param
noise = torch.randn_like(param) - 0.5
# apply noise to param
param.data = param.data + noise * 0.001
return loss_dict
def hook_train_loop(self, batch):
if self.hack_config.type == 'suppression':
return self.supress_loop()
else:
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
# end hook_train_loop

View File

@@ -7,7 +7,6 @@ from .TrainVAEProcess import TrainVAEProcess
from .BaseMergeProcess import BaseMergeProcess
from .TrainSliderProcess import TrainSliderProcess
from .TrainSliderProcessOld import TrainSliderProcessOld
from .TrainLoRAHack import TrainLoRAHack
from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .ModRescaleLoraProcess import ModRescaleLoraProcess
from .GenerateProcess import GenerateProcess

View File

@@ -22,7 +22,7 @@ batch_size = 4
dataset_config = DatasetConfig(
folder_path=dataset_folder,
resolution=resolution,
caption_type='txt',
caption_ext='txt',
default_caption='default',
buckets=True,
bucket_tolerance=bucket_tolerance,

View File

@@ -5,6 +5,7 @@ import random
ImgExt = Literal['jpg', 'png', 'webp']
class SaveConfig:
def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000)
@@ -167,9 +168,13 @@ class DatasetConfig:
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'image') # sd, slider, reference
# will be legacy
self.folder_path: str = kwargs.get('folder_path', None)
# can be json or folder path
self.dataset_path: str = kwargs.get('dataset_path', None)
self.default_caption: str = kwargs.get('default_caption', None)
self.caption_type: str = kwargs.get('caption_type', None)
self.caption_ext: str = kwargs.get('caption_ext', None)
self.random_scale: bool = kwargs.get('random_scale', False)
self.random_crop: bool = kwargs.get('random_crop', False)
self.resolution: int = kwargs.get('resolution', 512)
@@ -182,6 +187,33 @@ class DatasetConfig:
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
# legacy compatability
legacy_caption_type = kwargs.get('caption_type', None)
if legacy_caption_type:
self.caption_ext = legacy_caption_type
self.caption_type = self.caption_ext
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
"""
This just splits up the datasets by resolutions so you dont have to do it manually
:param raw_config:
:return:
"""
# split up datasets by resolutions
new_config = []
for dataset in raw_config:
resolution = dataset.get('resolution', 512)
if isinstance(resolution, list):
resolution_list = resolution
else:
resolution_list = [resolution]
for res in resolution_list:
dataset_copy = dataset.copy()
dataset_copy['resolution'] = res
new_config.append(dataset_copy)
return new_config
class GenerateImageConfig:
def __init__(

View File

@@ -1,3 +1,4 @@
import json
import os
import random
from typing import List
@@ -13,10 +14,9 @@ from tqdm import tqdm
import albumentations as A
from toolkit import image_utils
from toolkit.config_modules import DatasetConfig
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
class ImageDataset(Dataset, CaptionMixin):
@@ -29,7 +29,7 @@ class ImageDataset(Dataset, CaptionMixin):
self.include_prompt = self.get_config('include_prompt', False)
self.default_prompt = self.get_config('default_prompt', '')
if self.include_prompt:
self.caption_type = self.get_config('caption_type', 'txt')
self.caption_type = self.get_config('caption_ext', 'txt')
else:
self.caption_type = None
# we always random crop if random scale is enabled
@@ -288,24 +288,17 @@ class PairedImageDataset(Dataset):
return img, prompt, (self.neg_weight, self.pos_weight)
printed_messages = []
def print_once(msg):
global printed_messages
if msg not in printed_messages:
print(msg)
printed_messages.append(msg)
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
super().__init__()
self.dataset_config = dataset_config
self.folder_path = dataset_config.folder_path
self.caption_type = dataset_config.caption_type
folder_path = dataset_config.folder_path
self.dataset_path = dataset_config.dataset_path
if self.dataset_path is None:
self.dataset_path = folder_path
self.caption_type = dataset_config.caption_ext
self.default_caption = dataset_config.default_caption
self.random_scale = dataset_config.random_scale
self.scale = dataset_config.scale
@@ -313,147 +306,96 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
self.resolution = dataset_config.resolution
self.caption_dict = None
self.file_list: List['FileItemDTO'] = []
# get the file list
file_list = [
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
# check if dataset_path is a folder or json
if os.path.isdir(self.dataset_path):
file_list = [
os.path.join(self.dataset_path, file) for file in os.listdir(self.dataset_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
else:
# assume json
with open(self.dataset_path, 'r') as f:
self.caption_dict = json.load(f)
# keys are file paths
file_list = list(self.caption_dict.keys())
# this might take a while
print(f" - Preprocessing image dimensions")
bad_count = 0
for file in tqdm(file_list):
try:
w, h = image_utils.get_image_size(file)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
img = Image.open(file)
h, w = img.size
# TODO allow smaller images
if int(min(h, w) * self.scale) >= self.resolution:
self.file_list.append(
FileItemDTO(
path=file,
width=w,
height=h,
scale_to_width=int(w * self.scale),
scale_to_height=int(h * self.scale),
dataset_config=dataset_config
)
)
else:
file_item = FileItemDTO(
path=file,
dataset_config=dataset_config
)
if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
bad_count += 1
else:
self.file_list.append(file_item)
print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
self.setup_epoch()
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
def setup_epoch(self):
# TODO: set this up to redo cropping and everything else
# do not call for now
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
def __len__(self):
if self.dataset_config.buckets:
return len(self.batch_indices)
return len(self.file_list)
def _get_single_item(self, index):
def _get_single_item(self, index) -> 'FileItemDTO':
file_item = self.file_list[index]
# todo make sure this matches
img = exif_transpose(Image.open(file_item.path)).convert('RGB')
w, h = img.size
if w > h and file_item.scale_to_width < file_item.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
elif h > w and file_item.scale_to_height < file_item.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
# Downscale the source image first
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
min_img_size = min(img.size)
if self.dataset_config.buckets:
# todo allow scaling and cropping, will be hard to add
# scale and crop based on file item
img = img.resize((file_item.scale_to_width, file_item.scale_to_height), Image.BICUBIC)
img = transforms.CenterCrop((file_item.crop_height, file_item.crop_width))(img)
else:
if self.random_crop:
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file_item.path}")
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
img = self.transform(img)
# todo convert it all
dataset_config_dict = {
"is_reg": 1 if self.dataset_config.is_reg else 0,
}
if self.caption_type is not None:
prompt = self.get_caption_item(index)
return img, prompt, dataset_config_dict
else:
return img, dataset_config_dict
file_item.load_and_process_image(self.transform)
file_item.load_caption(self.caption_dict)
return file_item
def __getitem__(self, item):
if self.dataset_config.buckets:
# for buckets we collate ourselves for now
# todo allow a scheduler to dynamically make buckets
# we collate ourselves
idx_list = self.batch_indices[item]
tensor_list = []
prompt_list = []
dataset_config_dict_list = []
for idx in idx_list:
if self.caption_type is not None:
img, prompt, dataset_config_dict = self._get_single_item(idx)
prompt_list.append(prompt)
dataset_config_dict_list.append(dataset_config_dict)
else:
img, dataset_config_dict = self._get_single_item(idx)
dataset_config_dict_list.append(dataset_config_dict)
tensor_list.append(img.unsqueeze(0))
if self.caption_type is not None:
return torch.cat(tensor_list, dim=0), prompt_list, dataset_config_dict_list
else:
return torch.cat(tensor_list, dim=0), dataset_config_dict_list
return [self._get_single_item(idx) for idx in idx_list]
else:
# Dataloader is batching
return self._get_single_item(item)
def get_dataloader_from_datasets(dataset_options, batch_size=1):
# TODO do bucketing
if dataset_options is None or len(dataset_options) == 0:
return None
datasets = []
has_buckets = False
dataset_config_list = []
# preprocess them all
for dataset_option in dataset_options:
if isinstance(dataset_option, DatasetConfig):
config = dataset_option
dataset_config_list.append(dataset_option)
else:
config = DatasetConfig(**dataset_option)
# preprocess raw data
split_configs = preprocess_dataset_raw_config([dataset_option])
for x in split_configs:
dataset_config_list.append(DatasetConfig(**x))
for config in dataset_config_list:
if config.type == 'image':
dataset = AiToolkitDataset(config, batch_size=batch_size)
datasets.append(dataset)
@@ -463,21 +405,28 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1):
raise ValueError(f"invalid dataset type: {config.type}")
concatenated_dataset = ConcatDataset(datasets)
# todo build scheduler that can get buckets from all datasets that match
# todo and evenly distribute reg images
def dto_collation(batch: List['FileItemDTO']):
# create DTO batch
batch = DataLoaderBatchDTO(
file_items=batch
)
return batch
if has_buckets:
# make sure they all have buckets
for dataset in datasets:
assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none"
def custom_collate_fn(batch):
# just return as is
return batch
data_loader = DataLoader(
concatenated_dataset,
batch_size=None, # we batch in the dataloader
batch_size=None, # we batch in the datasets for now
drop_last=False,
shuffle=True,
collate_fn=custom_collate_fn, # Use the custom collate function
collate_fn=dto_collation, # Use the custom collate function
num_workers=2
)
else:
@@ -485,6 +434,7 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1):
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
num_workers=2,
collate_fn=dto_collation
)
return data_loader

View File

@@ -1,36 +1,84 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Union
import torch
import random
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin
from PIL import Image
from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
printed_messages = []
class FileItemDTO(CaptionProcessingDTOMixin):
def print_once(msg):
global printed_messages
if msg not in printed_messages:
print(msg)
printed_messages.append(msg)
class FileItemDTO(CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
def __init__(self, **kwargs):
self.path = kwargs.get('path', None)
self.caption_path: str = kwargs.get('caption_path', None)
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
# process width and height
try:
w, h = image_utils.get_image_size(self.path)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
img = exif_transpose(Image.open(self.path))
h, w = img.size
self.width: int = w
self.height: int = h
# self.caption_path: str = kwargs.get('caption_path', None)
self.raw_caption: str = kwargs.get('raw_caption', None)
self.width: int = kwargs.get('width', None)
self.height: int = kwargs.get('height', None)
# we scale first, then crop
self.scale_to_width: int = kwargs.get('scale_to_width', self.width)
self.scale_to_height: int = kwargs.get('scale_to_height', self.height)
self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale))
self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale))
# crop values are from scaled size
self.crop_x: int = kwargs.get('crop_x', 0)
self.crop_y: int = kwargs.get('crop_y', 0)
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
# process config
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.network_weight: float = self.dataset_config.network_weight
self.is_reg = self.dataset_config.is_reg
self.tensor: Union[torch.Tensor, None] = None
self.network_network_weight: float = self.dataset_config.network_weight
def cleanup(self):
self.tensor = None
class DataLoaderBatchDTO:
def __init__(self, **kwargs):
self.file_item: 'FileItemDTO' = kwargs.get('file_item', None)
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]
def get_network_weight_list(self):
return [x.network_weight for x in self.file_items]
def get_caption_list(
self,
trigger=None,
to_replace_list=None,
add_if_not_present=True
):
return [x.get_caption(
trigger=trigger,
to_replace_list=to_replace_list,
add_if_not_present=add_if_not_present
) for x in self.file_items]
def cleanup(self):
self.tensor = None
for file_item in self.file_items:
file_item.cleanup()

View File

@@ -1,8 +1,11 @@
import os
import random
from typing import TYPE_CHECKING, List, Dict
from typing import TYPE_CHECKING, List, Dict, Union
from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms
from PIL import Image
from PIL.ImageOps import exif_transpose
if TYPE_CHECKING:
from toolkit.data_loader import AiToolkitDataset
@@ -159,6 +162,38 @@ class BucketsMixin:
class CaptionProcessingDTOMixin:
# todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
if self.raw_caption is not None:
# we already loaded it
pass
elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]:
self.raw_caption = caption_dict[self.path]["caption"]
else:
# see if prompt file exists
path_no_ext = os.path.splitext(self.path)[0]
prompt_ext = self.dataset_config.caption_ext
prompt_path = f"{path_no_ext}.{prompt_ext}"
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
# remove any newlines
prompt = prompt.replace('\n', ', ')
# remove new lines for all operating systems
prompt = prompt.replace('\r', ', ')
prompt_split = prompt.split(',')
# remove empty strings
prompt_split = [p.strip() for p in prompt_split if p.strip()]
# join back together
prompt = ', '.join(prompt_split)
else:
prompt = ''
if self.dataset_config.default_caption is not None:
prompt = self.dataset_config.default_caption
self.raw_caption = prompt
def get_caption(
self: 'FileItemDTO',
trigger=None,
@@ -201,3 +236,51 @@ class CaptionProcessingDTOMixin:
caption = ', '.join(token_list)
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
return caption
class ImageProcessingDTOMixin:
def load_and_process_image(
self: 'FileItemDTO',
transform: Union[None, transforms.Compose]
):
# todo make sure this matches
img = exif_transpose(Image.open(self.path)).convert('RGB')
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.dataset_config.buckets:
# todo allow scaling and cropping, will be hard to add
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
else:
# Downscale the source image first
img = img.resize(
(int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)),
Image.BICUBIC)
min_img_size = min(img.size)
if self.dataset_config.random_crop:
if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution:
if min_img_size < self.dataset_config.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}")
scale_size = self.dataset_config.resolution
else:
scale_size = random.randint(self.dataset_config.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.dataset_config.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC)
if transform:
img = transform(img)
self.tensor = img