diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ca608b6e..3f7ecbc1 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -4,6 +4,7 @@ from diffusers import T2IAdapter from toolkit.basic import value_map from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.ip_adapter import IPAdapter +from toolkit.prompt_utils import PromptEmbeds from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.train_tools import get_torch_dtype, apply_snr_weight import gc @@ -27,6 +28,10 @@ class SDTrainer(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): super().__init__(process_id, job, config, **kwargs) self.assistant_adapter: Union['T2IAdapter', None] + self.do_prior_prediction = False + if self.train_config.inverted_mask_prior: + self.do_prior_prediction = True + def before_model_load(self): pass @@ -135,6 +140,40 @@ class SDTrainer(BaseSDTrainProcess): def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): return batch + def get_prior_prediction( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + **kwargs + ): + # do a prediction here so we can match its output with network multiplier set to 0.0 + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + # dont use network on this + self.network.multiplier = 0.0 + self.sd.unet.eval() + prior_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + self.sd.unet.train() + prior_pred = prior_pred.detach() + # remove the residuals as we wont use them on prediction when matching control + if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: + del pred_kwargs['down_block_additional_residuals'] + # restore network + self.network.multiplier = network_weight_list + return prior_pred + def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): self.timer.start('preprocess_batch') @@ -287,28 +326,18 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals prior_pred = None - if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.train_config.inverted_mask_prior: + if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction: with self.timer('prior predict'): - # do a prediction here so we can match its output with network multiplier set to 0.0 - with torch.no_grad(): - # dont use network on this - network.multiplier = 0.0 - self.sd.unet.eval() - prior_pred = self.sd.predict_noise( - latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs # adapter residuals in here - ) - self.sd.unet.train() - prior_pred = prior_pred.detach() - # remove the residuals as we wont use them on prediction when matching control - if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: - del pred_kwargs['down_block_additional_residuals'] - # restore network - network.multiplier = network_weight_list - + prior_pred = self.get_prior_prediction( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + noise=noise, + batch=batch, + ) if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter'): diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 614579a6..9ea062b5 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -55,6 +55,18 @@ transforms_dict = { caption_ext_list = ['txt', 'json', 'caption'] +def clean_caption(caption): + # remove any newlines + caption = caption.replace('\n', ', ') + # remove new lines for all operating systems + caption = caption.replace('\r', ', ') + caption_split = caption.split(',') + # remove empty strings + caption_split = [p.strip() for p in caption_split if p.strip()] + # join back together + caption = ', '.join(caption_split) + return caption + class CaptionMixin: def get_caption_item(self: 'AiToolkitDataset', index): @@ -91,15 +103,7 @@ class CaptionMixin: if 'caption' in prompt: prompt = prompt['caption'] - # remove any newlines - prompt = prompt.replace('\n', ', ') - # remove new lines for all operating systems - prompt = prompt.replace('\r', ', ') - prompt_split = prompt.split(',') - # remove empty strings - prompt_split = [p.strip() for p in prompt_split if p.strip()] - # join back together - prompt = ', '.join(prompt_split) + prompt = clean_caption(prompt) else: prompt = '' # get default_prompt if it exists on the class instance @@ -135,6 +139,10 @@ class BucketsMixin: batch = bucket.file_list_idx[start_idx:end_idx] self.batch_indices.append(batch) + def shuffle_buckets(self: 'AiToolkitDataset'): + for key, bucket in self.buckets.items(): + random.shuffle(bucket.file_list_idx) + def setup_buckets(self: 'AiToolkitDataset', quiet=False): if not hasattr(self, 'file_list'): raise Exception(f'file_list not found on class instance {self.__class__.__name__}') @@ -206,6 +214,7 @@ class BucketsMixin: self.buckets[bucket_key].file_list_idx.append(idx) # print the buckets + self.shuffle_buckets() self.build_batch_indices() if not quiet: print(f'Bucket sizes for {self.dataset_path}:') diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index fd2e6dfd..dd17406d 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -327,7 +327,6 @@ class StableDiffusion: scheduler=noise_scheduler, **extra_args ).to(self.device_torch) - # force turn that (ruin your images with obvious green and red dots) the #$@@ off!!! pipeline.watermark = None else: pipeline = Pipe( @@ -372,7 +371,8 @@ class StableDiffusion: extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale if isinstance(self.adapter, IPAdapter): transform = transforms.Compose([ - transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(gen_config.width, + interpolation=transforms.InterpolationMode.BILINEAR), transforms.PILToTensor(), ]) validation_image = transform(validation_image) @@ -395,14 +395,15 @@ class StableDiffusion: unconditional_embeds, ) - if self.adapter is not None and isinstance(self.adapter, IPAdapter) and gen_config.adapter_image_path is not None: + if self.adapter is not None and isinstance(self.adapter, + IPAdapter) and gen_config.adapter_image_path is not None: # apply the image projection conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) - unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, + True) conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds) - # todo do we disable text encoder here as well if disabled for model, or only do that for training? if self.is_xl: # fix guidance rescale for sdxl @@ -668,7 +669,15 @@ class StableDiffusion: # return latents_steps return latents - def encode_prompt(self, prompt, prompt2=None, num_images_per_prompt=1, force_all=False) -> PromptEmbeds: + def encode_prompt( + self, + prompt, + prompt2=None, + num_images_per_prompt=1, + force_all=False, + long_prompts=False, + max_length=None + ) -> PromptEmbeds: # sd1.5 embeddings are (bs, 77, 768) prompt = prompt # if it is not a list, make it one @@ -695,12 +704,14 @@ class StableDiffusion: num_images_per_prompt=num_images_per_prompt, use_text_encoder_1=use_encoder_1, use_text_encoder_2=use_encoder_2, + truncate=not long_prompts, + max_length=max_length, ) ) else: return PromptEmbeds( train_tools.encode_prompts( - self.tokenizer, self.text_encoder, prompt + self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length ) ) diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 1e69d798..99cf85f9 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -447,29 +447,78 @@ if TYPE_CHECKING: def text_tokenize( - tokenizer: 'CLIPTokenizer', # 普通ならひとつ、XLならふたつ! + tokenizer: 'CLIPTokenizer', prompts: list[str], + truncate: bool = True, + max_length: int = None, + max_length_multiplier: int = 4, ): - return tokenizer( + # allow fo up to 4x the max length for long prompts + if max_length is None: + if truncate: + max_length = tokenizer.model_max_length + else: + # allow up to 4x the max length for long prompts + max_length = tokenizer.model_max_length * max_length_multiplier + + input_ids = tokenizer( prompts, - padding="max_length", - max_length=tokenizer.model_max_length, + padding='max_length', + max_length=max_length, truncation=True, return_tensors="pt", ).input_ids + if truncate or max_length == tokenizer.model_max_length: + return input_ids + else: + # remove additional padding + num_chunks = input_ids.shape[1] // tokenizer.model_max_length + chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1) + + # New list to store non-redundant chunks + non_redundant_chunks = [] + + for chunk in chunks: + if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element + non_redundant_chunks.append(chunk) + + input_ids = torch.cat(non_redundant_chunks, dim=1) + return input_ids + # https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348 def text_encode_xl( text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'], tokens: torch.FloatTensor, num_images_per_prompt: int = 1, + max_length: int = 77, # not sure what default to put here, always pass one? + truncate: bool = True, ): - prompt_embeds = text_encoder( - tokens.to(text_encoder.device), output_hidden_states=True - ) - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer + if truncate: + # normal short prompt 77 tokens max + prompt_embeds = text_encoder( + tokens.to(text_encoder.device), output_hidden_states=True + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer + else: + # handle long prompts + prompt_embeds_list = [] + tokens = tokens.to(text_encoder.device) + pooled_prompt_embeds = None + for i in range(0, tokens.shape[-1], max_length): + # todo run it through the in a single batch + section_tokens = tokens[:, i: i + max_length] + embeds = text_encoder(section_tokens, output_hidden_states=True) + pooled_prompt_embed = embeds[0] + if pooled_prompt_embeds is None: + # we only want the first ( I think??) + pooled_prompt_embeds = pooled_prompt_embed + prompt_embed = embeds.hidden_states[-2] # always penultimate layer + prompt_embeds_list.append(prompt_embed) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=1) bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -485,7 +534,9 @@ def encode_prompts_xl( prompts2: Union[list[str], None], num_images_per_prompt: int = 1, use_text_encoder_1: bool = True, # sdxl - use_text_encoder_2: bool = True # sdxl + use_text_encoder_2: bool = True, # sdxl + truncate: bool = True, + max_length=None, ) -> tuple[torch.FloatTensor, torch.FloatTensor]: # text_encoder and text_encoder_2's penuultimate layer's output text_embeds_list = [] @@ -502,9 +553,14 @@ def encode_prompts_xl( if idx == 1 and not use_text_encoder_2: prompt_list_to_use = ["" for _ in prompts] - text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use) + text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length) + # set the max length for the next one + if idx == 0: + max_length = text_tokens_input_ids.shape[-1] + text_embeds, pooled_text_embeds = text_encode_xl( - text_encoder, text_tokens_input_ids, num_images_per_prompt + text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length, + truncate=truncate ) text_embeds_list.append(text_embeds) @@ -517,18 +573,36 @@ def encode_prompts_xl( return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds -def text_encode(text_encoder: 'CLIPTextModel', tokens): - return text_encoder(tokens.to(text_encoder.device))[0] +# ref for long prompts https://github.com/huggingface/diffusers/issues/2136 +def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None): + if max_length is None and not truncate: + raise ValueError("max_length must be set if truncate is True") + tokens = tokens.to(text_encoder.device) + + if truncate: + return text_encoder(tokens)[0] + else: + # handle long prompts + prompt_embeds_list = [] + for i in range(0, tokens.shape[-1], max_length): + prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0] + prompt_embeds_list.append(prompt_embeds) + + return torch.cat(prompt_embeds_list, dim=1) def encode_prompts( tokenizer: 'CLIPTokenizer', - text_encoder: 'CLIPTokenizer', + text_encoder: 'CLIPTextModel', prompts: list[str], + truncate: bool = True, + max_length=None, ): - text_tokens = text_tokenize(tokenizer, prompts) - text_embeddings = text_encode(text_encoder, text_tokens) + if max_length is None: + max_length = tokenizer.model_max_length + text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length) + text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length) return text_embeddings