mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
added prompt dropout to happen indempendently on each TE
This commit is contained in:
@@ -62,8 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# offload it. Already cached
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(1000)
|
||||
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
||||
|
||||
# you can expand these in a child class to make customization easier
|
||||
@@ -478,9 +476,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
||||
long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
@@ -491,9 +490,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
||||
long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
@@ -419,6 +419,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with open(path_to_save, 'w') as f:
|
||||
json.dump(json_data, f, indent=4)
|
||||
|
||||
# save optimizer
|
||||
if self.optimizer is not None:
|
||||
try:
|
||||
filename = f'optimizer.pt'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
torch.save(self.optimizer.state_dict(), file_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Could not save optimizer")
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
@@ -1121,6 +1131,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
self.optimizer = optimizer
|
||||
|
||||
# check if it exists
|
||||
optimizer_state_filename = f'optimizer.pt'
|
||||
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
|
||||
if os.path.exists(optimizer_state_file_path):
|
||||
# try to load
|
||||
try:
|
||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
except Exception as e:
|
||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print(e)
|
||||
|
||||
lr_scheduler_params = self.train_config.lr_scheduler_params
|
||||
|
||||
# make sure it had bare minimum
|
||||
|
||||
67
scripts/make_lcm_sdxl_model.py
Normal file
67
scripts/make_lcm_sdxl_model.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'input_path',
|
||||
type=str,
|
||||
help='Path to original sdxl model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'output_path',
|
||||
type=str,
|
||||
help='output path'
|
||||
)
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
print(f"Loading model from {args.input_path}")
|
||||
|
||||
if args.sdxl:
|
||||
adapter_id = "latent-consistency/lcm-lora-sdxl"
|
||||
if args.refiner:
|
||||
adapter_id = "latent-consistency/lcm-lora-sdxl"
|
||||
elif args.ssd:
|
||||
adapter_id = "latent-consistency/lcm-lora-ssd-1b"
|
||||
else:
|
||||
adapter_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=args.input_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
model_config=diffusers_model_config,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd.load_model()
|
||||
|
||||
|
||||
print(f"Loaded model from {args.input_path}")
|
||||
|
||||
diffusers_sd.pipeline.load_lora_weights(adapter_id)
|
||||
diffusers_sd.pipeline.fuse_lora()
|
||||
|
||||
meta = OrderedDict()
|
||||
|
||||
diffusers_sd.save(args.output_path, meta=meta)
|
||||
|
||||
|
||||
print(f"Saved to {args.output_path}")
|
||||
@@ -194,6 +194,9 @@ class TrainConfig:
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
|
||||
# dropout that happens before encoding. It functions independently per text encoder
|
||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||
|
||||
# match the norm of the noise before computing loss. This will help the model maintain its
|
||||
# current understandin of the brightness of images.
|
||||
|
||||
|
||||
@@ -12,7 +12,9 @@ from diffusers import (
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
LCMScheduler
|
||||
)
|
||||
|
||||
from k_diffusion.external import CompVisDenoiser
|
||||
|
||||
# scheduler:
|
||||
@@ -72,12 +74,15 @@ def get_sampler(
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif sampler == "dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
elif sampler == "lcm":
|
||||
scheduler_cls = LCMScheduler
|
||||
|
||||
config = copy.deepcopy(sdxl_sampler_config)
|
||||
config.update(sched_init_args)
|
||||
|
||||
scheduler = scheduler_cls.from_config(config)
|
||||
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
|
||||
@@ -344,6 +344,11 @@ class StableDiffusion:
|
||||
else:
|
||||
noise_scheduler = get_sampler(sampler)
|
||||
|
||||
try:
|
||||
noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
|
||||
except:
|
||||
pass
|
||||
|
||||
if sampler.startswith("sample_") and self.is_xl:
|
||||
# using kdiffusion
|
||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||
@@ -722,7 +727,8 @@ class StableDiffusion:
|
||||
refiner_pred = self.refiner_unet(
|
||||
input_chunks[1],
|
||||
timestep_chunks[1],
|
||||
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
|
||||
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:],
|
||||
# just use the first second text encoder
|
||||
added_cond_kwargs={
|
||||
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
|
||||
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
|
||||
@@ -740,7 +746,8 @@ class StableDiffusion:
|
||||
# just use the first second text encoder
|
||||
added_cond_kwargs={
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
|
||||
"time_ids": self.get_time_ids_from_latents(latent_model_input,
|
||||
requires_aesthetic_score=True),
|
||||
},
|
||||
**kwargs,
|
||||
).sample
|
||||
@@ -845,7 +852,8 @@ class StableDiffusion:
|
||||
num_images_per_prompt=1,
|
||||
force_all=False,
|
||||
long_prompts=False,
|
||||
max_length=None
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
) -> PromptEmbeds:
|
||||
# sd1.5 embeddings are (bs, 77, 768)
|
||||
prompt = prompt
|
||||
@@ -875,12 +883,18 @@ class StableDiffusion:
|
||||
use_text_encoder_2=use_encoder_2,
|
||||
truncate=not long_prompts,
|
||||
max_length=max_length,
|
||||
dropout_prob=dropout_prob,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PromptEmbeds(
|
||||
train_tools.encode_prompts(
|
||||
self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
truncate=not long_prompts,
|
||||
max_length=max_length,
|
||||
dropout_prob=dropout_prob
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1011,8 +1025,9 @@ class StableDiffusion:
|
||||
state_dict[new_key] = v
|
||||
return state_dict
|
||||
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
|
||||
str, Parameter]:
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \
|
||||
OrderedDict[
|
||||
str, Parameter]:
|
||||
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
||||
if vae:
|
||||
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
|
||||
@@ -1198,7 +1213,8 @@ class StableDiffusion:
|
||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||
|
||||
if refiner:
|
||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
|
||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True,
|
||||
state_dict_keys=True)
|
||||
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
|
||||
params = []
|
||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
|
||||
@@ -537,6 +537,7 @@ def encode_prompts_xl(
|
||||
use_text_encoder_2: bool = True, # sdxl
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# text_encoder and text_encoder_2's penuultimate layer's output
|
||||
text_embeds_list = []
|
||||
@@ -553,6 +554,12 @@ def encode_prompts_xl(
|
||||
if idx == 1 and not use_text_encoder_2:
|
||||
prompt_list_to_use = ["" for _ in prompts]
|
||||
|
||||
if dropout_prob > 0.0:
|
||||
# randomly drop out prompts
|
||||
prompt_list_to_use = [
|
||||
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use
|
||||
]
|
||||
|
||||
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
|
||||
# set the max length for the next one
|
||||
if idx == 0:
|
||||
@@ -598,9 +605,17 @@ def encode_prompts(
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
):
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
if dropout_prob > 0.0:
|
||||
# randomly drop out prompts
|
||||
prompts = [
|
||||
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
|
||||
]
|
||||
|
||||
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
|
||||
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user