added prompt dropout to happen indempendently on each TE

This commit is contained in:
Jaret Burkett
2023-11-14 05:26:51 -07:00
parent 7782caa468
commit 4f9cdd916a
7 changed files with 144 additions and 15 deletions

View File

@@ -62,8 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
# offload it. Already cached # offload it. Already cached
self.sd.vae.to('cpu') self.sd.vae.to('cpu')
flush() flush()
self.sd.noise_scheduler.set_timesteps(1000)
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
# you can expand these in a child class to make customization easier # you can expand these in a child class to make customization easier
@@ -478,9 +476,10 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'): with self.timer('encode_prompt'):
if grad_on_text_encoder: if grad_on_text_encoder:
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, conditional_embeds = self.sd.encode_prompt(
long_prompts=True).to( conditioned_prompts, prompt_2,
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=True).to(
self.device_torch, self.device_torch,
dtype=dtype) dtype=dtype)
else: else:
@@ -491,9 +490,10 @@ class SDTrainer(BaseSDTrainProcess):
te.eval() te.eval()
else: else:
self.sd.text_encoder.eval() self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, conditional_embeds = self.sd.encode_prompt(
long_prompts=True).to( conditioned_prompts, prompt_2,
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=True).to(
self.device_torch, self.device_torch,
dtype=dtype) dtype=dtype)

View File

@@ -419,6 +419,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
with open(path_to_save, 'w') as f: with open(path_to_save, 'w') as f:
json.dump(json_data, f, indent=4) json.dump(json_data, f, indent=4)
# save optimizer
if self.optimizer is not None:
try:
filename = f'optimizer.pt'
file_path = os.path.join(self.save_root, filename)
torch.save(self.optimizer.state_dict(), file_path)
except Exception as e:
print(e)
print("Could not save optimizer")
self.print(f"Saved to {file_path}") self.print(f"Saved to {file_path}")
self.clean_up_saves() self.clean_up_saves()
self.post_save_hook(file_path) self.post_save_hook(file_path)
@@ -1121,6 +1131,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
optimizer_params=self.train_config.optimizer_params) optimizer_params=self.train_config.optimizer_params)
self.optimizer = optimizer self.optimizer = optimizer
# check if it exists
optimizer_state_filename = f'optimizer.pt'
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
if os.path.exists(optimizer_state_file_path):
# try to load
try:
print(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path)
optimizer.load_state_dict(optimizer_state_dict)
except Exception as e:
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
print(e)
lr_scheduler_params = self.train_config.lr_scheduler_params lr_scheduler_params = self.train_config.lr_scheduler_params
# make sure it had bare minimum # make sure it had bare minimum

View File

@@ -0,0 +1,67 @@
import argparse
from collections import OrderedDict
import torch
from toolkit.config_modules import ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion
parser = argparse.ArgumentParser()
parser.add_argument(
'input_path',
type=str,
help='Path to original sdxl model'
)
parser.add_argument(
'output_path',
type=str,
help='output path'
)
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--refiner', action='store_true', help='is refiner model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
device = torch.device('cpu')
dtype = torch.float32
print(f"Loading model from {args.input_path}")
if args.sdxl:
adapter_id = "latent-consistency/lcm-lora-sdxl"
if args.refiner:
adapter_id = "latent-consistency/lcm-lora-sdxl"
elif args.ssd:
adapter_id = "latent-consistency/lcm-lora-ssd-1b"
else:
adapter_id = "latent-consistency/lcm-lora-sdv1-5"
diffusers_model_config = ModelConfig(
name_or_path=args.input_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
print(f"Loaded model from {args.input_path}")
diffusers_sd.pipeline.load_lora_weights(adapter_id)
diffusers_sd.pipeline.fuse_lora()
meta = OrderedDict()
diffusers_sd.save(args.output_path, meta=meta)
print(f"Saved to {args.output_path}")

View File

@@ -194,6 +194,9 @@ class TrainConfig:
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0)
# dropout that happens before encoding. It functions independently per text encoder
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
# match the norm of the noise before computing loss. This will help the model maintain its # match the norm of the noise before computing loss. This will help the model maintain its
# current understandin of the brightness of images. # current understandin of the brightness of images.

View File

