mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
added prompt dropout to happen indempendently on each TE
This commit is contained in:
@@ -62,8 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# offload it. Already cached
|
# offload it. Already cached
|
||||||
self.sd.vae.to('cpu')
|
self.sd.vae.to('cpu')
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
self.sd.noise_scheduler.set_timesteps(1000)
|
|
||||||
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
||||||
|
|
||||||
# you can expand these in a child class to make customization easier
|
# you can expand these in a child class to make customization easier
|
||||||
@@ -478,9 +476,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
with self.timer('encode_prompt'):
|
with self.timer('encode_prompt'):
|
||||||
if grad_on_text_encoder:
|
if grad_on_text_encoder:
|
||||||
with torch.set_grad_enabled(True):
|
with torch.set_grad_enabled(True):
|
||||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
conditional_embeds = self.sd.encode_prompt(
|
||||||
long_prompts=True).to(
|
conditioned_prompts, prompt_2,
|
||||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||||
|
long_prompts=True).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
else:
|
else:
|
||||||
@@ -491,9 +490,10 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
te.eval()
|
te.eval()
|
||||||
else:
|
else:
|
||||||
self.sd.text_encoder.eval()
|
self.sd.text_encoder.eval()
|
||||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
|
conditional_embeds = self.sd.encode_prompt(
|
||||||
long_prompts=True).to(
|
conditioned_prompts, prompt_2,
|
||||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||||
|
long_prompts=True).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -419,6 +419,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
with open(path_to_save, 'w') as f:
|
with open(path_to_save, 'w') as f:
|
||||||
json.dump(json_data, f, indent=4)
|
json.dump(json_data, f, indent=4)
|
||||||
|
|
||||||
|
# save optimizer
|
||||||
|
if self.optimizer is not None:
|
||||||
|
try:
|
||||||
|
filename = f'optimizer.pt'
|
||||||
|
file_path = os.path.join(self.save_root, filename)
|
||||||
|
torch.save(self.optimizer.state_dict(), file_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print("Could not save optimizer")
|
||||||
|
|
||||||
self.print(f"Saved to {file_path}")
|
self.print(f"Saved to {file_path}")
|
||||||
self.clean_up_saves()
|
self.clean_up_saves()
|
||||||
self.post_save_hook(file_path)
|
self.post_save_hook(file_path)
|
||||||
@@ -1121,6 +1131,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
optimizer_params=self.train_config.optimizer_params)
|
optimizer_params=self.train_config.optimizer_params)
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
# check if it exists
|
||||||
|
optimizer_state_filename = f'optimizer.pt'
|
||||||
|
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
|
||||||
|
if os.path.exists(optimizer_state_file_path):
|
||||||
|
# try to load
|
||||||
|
try:
|
||||||
|
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||||
|
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
||||||
|
optimizer.load_state_dict(optimizer_state_dict)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||||
|
print(e)
|
||||||
|
|
||||||
lr_scheduler_params = self.train_config.lr_scheduler_params
|
lr_scheduler_params = self.train_config.lr_scheduler_params
|
||||||
|
|
||||||
# make sure it had bare minimum
|
# make sure it had bare minimum
|
||||||
|
|||||||
67
scripts/make_lcm_sdxl_model.py
Normal file
67
scripts/make_lcm_sdxl_model.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import argparse
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from toolkit.config_modules import ModelConfig
|
||||||
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'input_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to original sdxl model'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'output_path',
|
||||||
|
type=str,
|
||||||
|
help='output path'
|
||||||
|
)
|
||||||
|
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||||
|
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||||
|
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||||
|
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
device = torch.device('cpu')
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
print(f"Loading model from {args.input_path}")
|
||||||
|
|
||||||
|
if args.sdxl:
|
||||||
|
adapter_id = "latent-consistency/lcm-lora-sdxl"
|
||||||
|
if args.refiner:
|
||||||
|
adapter_id = "latent-consistency/lcm-lora-sdxl"
|
||||||
|
elif args.ssd:
|
||||||
|
adapter_id = "latent-consistency/lcm-lora-ssd-1b"
|
||||||
|
else:
|
||||||
|
adapter_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||||
|
|
||||||
|
|
||||||
|
diffusers_model_config = ModelConfig(
|
||||||
|
name_or_path=args.input_path,
|
||||||
|
is_xl=args.sdxl,
|
||||||
|
is_v2=args.sd2,
|
||||||
|
is_ssd=args.ssd,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
diffusers_sd = StableDiffusion(
|
||||||
|
model_config=diffusers_model_config,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
diffusers_sd.load_model()
|
||||||
|
|
||||||
|
|
||||||
|
print(f"Loaded model from {args.input_path}")
|
||||||
|
|
||||||
|
diffusers_sd.pipeline.load_lora_weights(adapter_id)
|
||||||
|
diffusers_sd.pipeline.fuse_lora()
|
||||||
|
|
||||||
|
meta = OrderedDict()
|
||||||
|
|
||||||
|
diffusers_sd.save(args.output_path, meta=meta)
|
||||||
|
|
||||||
|
|
||||||
|
print(f"Saved to {args.output_path}")
|
||||||
@@ -194,6 +194,9 @@ class TrainConfig:
|
|||||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
|
|
||||||
|
# dropout that happens before encoding. It functions independently per text encoder
|
||||||
|
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||||
|
|
||||||
# match the norm of the noise before computing loss. This will help the model maintain its
|
# match the norm of the noise before computing loss. This will help the model maintain its
|
||||||
# current understandin of the brightness of images.
|
# current understandin of the brightness of images.
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ from diffusers import (
|
|||||||
HeunDiscreteScheduler,
|
HeunDiscreteScheduler,
|
||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
|
LCMScheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
from k_diffusion.external import CompVisDenoiser
|
from k_diffusion.external import CompVisDenoiser
|
||||||
|
|
||||||
# scheduler:
|
# scheduler:
|
||||||
@@ -72,12 +74,15 @@ def get_sampler(
|
|||||||
scheduler_cls = KDPM2DiscreteScheduler
|
scheduler_cls = KDPM2DiscreteScheduler
|
||||||
elif sampler == "dpm_2_a":
|
elif sampler == "dpm_2_a":
|
||||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||||
|
elif sampler == "lcm":
|
||||||
|
scheduler_cls = LCMScheduler
|
||||||
|
|
||||||
config = copy.deepcopy(sdxl_sampler_config)
|
config = copy.deepcopy(sdxl_sampler_config)
|
||||||
config.update(sched_init_args)
|
config.update(sched_init_args)
|
||||||
|
|
||||||
scheduler = scheduler_cls.from_config(config)
|
scheduler = scheduler_cls.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -344,6 +344,11 @@ class StableDiffusion:
|
|||||||
else:
|
else:
|
||||||
noise_scheduler = get_sampler(sampler)
|
noise_scheduler = get_sampler(sampler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if sampler.startswith("sample_") and self.is_xl:
|
if sampler.startswith("sample_") and self.is_xl:
|
||||||
# using kdiffusion
|
# using kdiffusion
|
||||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||||
@@ -722,7 +727,8 @@ class StableDiffusion:
|
|||||||
refiner_pred = self.refiner_unet(
|
refiner_pred = self.refiner_unet(
|
||||||
input_chunks[1],
|
input_chunks[1],
|
||||||
timestep_chunks[1],
|
timestep_chunks[1],
|
||||||
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
|
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:],
|
||||||
|
# just use the first second text encoder
|
||||||
added_cond_kwargs={
|
added_cond_kwargs={
|
||||||
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
|
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
|
||||||
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
|
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
|
||||||
@@ -740,7 +746,8 @@ class StableDiffusion:
|
|||||||
# just use the first second text encoder
|
# just use the first second text encoder
|
||||||
added_cond_kwargs={
|
added_cond_kwargs={
|
||||||
"text_embeds": text_embeddings.pooled_embeds,
|
"text_embeds": text_embeddings.pooled_embeds,
|
||||||
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
|
"time_ids": self.get_time_ids_from_latents(latent_model_input,
|
||||||
|
requires_aesthetic_score=True),
|
||||||
},
|
},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).sample
|
).sample
|
||||||
@@ -845,7 +852,8 @@ class StableDiffusion:
|
|||||||
num_images_per_prompt=1,
|
num_images_per_prompt=1,
|
||||||
force_all=False,
|
force_all=False,
|
||||||
long_prompts=False,
|
long_prompts=False,
|
||||||
max_length=None
|
max_length=None,
|
||||||
|
dropout_prob=0.0,
|
||||||
) -> PromptEmbeds:
|
) -> PromptEmbeds:
|
||||||
# sd1.5 embeddings are (bs, 77, 768)
|
# sd1.5 embeddings are (bs, 77, 768)
|
||||||
prompt = prompt
|
prompt = prompt
|
||||||
@@ -875,12 +883,18 @@ class StableDiffusion:
|
|||||||
use_text_encoder_2=use_encoder_2,
|
use_text_encoder_2=use_encoder_2,
|
||||||
truncate=not long_prompts,
|
truncate=not long_prompts,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
|
dropout_prob=dropout_prob,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return PromptEmbeds(
|
return PromptEmbeds(
|
||||||
train_tools.encode_prompts(
|
train_tools.encode_prompts(
|
||||||
self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length
|
self.tokenizer,
|
||||||
|
self.text_encoder,
|
||||||
|
prompt,
|
||||||
|
truncate=not long_prompts,
|
||||||
|
max_length=max_length,
|
||||||
|
dropout_prob=dropout_prob
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1011,8 +1025,9 @@ class StableDiffusion:
|
|||||||
state_dict[new_key] = v
|
state_dict[new_key] = v
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
|
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \
|
||||||
str, Parameter]:
|
OrderedDict[
|
||||||
|
str, Parameter]:
|
||||||
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
||||||
if vae:
|
if vae:
|
||||||
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
|
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
|
||||||
@@ -1198,7 +1213,8 @@ class StableDiffusion:
|
|||||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||||
|
|
||||||
if refiner:
|
if refiner:
|
||||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
|
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True,
|
||||||
|
state_dict_keys=True)
|
||||||
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
|
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
|
||||||
params = []
|
params = []
|
||||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||||
|
|||||||
@@ -537,6 +537,7 @@ def encode_prompts_xl(
|
|||||||
use_text_encoder_2: bool = True, # sdxl
|
use_text_encoder_2: bool = True, # sdxl
|
||||||
truncate: bool = True,
|
truncate: bool = True,
|
||||||
max_length=None,
|
max_length=None,
|
||||||
|
dropout_prob=0.0,
|
||||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
# text_encoder and text_encoder_2's penuultimate layer's output
|
# text_encoder and text_encoder_2's penuultimate layer's output
|
||||||
text_embeds_list = []
|
text_embeds_list = []
|
||||||
@@ -553,6 +554,12 @@ def encode_prompts_xl(
|
|||||||
if idx == 1 and not use_text_encoder_2:
|
if idx == 1 and not use_text_encoder_2:
|
||||||
prompt_list_to_use = ["" for _ in prompts]
|
prompt_list_to_use = ["" for _ in prompts]
|
||||||
|
|
||||||
|
if dropout_prob > 0.0:
|
||||||
|
# randomly drop out prompts
|
||||||
|
prompt_list_to_use = [
|
||||||
|
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use
|
||||||
|
]
|
||||||
|
|
||||||
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
|
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
|
||||||
# set the max length for the next one
|
# set the max length for the next one
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
@@ -598,9 +605,17 @@ def encode_prompts(
|
|||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
truncate: bool = True,
|
truncate: bool = True,
|
||||||
max_length=None,
|
max_length=None,
|
||||||
|
dropout_prob=0.0,
|
||||||
):
|
):
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
max_length = tokenizer.model_max_length
|
max_length = tokenizer.model_max_length
|
||||||
|
|
||||||
|
if dropout_prob > 0.0:
|
||||||
|
# randomly drop out prompts
|
||||||
|
prompts = [
|
||||||
|
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
|
||||||
|
]
|
||||||
|
|
||||||
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
|
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
|
||||||
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
|
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user