diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index bd8790d9..f56bd15a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -86,6 +86,14 @@ class BaseSDTrainProcess(BaseTrainProcess): if embedding_raw is not None: self.embed_config = EmbeddingConfig(**embedding_raw) + + # check to see if we have a latest save + latest_save_path = self.get_latest_save_path() + + if latest_save_path is not None: + print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + self.model_config.name_or_path = latest_save_path + self.sd = StableDiffusion( device=self.device, model_config=self.model_config, @@ -113,7 +121,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # zero-pad 9 digits step_num = f"_{str(step).zfill(9)}" - filename = f"[time]_{step_num}_[count].png" + filename = f"[time]_{step_num}_[count].{self.sample_config.ext}" output_path = os.path.join(sample_folder, filename) @@ -142,6 +150,7 @@ class BaseSDTrainProcess(BaseTrainProcess): num_inference_steps=sample_config.sample_steps, network_multiplier=sample_config.network_multiplier, output_path=output_path, + output_ext=sample_config.ext, )) # send to be generated diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 5005bee4..7f7744d8 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -3,6 +3,7 @@ import time from typing import List, Optional, Literal import random +ImgExt = Literal['jpg', 'png', 'webp'] class SaveConfig: def __init__(self, **kwargs): @@ -31,6 +32,7 @@ class SampleConfig: self.sample_steps = kwargs.get('sample_steps', 20) self.network_multiplier = kwargs.get('network_multiplier', 1) self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.ext: ImgExt = kwargs.get('format', 'jpg') class NetworkConfig: @@ -158,7 +160,10 @@ class SliderConfig: class DatasetConfig: - caption_type: Literal["txt", "caption"] = 'txt' + """ + Dataset config for sd-datasets + + """ def __init__(self, **kwargs): self.type = kwargs.get('type', 'image') # sd, slider, reference @@ -172,6 +177,10 @@ class DatasetConfig: self.buckets: bool = kwargs.get('buckets', False) self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64) self.is_reg: bool = kwargs.get('is_reg', False) + self.network_weight: float = float(kwargs.get('network_weight', 1.0)) + self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0)) + self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) + self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) class GenerateImageConfig: @@ -191,7 +200,7 @@ class GenerateImageConfig: # the tag [time] will be replaced with milliseconds since epoch output_path: str = None, # full image path output_folder: str = None, # folder to save image in if output_path is not specified - output_ext: str = 'png', # extension to save image as if output_path is not specified + output_ext: str = ImgExt, # extension to save image as if output_path is not specified output_tail: str = '', # tail to add to output filename add_prompt_file: bool = False, # add a prompt file with generated image ): diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index cc50c0d9..5eb1c371 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -15,6 +15,8 @@ import albumentations as A from toolkit import image_utils from toolkit.config_modules import DatasetConfig from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin +from toolkit.data_transfer_object.data_loader import FileItemDTO + class ImageDataset(Dataset, CaptionMixin): @@ -296,20 +298,6 @@ def print_once(msg): printed_messages.append(msg) -class FileItem: - def __init__(self, **kwargs): - self.path = kwargs.get('path', None) - self.width = kwargs.get('width', None) - self.height = kwargs.get('height', None) - # we scale first, then crop - self.scale_to_width = kwargs.get('scale_to_width', self.width) - self.scale_to_height = kwargs.get('scale_to_height', self.height) - # crop values are from scaled size - self.crop_x = kwargs.get('crop_x', 0) - self.crop_y = kwargs.get('crop_y', 0) - self.crop_width = kwargs.get('crop_width', self.scale_to_width) - self.crop_height = kwargs.get('crop_height', self.scale_to_height) - class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): @@ -325,7 +313,7 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): # we always random crop if random scale is enabled self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop self.resolution = dataset_config.resolution - self.file_list: List['FileItem'] = [] + self.file_list: List['FileItemDTO'] = [] # get the file list file_list = [ @@ -344,14 +332,16 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): f'This process is faster for png, jpeg') img = Image.open(file) h, w = img.size + # TODO allow smaller images if int(min(h, w) * self.scale) >= self.resolution: self.file_list.append( - FileItem( + FileItemDTO( path=file, width=w, height=h, scale_to_width=int(w * self.scale), scale_to_height=int(h * self.scale), + dataset_config=dataset_config ) ) else: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py new file mode 100644 index 00000000..f2c9a509 --- /dev/null +++ b/toolkit/data_transfer_object/data_loader.py @@ -0,0 +1,36 @@ +from typing import TYPE_CHECKING +import torch +import random + +from toolkit.dataloader_mixins import CaptionProcessingDTOMixin + +if TYPE_CHECKING: + from toolkit.config_modules import DatasetConfig + + +class FileItemDTO(CaptionProcessingDTOMixin): + def __init__(self, **kwargs): + self.path = kwargs.get('path', None) + self.caption_path: str = kwargs.get('caption_path', None) + self.raw_caption: str = kwargs.get('raw_caption', None) + self.width: int = kwargs.get('width', None) + self.height: int = kwargs.get('height', None) + # we scale first, then crop + self.scale_to_width: int = kwargs.get('scale_to_width', self.width) + self.scale_to_height: int = kwargs.get('scale_to_height', self.height) + # crop values are from scaled size + self.crop_x: int = kwargs.get('crop_x', 0) + self.crop_y: int = kwargs.get('crop_y', 0) + self.crop_width: int = kwargs.get('crop_width', self.scale_to_width) + self.crop_height: int = kwargs.get('crop_height', self.scale_to_height) + + # process config + self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + + self.network_network_weight: float = self.dataset_config.network_weight + + +class DataLoaderBatchDTO: + def __init__(self, **kwargs): + self.file_item: 'FileItemDTO' = kwargs.get('file_item', None) + self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index fea29421..bc4be6a7 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1,9 +1,19 @@ import os +import random from typing import TYPE_CHECKING, List, Dict +from toolkit.prompt_utils import inject_trigger_into_prompt + +if TYPE_CHECKING: + from toolkit.data_loader import AiToolkitDataset + from toolkit.data_transfer_object.data_loader import FileItemDTO + + +# def get_associated_caption_from_img_path(img_path): + class CaptionMixin: - def get_caption_item(self, index): + def get_caption_item(self: 'AiToolkitDataset', index): if not hasattr(self, 'caption_type'): raise Exception('caption_type not found on class instance') if not hasattr(self, 'file_list'): @@ -48,7 +58,7 @@ class CaptionMixin: if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig - from toolkit.data_loader import FileItem + from toolkit.data_transfer_object.data_loader import FileItemDTO class Bucket: @@ -63,14 +73,14 @@ class BucketsMixin: self.buckets: Dict[str, Bucket] = {} self.batch_indices: List[List[int]] = [] - def build_batch_indices(self): + def build_batch_indices(self: 'AiToolkitDataset'): for key, bucket in self.buckets.items(): for start_idx in range(0, len(bucket.file_list_idx), self.batch_size): end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx)) batch = bucket.file_list_idx[start_idx:end_idx] self.batch_indices.append(batch) - def setup_buckets(self): + def setup_buckets(self: 'AiToolkitDataset'): if not hasattr(self, 'file_list'): raise Exception(f'file_list not found on class instance {self.__class__.__name__}') if not hasattr(self, 'dataset_config'): @@ -79,7 +89,7 @@ class BucketsMixin: config: 'DatasetConfig' = self.dataset_config resolution = config.resolution bucket_tolerance = config.bucket_tolerance - file_list: List['FileItem'] = self.file_list + file_list: List['FileItemDTO'] = self.file_list # make sure out resolution is divisible by bucket_tolerance if resolution % bucket_tolerance != 0: @@ -146,3 +156,48 @@ class BucketsMixin: print(f'{len(self.buckets)} buckets made') # file buckets made + + +class CaptionProcessingDTOMixin: + def get_caption( + self: 'FileItemDTO', + trigger=None, + to_replace_list=None, + add_if_not_present=True + ): + raw_caption = self.raw_caption + if raw_caption is None: + raw_caption = '' + # handle dropout + if self.dataset_config.caption_dropout_rate > 0: + # get a random float form 0 to 1 + rand = random.random() + if rand < self.dataset_config.caption_dropout_rate: + # drop the caption + return '' + + # get tokens + token_list = raw_caption.split(',') + # trim whitespace + token_list = [x.strip() for x in token_list] + # remove empty strings + token_list = [x for x in token_list if x] + + if self.dataset_config.shuffle_tokens: + random.shuffle(token_list) + + # handle token dropout + if self.dataset_config.token_dropout_rate > 0: + new_token_list = [] + for token in token_list: + # get a random float form 0 to 1 + rand = random.random() + if rand > self.dataset_config.token_dropout_rate: + # keep the token + new_token_list.append(token) + token_list = new_token_list + + # join back together + caption = ', '.join(token_list) + caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) + return caption diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 813c3bad..3ba3e257 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -1,12 +1,11 @@ import os -from typing import Optional, TYPE_CHECKING, List +from typing import Optional, TYPE_CHECKING, List, Union, Tuple import torch from safetensors.torch import load_file, save_file from tqdm import tqdm import random -from toolkit.stable_diffusion_model import PromptEmbeds from toolkit.train_tools import get_torch_dtype import itertools @@ -19,6 +18,27 @@ class ACTION_TYPES_SLIDER: ENHANCE_NEGATIVE = 1 +class PromptEmbeds: + text_embeds: torch.Tensor + pooled_embeds: Union[torch.Tensor, None] + + def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None: + if isinstance(args, list) or isinstance(args, tuple): + # xl + self.text_embeds = args[0] + self.pooled_embeds = args[1] + else: + # sdv1.x, sdv2.x + self.text_embeds = args + self.pooled_embeds = None + + def to(self, *args, **kwargs): + self.text_embeds = self.text_embeds.to(*args, **kwargs) + if self.pooled_embeds is not None: + self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) + return self + + class EncodedPromptPair: def __init__( self, @@ -465,3 +485,37 @@ def build_latent_image_batch_for_prompt_pair( latent_list.append(neg_latent) return torch.cat(latent_list, dim=0) + + +def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True): + if trigger is None: + return prompt + output_prompt = prompt + default_replacements = ["[name]", "[trigger]"] + + replace_with = trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = output_prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + if num_instances > 1: + print( + f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index b831cdd9..20e19312 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -16,6 +16,7 @@ from toolkit import train_tools from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors from toolkit.paths import REPOS_ROOT +from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds from toolkit.saving import save_ldm_model_from_diffusers from toolkit.train_tools import get_torch_dtype, apply_noise_offset import torch @@ -63,25 +64,7 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 -class PromptEmbeds: - text_embeds: torch.Tensor - pooled_embeds: Union[torch.Tensor, None] - def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None: - if isinstance(args, list) or isinstance(args, tuple): - # xl - self.text_embeds = args[0] - self.pooled_embeds = args[1] - else: - # sdv1.x, sdv2.x - self.text_embeds = args - self.pooled_embeds = None - - def to(self, *args, **kwargs): - self.text_embeds = self.text_embeds.to(*args, **kwargs) - if self.pooled_embeds is not None: - self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) - return self # if is type checking @@ -708,38 +691,12 @@ class StableDiffusion: raise ValueError(f"Unknown weight name: {name}") def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True): - if trigger is None: - return prompt - output_prompt = prompt - default_replacements = ["[name]", "[trigger]"] - num_times_trigger_exists = prompt.count(trigger) - - replace_with = trigger - if to_replace_list is None: - to_replace_list = default_replacements - else: - to_replace_list += default_replacements - - # remove duplicates - to_replace_list = list(set(to_replace_list)) - - # replace them all - for to_replace in to_replace_list: - # replace it - output_prompt = output_prompt.replace(to_replace, replace_with) - - # see how many times replace_with is in the prompt - num_instances = output_prompt.count(replace_with) - - if num_instances == 0 and add_if_not_present: - # add it to the beginning of the prompt - output_prompt = replace_with + " " + output_prompt - - if num_instances > 1: - print( - f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") - - return output_prompt + return inject_trigger_into_prompt( + prompt, + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present, + ) def state_dict(self, vae=True, text_encoder=True, unet=True): state_dict = OrderedDict()