@@ -12,7 +12,9 @@ from diffusers import (
HeunDiscreteScheduler, HeunDiscreteScheduler,
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
LCMScheduler
) )
from k_diffusion.external import CompVisDenoiser from k_diffusion.external import CompVisDenoiser
# scheduler: # scheduler:
@@ -72,12 +74,15 @@ def get_sampler(
scheduler_cls = KDPM2DiscreteScheduler scheduler_cls = KDPM2DiscreteScheduler
elif sampler == "dpm_2_a": elif sampler == "dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler scheduler_cls = KDPM2AncestralDiscreteScheduler
elif sampler == "lcm":
scheduler_cls = LCMScheduler
config = copy.deepcopy(sdxl_sampler_config) config = copy.deepcopy(sdxl_sampler_config)
config.update(sched_init_args) config.update(sched_init_args)
scheduler = scheduler_cls.from_config(config) scheduler = scheduler_cls.from_config(config)
return scheduler return scheduler

View File

@@ -344,6 +344,11 @@ class StableDiffusion:
else: else:
noise_scheduler = get_sampler(sampler) noise_scheduler = get_sampler(sampler)
try:
noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
except:
pass
if sampler.startswith("sample_") and self.is_xl: if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion # using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline Pipe = StableDiffusionKDiffusionXLPipeline
@@ -722,7 +727,8 @@ class StableDiffusion:
refiner_pred = self.refiner_unet( refiner_pred = self.refiner_unet(
input_chunks[1], input_chunks[1],
timestep_chunks[1], timestep_chunks[1],
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:],
# just use the first second text encoder
added_cond_kwargs={ added_cond_kwargs={
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1], "text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
# "time_ids": added_cond_kwargs_chunked['time_ids'][1], # "time_ids": added_cond_kwargs_chunked['time_ids'][1],
@@ -740,7 +746,8 @@ class StableDiffusion:
# just use the first second text encoder # just use the first second text encoder
added_cond_kwargs={ added_cond_kwargs={
"text_embeds": text_embeddings.pooled_embeds, "text_embeds": text_embeddings.pooled_embeds,
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True), "time_ids": self.get_time_ids_from_latents(latent_model_input,
requires_aesthetic_score=True),
}, },
**kwargs, **kwargs,
).sample ).sample
@@ -845,7 +852,8 @@ class StableDiffusion:
num_images_per_prompt=1, num_images_per_prompt=1,
force_all=False, force_all=False,
long_prompts=False, long_prompts=False,
max_length=None max_length=None,
dropout_prob=0.0,
) -> PromptEmbeds: ) -> PromptEmbeds:
# sd1.5 embeddings are (bs, 77, 768) # sd1.5 embeddings are (bs, 77, 768)
prompt = prompt prompt = prompt
@@ -875,12 +883,18 @@ class StableDiffusion:
use_text_encoder_2=use_encoder_2, use_text_encoder_2=use_encoder_2,
truncate=not long_prompts, truncate=not long_prompts,
max_length=max_length, max_length=max_length,
dropout_prob=dropout_prob,
) )
) )
else: else:
return PromptEmbeds( return PromptEmbeds(
train_tools.encode_prompts( train_tools.encode_prompts(
self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length self.tokenizer,
self.text_encoder,
prompt,
truncate=not long_prompts,
max_length=max_length,
dropout_prob=dropout_prob
) )
) )
@@ -1011,8 +1025,9 @@ class StableDiffusion:
state_dict[new_key] = v state_dict[new_key] = v
return state_dict return state_dict
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[ def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \
str, Parameter]: OrderedDict[
str, Parameter]:
named_params: OrderedDict[str, Parameter] = OrderedDict() named_params: OrderedDict[str, Parameter] = OrderedDict()
if vae: if vae:
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
@@ -1198,7 +1213,8 @@ class StableDiffusion:
print(f"Found {len(params)} trainable parameter in text encoder") print(f"Found {len(params)} trainable parameter in text encoder")
if refiner: if refiner:
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True) named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True,
state_dict_keys=True)
refiner_lr = refiner_lr if refiner_lr is not None else default_lr refiner_lr = refiner_lr if refiner_lr is not None else default_lr
params = [] params = []
for key, diffusers_key in ldm_diffusers_keymap.items(): for key, diffusers_key in ldm_diffusers_keymap.items():

View File

@@ -537,6 +537,7 @@ def encode_prompts_xl(
use_text_encoder_2: bool = True, # sdxl use_text_encoder_2: bool = True, # sdxl
truncate: bool = True, truncate: bool = True,
max_length=None, max_length=None,
dropout_prob=0.0,
) -> tuple[torch.FloatTensor, torch.FloatTensor]: ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# text_encoder and text_encoder_2's penuultimate layer's output # text_encoder and text_encoder_2's penuultimate layer's output
text_embeds_list = [] text_embeds_list = []
@@ -553,6 +554,12 @@ def encode_prompts_xl(
if idx == 1 and not use_text_encoder_2: if idx == 1 and not use_text_encoder_2:
prompt_list_to_use = ["" for _ in prompts] prompt_list_to_use = ["" for _ in prompts]
if dropout_prob > 0.0:
# randomly drop out prompts
prompt_list_to_use = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use
]
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length) text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
# set the max length for the next one # set the max length for the next one
if idx == 0: if idx == 0:
@@ -598,9 +605,17 @@ def encode_prompts(
prompts: list[str], prompts: list[str],
truncate: bool = True, truncate: bool = True,
max_length=None, max_length=None,
dropout_prob=0.0,
): ):
if max_length is None: if max_length is None:
max_length = tokenizer.model_max_length max_length = tokenizer.model_max_length
if dropout_prob > 0.0:
# randomly drop out prompts
prompts = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
]
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length) text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length) text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)