added prompt dropout to happen indempendently on each TE

This commit is contained in:
Jaret Burkett
2023-11-14 05:26:51 -07:00
parent 7782caa468
commit 4f9cdd916a
7 changed files with 144 additions and 15 deletions

View File

@@ -62,8 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
# offload it. Already cached
self.sd.vae.to('cpu')
flush()
self.sd.noise_scheduler.set_timesteps(1000)
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
# you can expand these in a child class to make customization easier
@@ -478,9 +476,10 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=True).to(
self.device_torch,
dtype=dtype)
else:
@@ -491,9 +490,10 @@ class SDTrainer(BaseSDTrainProcess):
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=True).to(
self.device_torch,
dtype=dtype)

View File

@@ -419,6 +419,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
with open(path_to_save, 'w') as f:
json.dump(json_data, f, indent=4)
# save optimizer
if self.optimizer is not None:
try:
filename = f'optimizer.pt'
file_path = os.path.join(self.save_root, filename)
torch.save(self.optimizer.state_dict(), file_path)
except Exception as e:
print(e)
print("Could not save optimizer")
self.print(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
@@ -1121,6 +1131,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
optimizer_params=self.train_config.optimizer_params)
self.optimizer = optimizer
# check if it exists
optimizer_state_filename = f'optimizer.pt'
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
if os.path.exists(optimizer_state_file_path):
# try to load
try:
print(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path)
optimizer.load_state_dict(optimizer_state_dict)
except Exception as e:
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
print(e)
lr_scheduler_params = self.train_config.lr_scheduler_params
# make sure it had bare minimum

View File

@@ -0,0 +1,67 @@
import argparse
from collections import OrderedDict
import torch
from toolkit.config_modules import ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion
parser = argparse.ArgumentParser()
parser.add_argument(
'input_path',
type=str,
help='Path to original sdxl model'
)
parser.add_argument(
'output_path',
type=str,
help='output path'
)
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--refiner', action='store_true', help='is refiner model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
device = torch.device('cpu')
dtype = torch.float32
print(f"Loading model from {args.input_path}")
if args.sdxl:
adapter_id = "latent-consistency/lcm-lora-sdxl"
if args.refiner:
adapter_id = "latent-consistency/lcm-lora-sdxl"
elif args.ssd:
adapter_id = "latent-consistency/lcm-lora-ssd-1b"
else:
adapter_id = "latent-consistency/lcm-lora-sdv1-5"
diffusers_model_config = ModelConfig(
name_or_path=args.input_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
print(f"Loaded model from {args.input_path}")
diffusers_sd.pipeline.load_lora_weights(adapter_id)
diffusers_sd.pipeline.fuse_lora()
meta = OrderedDict()
diffusers_sd.save(args.output_path, meta=meta)
print(f"Saved to {args.output_path}")

View File

@@ -194,6 +194,9 @@ class TrainConfig:
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
# dropout that happens before encoding. It functions independently per text encoder
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
# match the norm of the noise before computing loss. This will help the model maintain its
# current understandin of the brightness of images.

View File

@@ -12,7 +12,9 @@ from diffusers import (
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
LCMScheduler
)
from k_diffusion.external import CompVisDenoiser
# scheduler:
@@ -72,12 +74,15 @@ def get_sampler(
scheduler_cls = KDPM2DiscreteScheduler
elif sampler == "dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
elif sampler == "lcm":
scheduler_cls = LCMScheduler
config = copy.deepcopy(sdxl_sampler_config)
config.update(sched_init_args)
scheduler = scheduler_cls.from_config(config)
return scheduler

View File

@@ -344,6 +344,11 @@ class StableDiffusion:
else:
noise_scheduler = get_sampler(sampler)
try:
noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
except:
pass
if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline
@@ -722,7 +727,8 @@ class StableDiffusion:
refiner_pred = self.refiner_unet(
input_chunks[1],
timestep_chunks[1],
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:],
# just use the first second text encoder
added_cond_kwargs={
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
@@ -740,7 +746,8 @@ class StableDiffusion:
# just use the first second text encoder
added_cond_kwargs={
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
"time_ids": self.get_time_ids_from_latents(latent_model_input,
requires_aesthetic_score=True),
},
**kwargs,
).sample
@@ -845,7 +852,8 @@ class StableDiffusion:
num_images_per_prompt=1,
force_all=False,
long_prompts=False,
max_length=None
max_length=None,
dropout_prob=0.0,
) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768)
prompt = prompt
@@ -875,12 +883,18 @@ class StableDiffusion:
use_text_encoder_2=use_encoder_2,
truncate=not long_prompts,
max_length=max_length,
dropout_prob=dropout_prob,
)
)
else:
return PromptEmbeds(
train_tools.encode_prompts(
self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length
self.tokenizer,
self.text_encoder,
prompt,
truncate=not long_prompts,
max_length=max_length,
dropout_prob=dropout_prob
)
)
@@ -1011,8 +1025,9 @@ class StableDiffusion:
state_dict[new_key] = v
return state_dict
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
str, Parameter]:
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \
OrderedDict[
str, Parameter]:
named_params: OrderedDict[str, Parameter] = OrderedDict()
if vae:
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
@@ -1198,7 +1213,8 @@ class StableDiffusion:
print(f"Found {len(params)} trainable parameter in text encoder")
if refiner:
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True,
state_dict_keys=True)
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
params = []
for key, diffusers_key in ldm_diffusers_keymap.items():

View File

@@ -537,6 +537,7 @@ def encode_prompts_xl(
use_text_encoder_2: bool = True, # sdxl
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# text_encoder and text_encoder_2's penuultimate layer's output
text_embeds_list = []
@@ -553,6 +554,12 @@ def encode_prompts_xl(
if idx == 1 and not use_text_encoder_2:
prompt_list_to_use = ["" for _ in prompts]
if dropout_prob > 0.0:
# randomly drop out prompts
prompt_list_to_use = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use
]
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
# set the max length for the next one
if idx == 0:
@@ -598,9 +605,17 @@ def encode_prompts(
prompts: list[str],
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
):
if max_length is None:
max_length = tokenizer.model_max_length
if dropout_prob > 0.0:
# randomly drop out prompts
prompts = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
]
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)