mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-04 12:39:58 +00:00
Built base interfaces for a DTO to handle batch infomation transports for the dataloader
This commit is contained in:
@@ -86,6 +86,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if embedding_raw is not None:
|
||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||
|
||||
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.model_config.name_or_path = latest_save_path
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
@@ -113,7 +121,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# zero-pad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
|
||||
filename = f"[time]_{step_num}_[count].png"
|
||||
filename = f"[time]_{step_num}_[count].{self.sample_config.ext}"
|
||||
|
||||
output_path = os.path.join(sample_folder, filename)
|
||||
|
||||
@@ -142,6 +150,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
network_multiplier=sample_config.network_multiplier,
|
||||
output_path=output_path,
|
||||
output_ext=sample_config.ext,
|
||||
))
|
||||
|
||||
# send to be generated
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
from typing import List, Optional, Literal
|
||||
import random
|
||||
|
||||
ImgExt = Literal['jpg', 'png', 'webp']
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -31,6 +32,7 @@ class SampleConfig:
|
||||
self.sample_steps = kwargs.get('sample_steps', 20)
|
||||
self.network_multiplier = kwargs.get('network_multiplier', 1)
|
||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||
self.ext: ImgExt = kwargs.get('format', 'jpg')
|
||||
|
||||
|
||||
class NetworkConfig:
|
||||
@@ -158,7 +160,10 @@ class SliderConfig:
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
caption_type: Literal["txt", "caption"] = 'txt'
|
||||
"""
|
||||
Dataset config for sd-datasets
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.type = kwargs.get('type', 'image') # sd, slider, reference
|
||||
@@ -172,6 +177,10 @@ class DatasetConfig:
|
||||
self.buckets: bool = kwargs.get('buckets', False)
|
||||
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
|
||||
self.is_reg: bool = kwargs.get('is_reg', False)
|
||||
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
|
||||
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
|
||||
|
||||
class GenerateImageConfig:
|
||||
@@ -191,7 +200,7 @@ class GenerateImageConfig:
|
||||
# the tag [time] will be replaced with milliseconds since epoch
|
||||
output_path: str = None, # full image path
|
||||
output_folder: str = None, # folder to save image in if output_path is not specified
|
||||
output_ext: str = 'png', # extension to save image as if output_path is not specified
|
||||
output_ext: str = ImgExt, # extension to save image as if output_path is not specified
|
||||
output_tail: str = '', # tail to add to output filename
|
||||
add_prompt_file: bool = False, # add a prompt file with generated image
|
||||
):
|
||||
|
||||
@@ -15,6 +15,8 @@ import albumentations as A
|
||||
from toolkit import image_utils
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
@@ -296,20 +298,6 @@ def print_once(msg):
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
class FileItem:
|
||||
def __init__(self, **kwargs):
|
||||
self.path = kwargs.get('path', None)
|
||||
self.width = kwargs.get('width', None)
|
||||
self.height = kwargs.get('height', None)
|
||||
# we scale first, then crop
|
||||
self.scale_to_width = kwargs.get('scale_to_width', self.width)
|
||||
self.scale_to_height = kwargs.get('scale_to_height', self.height)
|
||||
# crop values are from scaled size
|
||||
self.crop_x = kwargs.get('crop_x', 0)
|
||||
self.crop_y = kwargs.get('crop_y', 0)
|
||||
self.crop_width = kwargs.get('crop_width', self.scale_to_width)
|
||||
self.crop_height = kwargs.get('crop_height', self.scale_to_height)
|
||||
|
||||
|
||||
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
|
||||
@@ -325,7 +313,7 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
# we always random crop if random scale is enabled
|
||||
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
|
||||
self.resolution = dataset_config.resolution
|
||||
self.file_list: List['FileItem'] = []
|
||||
self.file_list: List['FileItemDTO'] = []
|
||||
|
||||
# get the file list
|
||||
file_list = [
|
||||
@@ -344,14 +332,16 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
f'This process is faster for png, jpeg')
|
||||
img = Image.open(file)
|
||||
h, w = img.size
|
||||
# TODO allow smaller images
|
||||
if int(min(h, w) * self.scale) >= self.resolution:
|
||||
self.file_list.append(
|
||||
FileItem(
|
||||
FileItemDTO(
|
||||
path=file,
|
||||
width=w,
|
||||
height=h,
|
||||
scale_to_width=int(w * self.scale),
|
||||
scale_to_height=int(h * self.scale),
|
||||
dataset_config=dataset_config
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
36
toolkit/data_transfer_object/data_loader.py
Normal file
36
toolkit/data_transfer_object/data_loader.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import torch
|
||||
import random
|
||||
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
|
||||
|
||||
class FileItemDTO(CaptionProcessingDTOMixin):
|
||||
def __init__(self, **kwargs):
|
||||
self.path = kwargs.get('path', None)
|
||||
self.caption_path: str = kwargs.get('caption_path', None)
|
||||
self.raw_caption: str = kwargs.get('raw_caption', None)
|
||||
self.width: int = kwargs.get('width', None)
|
||||
self.height: int = kwargs.get('height', None)
|
||||
# we scale first, then crop
|
||||
self.scale_to_width: int = kwargs.get('scale_to_width', self.width)
|
||||
self.scale_to_height: int = kwargs.get('scale_to_height', self.height)
|
||||
# crop values are from scaled size
|
||||
self.crop_x: int = kwargs.get('crop_x', 0)
|
||||
self.crop_y: int = kwargs.get('crop_y', 0)
|
||||
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
|
||||
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
|
||||
|
||||
# process config
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
|
||||
self.network_network_weight: float = self.dataset_config.network_weight
|
||||
|
||||
|
||||
class DataLoaderBatchDTO:
|
||||
def __init__(self, **kwargs):
|
||||
self.file_item: 'FileItemDTO' = kwargs.get('file_item', None)
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
@@ -1,9 +1,19 @@
|
||||
import os
|
||||
import random
|
||||
from typing import TYPE_CHECKING, List, Dict
|
||||
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_loader import AiToolkitDataset
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
|
||||
# def get_associated_caption_from_img_path(img_path):
|
||||
|
||||
|
||||
class CaptionMixin:
|
||||
def get_caption_item(self, index):
|
||||
def get_caption_item(self: 'AiToolkitDataset', index):
|
||||
if not hasattr(self, 'caption_type'):
|
||||
raise Exception('caption_type not found on class instance')
|
||||
if not hasattr(self, 'file_list'):
|
||||
@@ -48,7 +58,7 @@ class CaptionMixin:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.data_loader import FileItem
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
|
||||
class Bucket:
|
||||
@@ -63,14 +73,14 @@ class BucketsMixin:
|
||||
self.buckets: Dict[str, Bucket] = {}
|
||||
self.batch_indices: List[List[int]] = []
|
||||
|
||||
def build_batch_indices(self):
|
||||
def build_batch_indices(self: 'AiToolkitDataset'):
|
||||
for key, bucket in self.buckets.items():
|
||||
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
|
||||
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
|
||||
batch = bucket.file_list_idx[start_idx:end_idx]
|
||||
self.batch_indices.append(batch)
|
||||
|
||||
def setup_buckets(self):
|
||||
def setup_buckets(self: 'AiToolkitDataset'):
|
||||
if not hasattr(self, 'file_list'):
|
||||
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
|
||||
if not hasattr(self, 'dataset_config'):
|
||||
@@ -79,7 +89,7 @@ class BucketsMixin:
|
||||
config: 'DatasetConfig' = self.dataset_config
|
||||
resolution = config.resolution
|
||||
bucket_tolerance = config.bucket_tolerance
|
||||
file_list: List['FileItem'] = self.file_list
|
||||
file_list: List['FileItemDTO'] = self.file_list
|
||||
|
||||
# make sure out resolution is divisible by bucket_tolerance
|
||||
if resolution % bucket_tolerance != 0:
|
||||
@@ -146,3 +156,48 @@ class BucketsMixin:
|
||||
print(f'{len(self.buckets)} buckets made')
|
||||
|
||||
# file buckets made
|
||||
|
||||
|
||||
class CaptionProcessingDTOMixin:
|
||||
def get_caption(
|
||||
self: 'FileItemDTO',
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
):
|
||||
raw_caption = self.raw_caption
|
||||
if raw_caption is None:
|
||||
raw_caption = ''
|
||||
# handle dropout
|
||||
if self.dataset_config.caption_dropout_rate > 0:
|
||||
# get a random float form 0 to 1
|
||||
rand = random.random()
|
||||
if rand < self.dataset_config.caption_dropout_rate:
|
||||
# drop the caption
|
||||
return ''
|
||||
|
||||
# get tokens
|
||||
token_list = raw_caption.split(',')
|
||||
# trim whitespace
|
||||
token_list = [x.strip() for x in token_list]
|
||||
# remove empty strings
|
||||
token_list = [x for x in token_list if x]
|
||||
|
||||
if self.dataset_config.shuffle_tokens:
|
||||
random.shuffle(token_list)
|
||||
|
||||
# handle token dropout
|
||||
if self.dataset_config.token_dropout_rate > 0:
|
||||
new_token_list = []
|
||||
for token in token_list:
|
||||
# get a random float form 0 to 1
|
||||
rand = random.random()
|
||||
if rand > self.dataset_config.token_dropout_rate:
|
||||
# keep the token
|
||||
new_token_list.append(token)
|
||||
token_list = new_token_list
|
||||
|
||||
# join back together
|
||||
caption = ', '.join(token_list)
|
||||
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||
return caption
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import os
|
||||
from typing import Optional, TYPE_CHECKING, List
|
||||
from typing import Optional, TYPE_CHECKING, List, Union, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import itertools
|
||||
|
||||
@@ -19,6 +18,27 @@ class ACTION_TYPES_SLIDER:
|
||||
ENHANCE_NEGATIVE = 1
|
||||
|
||||
|
||||
class PromptEmbeds:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: Union[torch.Tensor, None]
|
||||
|
||||
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
|
||||
if isinstance(args, list) or isinstance(args, tuple):
|
||||
# xl
|
||||
self.text_embeds = args[0]
|
||||
self.pooled_embeds = args[1]
|
||||
else:
|
||||
# sdv1.x, sdv2.x
|
||||
self.text_embeds = args
|
||||
self.pooled_embeds = None
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.text_embeds = self.text_embeds.to(*args, **kwargs)
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -465,3 +485,37 @@ def build_latent_image_batch_for_prompt_pair(
|
||||
latent_list.append(neg_latent)
|
||||
|
||||
return torch.cat(latent_list, dim=0)
|
||||
|
||||
|
||||
def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||
if trigger is None:
|
||||
return prompt
|
||||
output_prompt = prompt
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
replace_with = trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
@@ -16,6 +16,7 @@ from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
@@ -63,25 +64,7 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
|
||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
|
||||
|
||||
class PromptEmbeds:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: Union[torch.Tensor, None]
|
||||
|
||||
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
|
||||
if isinstance(args, list) or isinstance(args, tuple):
|
||||
# xl
|
||||
self.text_embeds = args[0]
|
||||
self.pooled_embeds = args[1]
|
||||
else:
|
||||
# sdv1.x, sdv2.x
|
||||
self.text_embeds = args
|
||||
self.pooled_embeds = None
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.text_embeds = self.text_embeds.to(*args, **kwargs)
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
# if is type checking
|
||||
@@ -708,38 +691,12 @@ class StableDiffusion:
|
||||
raise ValueError(f"Unknown weight name: {name}")
|
||||
|
||||
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||
if trigger is None:
|
||||
return prompt
|
||||
output_prompt = prompt
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
num_times_trigger_exists = prompt.count(trigger)
|
||||
|
||||
replace_with = trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
return inject_trigger_into_prompt(
|
||||
prompt,
|
||||
trigger=trigger,
|
||||
to_replace_list=to_replace_list,
|
||||
add_if_not_present=add_if_not_present,
|
||||
)
|
||||
|
||||
def state_dict(self, vae=True, text_encoder=True, unet=True):
|
||||
state_dict = OrderedDict()
|
||||
|
||||
Reference in New Issue
Block a user