Fixes and longer prompts

This commit is contained in:
Jaret Burkett
2023-10-22 08:57:37 -06:00
parent 0e9fc42816
commit 9905a1e205
4 changed files with 177 additions and 54 deletions

View File

@@ -4,6 +4,7 @@ from diffusers import T2IAdapter
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import PromptEmbeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc
@@ -27,6 +28,10 @@ class SDTrainer(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', None]
self.do_prior_prediction = False
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
def before_model_load(self):
pass
@@ -135,6 +140,40 @@ class SDTrainer(BaseSDTrainProcess):
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch
def get_prior_prediction(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
# dont use network on this
self.network.multiplier = 0.0
self.sd.unet.eval()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
self.sd.unet.train()
prior_pred = prior_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
# restore network
self.network.multiplier = network_weight_list
return prior_pred
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.timer.start('preprocess_batch')
@@ -287,28 +326,18 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.train_config.inverted_mask_prior:
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
with self.timer('prior predict'):
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
# dont use network on this
network.multiplier = 0.0
self.sd.unet.eval()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
self.sd.unet.train()
prior_pred = prior_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
# restore network
network.multiplier = network_weight_list
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
)
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):

View File

@@ -55,6 +55,18 @@ transforms_dict = {
caption_ext_list = ['txt', 'json', 'caption']
def clean_caption(caption):
# remove any newlines
caption = caption.replace('\n', ', ')
# remove new lines for all operating systems
caption = caption.replace('\r', ', ')
caption_split = caption.split(',')
# remove empty strings
caption_split = [p.strip() for p in caption_split if p.strip()]
# join back together
caption = ', '.join(caption_split)
return caption
class CaptionMixin:
def get_caption_item(self: 'AiToolkitDataset', index):
@@ -91,15 +103,7 @@ class CaptionMixin:
if 'caption' in prompt:
prompt = prompt['caption']
# remove any newlines
prompt = prompt.replace('\n', ', ')
# remove new lines for all operating systems
prompt = prompt.replace('\r', ', ')
prompt_split = prompt.split(',')
# remove empty strings
prompt_split = [p.strip() for p in prompt_split if p.strip()]
# join back together
prompt = ', '.join(prompt_split)
prompt = clean_caption(prompt)
else:
prompt = ''
# get default_prompt if it exists on the class instance
@@ -135,6 +139,10 @@ class BucketsMixin:
batch = bucket.file_list_idx[start_idx:end_idx]
self.batch_indices.append(batch)
def shuffle_buckets(self: 'AiToolkitDataset'):
for key, bucket in self.buckets.items():
random.shuffle(bucket.file_list_idx)
def setup_buckets(self: 'AiToolkitDataset', quiet=False):
if not hasattr(self, 'file_list'):
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
@@ -206,6 +214,7 @@ class BucketsMixin:
self.buckets[bucket_key].file_list_idx.append(idx)
# print the buckets
self.shuffle_buckets()
self.build_batch_indices()
if not quiet:
print(f'Bucket sizes for {self.dataset_path}:')

View File

@@ -327,7 +327,6 @@ class StableDiffusion:
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
# force turn that (ruin your images with obvious green and red dots) the #$@@ off!!!
pipeline.watermark = None
else:
pipeline = Pipe(
@@ -372,7 +371,8 @@ class StableDiffusion:
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, IPAdapter):
transform = transforms.Compose([
transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(gen_config.width,
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.PILToTensor(),
])
validation_image = transform(validation_image)
@@ -395,14 +395,15 @@ class StableDiffusion:
unconditional_embeds,
)
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and gen_config.adapter_image_path is not None:
if self.adapter is not None and isinstance(self.adapter,
IPAdapter) and gen_config.adapter_image_path is not None:
# apply the image projection
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True)
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
True)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
if self.is_xl:
# fix guidance rescale for sdxl
@@ -668,7 +669,15 @@ class StableDiffusion:
# return latents_steps
return latents
def encode_prompt(self, prompt, prompt2=None, num_images_per_prompt=1, force_all=False) -> PromptEmbeds:
def encode_prompt(
self,
prompt,
prompt2=None,
num_images_per_prompt=1,
force_all=False,
long_prompts=False,
max_length=None
) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768)
prompt = prompt
# if it is not a list, make it one
@@ -695,12 +704,14 @@ class StableDiffusion:
num_images_per_prompt=num_images_per_prompt,
use_text_encoder_1=use_encoder_1,
use_text_encoder_2=use_encoder_2,
truncate=not long_prompts,
max_length=max_length,
)
)
else:
return PromptEmbeds(
train_tools.encode_prompts(
self.tokenizer, self.text_encoder, prompt
self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length
)
)

