Built base interfaces for a DTO to handle batch infomation transports for the dataloader

This commit is contained in:
Jaret Burkett
2023-08-28 12:43:31 -06:00
parent 71da78c8af
commit e866c75638
7 changed files with 186 additions and 76 deletions

View File

@@ -86,6 +86,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.model_config.name_or_path = latest_save_path
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
@@ -113,7 +121,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# zero-pad 9 digits
step_num = f"_{str(step).zfill(9)}"
filename = f"[time]_{step_num}_[count].png"
filename = f"[time]_{step_num}_[count].{self.sample_config.ext}"
output_path = os.path.join(sample_folder, filename)
@@ -142,6 +150,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
num_inference_steps=sample_config.sample_steps,
network_multiplier=sample_config.network_multiplier,
output_path=output_path,
output_ext=sample_config.ext,
))
# send to be generated

View File

@@ -3,6 +3,7 @@ import time
from typing import List, Optional, Literal
import random
ImgExt = Literal['jpg', 'png', 'webp']
class SaveConfig:
def __init__(self, **kwargs):
@@ -31,6 +32,7 @@ class SampleConfig:
self.sample_steps = kwargs.get('sample_steps', 20)
self.network_multiplier = kwargs.get('network_multiplier', 1)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext: ImgExt = kwargs.get('format', 'jpg')
class NetworkConfig:
@@ -158,7 +160,10 @@ class SliderConfig:
class DatasetConfig:
caption_type: Literal["txt", "caption"] = 'txt'
"""
Dataset config for sd-datasets
"""
def __init__(self, **kwargs):
self.type = kwargs.get('type', 'image') # sd, slider, reference
@@ -172,6 +177,10 @@ class DatasetConfig:
self.buckets: bool = kwargs.get('buckets', False)
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
self.is_reg: bool = kwargs.get('is_reg', False)
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
class GenerateImageConfig:
@@ -191,7 +200,7 @@ class GenerateImageConfig:
# the tag [time] will be replaced with milliseconds since epoch
output_path: str = None, # full image path
output_folder: str = None, # folder to save image in if output_path is not specified
output_ext: str = 'png', # extension to save image as if output_path is not specified
output_ext: str = ImgExt, # extension to save image as if output_path is not specified
output_tail: str = '', # tail to add to output filename
add_prompt_file: bool = False, # add a prompt file with generated image
):

View File

@@ -15,6 +15,8 @@ import albumentations as A
from toolkit import image_utils
from toolkit.config_modules import DatasetConfig
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO
class ImageDataset(Dataset, CaptionMixin):
@@ -296,20 +298,6 @@ def print_once(msg):
printed_messages.append(msg)
class FileItem:
def __init__(self, **kwargs):
self.path = kwargs.get('path', None)
self.width = kwargs.get('width', None)
self.height = kwargs.get('height', None)
# we scale first, then crop
self.scale_to_width = kwargs.get('scale_to_width', self.width)
self.scale_to_height = kwargs.get('scale_to_height', self.height)
# crop values are from scaled size
self.crop_x = kwargs.get('crop_x', 0)
self.crop_y = kwargs.get('crop_y', 0)
self.crop_width = kwargs.get('crop_width', self.scale_to_width)
self.crop_height = kwargs.get('crop_height', self.scale_to_height)
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
@@ -325,7 +313,7 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
self.resolution = dataset_config.resolution
self.file_list: List['FileItem'] = []
self.file_list: List['FileItemDTO'] = []
# get the file list
file_list = [
@@ -344,14 +332,16 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
f'This process is faster for png, jpeg')
img = Image.open(file)
h, w = img.size
# TODO allow smaller images
if int(min(h, w) * self.scale) >= self.resolution:
self.file_list.append(
FileItem(
FileItemDTO(
path=file,
width=w,
height=h,
scale_to_width=int(w * self.scale),
scale_to_height=int(h * self.scale),
dataset_config=dataset_config
)
)
else:

View File

@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING
import torch
import random
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
class FileItemDTO(CaptionProcessingDTOMixin):
def __init__(self, **kwargs):
self.path = kwargs.get('path', None)
self.caption_path: str = kwargs.get('caption_path', None)
self.raw_caption: str = kwargs.get('raw_caption', None)
self.width: int = kwargs.get('width', None)
self.height: int = kwargs.get('height', None)
# we scale first, then crop
self.scale_to_width: int = kwargs.get('scale_to_width', self.width)
self.scale_to_height: int = kwargs.get('scale_to_height', self.height)
# crop values are from scaled size
self.crop_x: int = kwargs.get('crop_x', 0)
self.crop_y: int = kwargs.get('crop_y', 0)
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
# process config
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.network_network_weight: float = self.dataset_config.network_weight
class DataLoaderBatchDTO:
def __init__(self, **kwargs):
self.file_item: 'FileItemDTO' = kwargs.get('file_item', None)
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)

View File

