mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixes and longer prompts
This commit is contained in:
@@ -4,6 +4,7 @@ from diffusers import T2IAdapter
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
import gc
|
||||
@@ -27,6 +28,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
self.assistant_adapter: Union['T2IAdapter', None]
|
||||
self.do_prior_prediction = False
|
||||
if self.train_config.inverted_mask_prior:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -135,6 +140,40 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
return batch
|
||||
|
||||
def get_prior_prediction(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# dont use network on this
|
||||
self.network.multiplier = 0.0
|
||||
self.sd.unet.eval()
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
self.sd.unet.train()
|
||||
prior_pred = prior_pred.detach()
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
# restore network
|
||||
self.network.multiplier = network_weight_list
|
||||
return prior_pred
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
@@ -287,28 +326,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.train_config.inverted_mask_prior:
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
||||
with self.timer('prior predict'):
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
with torch.no_grad():
|
||||
# dont use network on this
|
||||
network.multiplier = 0.0
|
||||
self.sd.unet.eval()
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
self.sd.unet.train()
|
||||
prior_pred = prior_pred.detach()
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
# restore network
|
||||
network.multiplier = network_weight_list
|
||||
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
|
||||
@@ -55,6 +55,18 @@ transforms_dict = {
|
||||
|
||||
caption_ext_list = ['txt', 'json', 'caption']
|
||||
|
||||
def clean_caption(caption):
|
||||
# remove any newlines
|
||||
caption = caption.replace('\n', ', ')
|
||||
# remove new lines for all operating systems
|
||||
caption = caption.replace('\r', ', ')
|
||||
caption_split = caption.split(',')
|
||||
# remove empty strings
|
||||
caption_split = [p.strip() for p in caption_split if p.strip()]
|
||||
# join back together
|
||||
caption = ', '.join(caption_split)
|
||||
return caption
|
||||
|
||||
|
||||
class CaptionMixin:
|
||||
def get_caption_item(self: 'AiToolkitDataset', index):
|
||||
@@ -91,15 +103,7 @@ class CaptionMixin:
|
||||
if 'caption' in prompt:
|
||||
prompt = prompt['caption']
|
||||
|
||||
# remove any newlines
|
||||
prompt = prompt.replace('\n', ', ')
|
||||
# remove new lines for all operating systems
|
||||
prompt = prompt.replace('\r', ', ')
|
||||
prompt_split = prompt.split(',')
|
||||
# remove empty strings
|
||||
prompt_split = [p.strip() for p in prompt_split if p.strip()]
|
||||
# join back together
|
||||
prompt = ', '.join(prompt_split)
|
||||
prompt = clean_caption(prompt)
|
||||
else:
|
||||
prompt = ''
|
||||
# get default_prompt if it exists on the class instance
|
||||
@@ -135,6 +139,10 @@ class BucketsMixin:
|
||||
batch = bucket.file_list_idx[start_idx:end_idx]
|
||||
self.batch_indices.append(batch)
|
||||
|
||||
def shuffle_buckets(self: 'AiToolkitDataset'):
|
||||
for key, bucket in self.buckets.items():
|
||||
random.shuffle(bucket.file_list_idx)
|
||||
|
||||
def setup_buckets(self: 'AiToolkitDataset', quiet=False):
|
||||
if not hasattr(self, 'file_list'):
|
||||
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
|
||||
@@ -206,6 +214,7 @@ class BucketsMixin:
|
||||
self.buckets[bucket_key].file_list_idx.append(idx)
|
||||
|
||||
# print the buckets
|
||||
self.shuffle_buckets()
|
||||
self.build_batch_indices()
|
||||
if not quiet:
|
||||
print(f'Bucket sizes for {self.dataset_path}:')
|
||||
|
||||
@@ -327,7 +327,6 @@ class StableDiffusion:
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
# force turn that (ruin your images with obvious green and red dots) the #$@@ off!!!
|
||||
pipeline.watermark = None
|
||||
else:
|
||||
pipeline = Pipe(
|
||||
@@ -372,7 +371,8 @@ class StableDiffusion:
|
||||
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(gen_config.width,
|
||||
interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.PILToTensor(),
|
||||
])
|
||||
validation_image = transform(validation_image)
|
||||
@@ -395,14 +395,15 @@ class StableDiffusion:
|
||||
unconditional_embeds,
|
||||
)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and gen_config.adapter_image_path is not None:
|
||||
if self.adapter is not None and isinstance(self.adapter,
|
||||
IPAdapter) and gen_config.adapter_image_path is not None:
|
||||
# apply the image projection
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True)
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
|
||||
True)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||
|
||||
|
||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
@@ -668,7 +669,15 @@ class StableDiffusion:
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, prompt, prompt2=None, num_images_per_prompt=1, force_all=False) -> PromptEmbeds:
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
prompt2=None,
|
||||
num_images_per_prompt=1,
|
||||
force_all=False,
|
||||
long_prompts=False,
|
||||
max_length=None
|
||||
) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
# if it is not a list, make it one
|
||||
@@ -695,12 +704,14 @@ class StableDiffusion:
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
use_text_encoder_1=use_encoder_1,
|
||||
use_text_encoder_2=use_encoder_2,
|
||||
truncate=not long_prompts,
|
||||
max_length=max_length,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts(
|
||||
self.tokenizer, self.text_encoder, prompt
|
||||
self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -447,29 +447,78 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def text_tokenize(
|
||||
tokenizer: 'CLIPTokenizer', # 普通ならひとつ、XLならふたつ!
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length: int = None,
|
||||
max_length_multiplier: int = 4,
|
||||
):
|
||||
return tokenizer(
|
||||
# allow fo up to 4x the max length for long prompts
|
||||
if max_length is None:
|
||||
if truncate:
|
||||
max_length = tokenizer.model_max_length
|
||||
else:
|
||||
# allow up to 4x the max length for long prompts
|
||||
max_length = tokenizer.model_max_length * max_length_multiplier
|
||||
|
||||
input_ids = tokenizer(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
padding='max_length',
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
if truncate or max_length == tokenizer.model_max_length:
|
||||
return input_ids
|
||||
else:
|
||||
# remove additional padding
|
||||
num_chunks = input_ids.shape[1] // tokenizer.model_max_length
|
||||
chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1)
|
||||
|
||||
# New list to store non-redundant chunks
|
||||
non_redundant_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element
|
||||
non_redundant_chunks.append(chunk)
|
||||
|
||||
input_ids = torch.cat(non_redundant_chunks, dim=1)
|
||||
return input_ids
|
||||
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
|
||||
def text_encode_xl(
|
||||
text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'],
|
||||
tokens: torch.FloatTensor,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_length: int = 77, # not sure what default to put here, always pass one?
|
||||
truncate: bool = True,
|
||||
):
|
||||
prompt_embeds = text_encoder(
|
||||
tokens.to(text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
||||
if truncate:
|
||||
# normal short prompt 77 tokens max
|
||||
prompt_embeds = text_encoder(
|
||||
tokens.to(text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
||||
else:
|
||||
# handle long prompts
|
||||
prompt_embeds_list = []
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
pooled_prompt_embeds = None
|
||||
for i in range(0, tokens.shape[-1], max_length):
|
||||
# todo run it through the in a single batch
|
||||
section_tokens = tokens[:, i: i + max_length]
|
||||
embeds = text_encoder(section_tokens, output_hidden_states=True)
|
||||
pooled_prompt_embed = embeds[0]
|
||||
if pooled_prompt_embeds is None:
|
||||
# we only want the first ( I think??)
|
||||
pooled_prompt_embeds = pooled_prompt_embed
|
||||
prompt_embed = embeds.hidden_states[-2] # always penultimate layer
|
||||
prompt_embeds_list.append(prompt_embed)
|
||||
|
||||
prompt_embeds = torch.cat(prompt_embeds_list, dim=1)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -485,7 +534,9 @@ def encode_prompts_xl(
|
||||
prompts2: Union[list[str], None],
|
||||
num_images_per_prompt: int = 1,
|
||||
use_text_encoder_1: bool = True, # sdxl
|
||||
use_text_encoder_2: bool = True # sdxl
|
||||
use_text_encoder_2: bool = True, # sdxl
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# text_encoder and text_encoder_2's penuultimate layer's output
|
||||
text_embeds_list = []
|
||||
@@ -502,9 +553,14 @@ def encode_prompts_xl(
|
||||
if idx == 1 and not use_text_encoder_2:
|
||||
prompt_list_to_use = ["" for _ in prompts]
|
||||
|
||||
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use)
|
||||
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
|
||||
# set the max length for the next one
|
||||
if idx == 0:
|
||||
max_length = text_tokens_input_ids.shape[-1]
|
||||
|
||||
text_embeds, pooled_text_embeds = text_encode_xl(
|
||||
text_encoder, text_tokens_input_ids, num_images_per_prompt
|
||||
text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length,
|
||||
truncate=truncate
|
||||
)
|
||||
|
||||
text_embeds_list.append(text_embeds)
|
||||
@@ -517,18 +573,36 @@ def encode_prompts_xl(
|
||||
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
||||
|
||||
|
||||
def text_encode(text_encoder: 'CLIPTextModel', tokens):
|
||||
return text_encoder(tokens.to(text_encoder.device))[0]
|
||||
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
|
||||
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
|
||||
if max_length is None and not truncate:
|
||||
raise ValueError("max_length must be set if truncate is True")
|
||||
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
|
||||
if truncate:
|
||||
return text_encoder(tokens)[0]
|
||||
else:
|
||||
# handle long prompts
|
||||
prompt_embeds_list = []
|
||||
for i in range(0, tokens.shape[-1], max_length):
|
||||
prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0]
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
return torch.cat(prompt_embeds_list, dim=1)
|
||||
|
||||
|
||||
def encode_prompts(
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
text_encoder: 'CLIPTokenizer',
|
||||
text_encoder: 'CLIPTextModel',
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
):
|
||||
text_tokens = text_tokenize(tokenizer, prompts)
|
||||
text_embeddings = text_encode(text_encoder, text_tokens)
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length
|
||||
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
|
||||
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
|
||||
|
||||
return text_embeddings
|
||||
|
||||
|
||||
Reference in New Issue
Block a user