View File

@@ -447,29 +447,78 @@ if TYPE_CHECKING:
def text_tokenize(
tokenizer: 'CLIPTokenizer', # 普通ならひとつ、XLならふたつ
tokenizer: 'CLIPTokenizer',
prompts: list[str],
truncate: bool = True,
max_length: int = None,
max_length_multiplier: int = 4,
):
return tokenizer(
# allow fo up to 4x the max length for long prompts
if max_length is None:
if truncate:
max_length = tokenizer.model_max_length
else:
# allow up to 4x the max length for long prompts
max_length = tokenizer.model_max_length * max_length_multiplier
input_ids = tokenizer(
prompts,
padding="max_length",
max_length=tokenizer.model_max_length,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors="pt",
).input_ids
if truncate or max_length == tokenizer.model_max_length:
return input_ids
else:
# remove additional padding
num_chunks = input_ids.shape[1] // tokenizer.model_max_length
chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1)
# New list to store non-redundant chunks
non_redundant_chunks = []
for chunk in chunks:
if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element
non_redundant_chunks.append(chunk)
input_ids = torch.cat(non_redundant_chunks, dim=1)
return input_ids
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
def text_encode_xl(
text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'],
tokens: torch.FloatTensor,
num_images_per_prompt: int = 1,
max_length: int = 77, # not sure what default to put here, always pass one?
truncate: bool = True,
):
prompt_embeds = text_encoder(
tokens.to(text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
if truncate:
# normal short prompt 77 tokens max
prompt_embeds = text_encoder(
tokens.to(text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
else:
# handle long prompts
prompt_embeds_list = []
tokens = tokens.to(text_encoder.device)
pooled_prompt_embeds = None
for i in range(0, tokens.shape[-1], max_length):
# todo run it through the in a single batch
section_tokens = tokens[:, i: i + max_length]
embeds = text_encoder(section_tokens, output_hidden_states=True)
pooled_prompt_embed = embeds[0]
if pooled_prompt_embeds is None:
# we only want the first ( I think??)
pooled_prompt_embeds = pooled_prompt_embed
prompt_embed = embeds.hidden_states[-2] # always penultimate layer
prompt_embeds_list.append(prompt_embed)
prompt_embeds = torch.cat(prompt_embeds_list, dim=1)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -485,7 +534,9 @@ def encode_prompts_xl(
prompts2: Union[list[str], None],
num_images_per_prompt: int = 1,
use_text_encoder_1: bool = True, # sdxl
use_text_encoder_2: bool = True # sdxl
use_text_encoder_2: bool = True, # sdxl
truncate: bool = True,
max_length=None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# text_encoder and text_encoder_2's penuultimate layer's output
text_embeds_list = []
@@ -502,9 +553,14 @@ def encode_prompts_xl(
if idx == 1 and not use_text_encoder_2:
prompt_list_to_use = ["" for _ in prompts]
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use)
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
# set the max length for the next one
if idx == 0:
max_length = text_tokens_input_ids.shape[-1]
text_embeds, pooled_text_embeds = text_encode_xl(
text_encoder, text_tokens_input_ids, num_images_per_prompt
text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length,
truncate=truncate
)
text_embeds_list.append(text_embeds)
@@ -517,18 +573,36 @@ def encode_prompts_xl(
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
def text_encode(text_encoder: 'CLIPTextModel', tokens):
return text_encoder(tokens.to(text_encoder.device))[0]
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
if max_length is None and not truncate:
raise ValueError("max_length must be set if truncate is True")
tokens = tokens.to(text_encoder.device)
if truncate:
return text_encoder(tokens)[0]
else:
# handle long prompts
prompt_embeds_list = []
for i in range(0, tokens.shape[-1], max_length):
prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0]
prompt_embeds_list.append(prompt_embeds)
return torch.cat(prompt_embeds_list, dim=1)
def encode_prompts(
tokenizer: 'CLIPTokenizer',
text_encoder: 'CLIPTokenizer',
text_encoder: 'CLIPTextModel',
prompts: list[str],
truncate: bool = True,
max_length=None,
):
text_tokens = text_tokenize(tokenizer, prompts)
text_embeddings = text_encode(text_encoder, text_tokens)
if max_length is None:
max_length = tokenizer.model_max_length
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
return text_embeddings