@@ -1,9 +1,19 @@
import os
import random
from typing import TYPE_CHECKING, List, Dict
from toolkit.prompt_utils import inject_trigger_into_prompt
if TYPE_CHECKING:
from toolkit.data_loader import AiToolkitDataset
from toolkit.data_transfer_object.data_loader import FileItemDTO
# def get_associated_caption_from_img_path(img_path):
class CaptionMixin:
def get_caption_item(self, index):
def get_caption_item(self: 'AiToolkitDataset', index):
if not hasattr(self, 'caption_type'):
raise Exception('caption_type not found on class instance')
if not hasattr(self, 'file_list'):
@@ -48,7 +58,7 @@ class CaptionMixin:
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
from toolkit.data_loader import FileItem
from toolkit.data_transfer_object.data_loader import FileItemDTO
class Bucket:
@@ -63,14 +73,14 @@ class BucketsMixin:
self.buckets: Dict[str, Bucket] = {}
self.batch_indices: List[List[int]] = []
def build_batch_indices(self):
def build_batch_indices(self: 'AiToolkitDataset'):
for key, bucket in self.buckets.items():
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
batch = bucket.file_list_idx[start_idx:end_idx]
self.batch_indices.append(batch)
def setup_buckets(self):
def setup_buckets(self: 'AiToolkitDataset'):
if not hasattr(self, 'file_list'):
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
if not hasattr(self, 'dataset_config'):
@@ -79,7 +89,7 @@ class BucketsMixin:
config: 'DatasetConfig' = self.dataset_config
resolution = config.resolution
bucket_tolerance = config.bucket_tolerance
file_list: List['FileItem'] = self.file_list
file_list: List['FileItemDTO'] = self.file_list
# make sure out resolution is divisible by bucket_tolerance
if resolution % bucket_tolerance != 0:
@@ -146,3 +156,48 @@ class BucketsMixin:
print(f'{len(self.buckets)} buckets made')
# file buckets made
class CaptionProcessingDTOMixin:
def get_caption(
self: 'FileItemDTO',
trigger=None,
to_replace_list=None,
add_if_not_present=True
):
raw_caption = self.raw_caption
if raw_caption is None:
raw_caption = ''
# handle dropout
if self.dataset_config.caption_dropout_rate > 0:
# get a random float form 0 to 1
rand = random.random()
if rand < self.dataset_config.caption_dropout_rate:
# drop the caption
return ''
# get tokens
token_list = raw_caption.split(',')
# trim whitespace
token_list = [x.strip() for x in token_list]
# remove empty strings
token_list = [x for x in token_list if x]
if self.dataset_config.shuffle_tokens:
random.shuffle(token_list)
# handle token dropout
if self.dataset_config.token_dropout_rate > 0:
new_token_list = []
for token in token_list:
# get a random float form 0 to 1
rand = random.random()
if rand > self.dataset_config.token_dropout_rate:
# keep the token
new_token_list.append(token)
token_list = new_token_list
# join back together
caption = ', '.join(token_list)
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
return caption

View File

@@ -1,12 +1,11 @@
import os
from typing import Optional, TYPE_CHECKING, List
from typing import Optional, TYPE_CHECKING, List, Union, Tuple
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import random
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
import itertools
@@ -19,6 +18,27 @@ class ACTION_TYPES_SLIDER:
ENHANCE_NEGATIVE = 1
class PromptEmbeds:
text_embeds: torch.Tensor
pooled_embeds: Union[torch.Tensor, None]
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
if isinstance(args, list) or isinstance(args, tuple):
# xl
self.text_embeds = args[0]
self.pooled_embeds = args[1]
else:
# sdv1.x, sdv2.x
self.text_embeds = args
self.pooled_embeds = None
def to(self, *args, **kwargs):
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
return self
class EncodedPromptPair:
def __init__(
self,
@@ -465,3 +485,37 @@ def build_latent_image_batch_for_prompt_pair(
latent_list.append(neg_latent)
return torch.cat(latent_list, dim=0)
def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
if trigger is None:
return prompt
output_prompt = prompt
default_replacements = ["[name]", "[trigger]"]
replace_with = trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = output_prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt

View File

@@ -16,6 +16,7 @@ from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch
@@ -63,25 +64,7 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
class PromptEmbeds:
text_embeds: torch.Tensor
pooled_embeds: Union[torch.Tensor, None]
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
if isinstance(args, list) or isinstance(args, tuple):
# xl
self.text_embeds = args[0]
self.pooled_embeds = args[1]
else:
# sdv1.x, sdv2.x
self.text_embeds = args
self.pooled_embeds = None
def to(self, *args, **kwargs):
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
return self
# if is type checking
@@ -708,38 +691,12 @@ class StableDiffusion:
raise ValueError(f"Unknown weight name: {name}")
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
if trigger is None:
return prompt
output_prompt = prompt
default_replacements = ["[name]", "[trigger]"]
num_times_trigger_exists = prompt.count(trigger)
replace_with = trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = output_prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt
return inject_trigger_into_prompt(
prompt,
trigger=trigger,
to_replace_list=to_replace_list,
add_if_not_present=add_if_not_present,
)
def state_dict(self, vae=True, text_encoder=True, unet=True):
state_dict = OrderedDict()