mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added refiner fine tuning. Works, but needs some polish.
This commit is contained in:
@@ -382,8 +382,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
adapter_strength_max = 1.0
|
adapter_strength_max = 1.0
|
||||||
else:
|
else:
|
||||||
# training with assistance, we want it low
|
# training with assistance, we want it low
|
||||||
adapter_strength_min = 0.5
|
adapter_strength_min = 0.4
|
||||||
adapter_strength_max = 0.8
|
adapter_strength_max = 0.7
|
||||||
# adapter_strength_min = 0.9
|
# adapter_strength_min = 0.9
|
||||||
# adapter_strength_max = 1.1
|
# adapter_strength_max = 1.1
|
||||||
|
|
||||||
@@ -431,6 +431,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
# make the batch splits
|
# make the batch splits
|
||||||
if self.train_config.single_item_batching:
|
if self.train_config.single_item_batching:
|
||||||
|
if self.model_config.refiner_name_or_path is not None:
|
||||||
|
raise ValueError("Single item batching is not supported when training the refiner")
|
||||||
batch_size = noisy_latents.shape[0]
|
batch_size = noisy_latents.shape[0]
|
||||||
# chunk/split everything
|
# chunk/split everything
|
||||||
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
||||||
@@ -452,7 +454,6 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
prompt_2_list = [[prompt] for prompt in prompts_2]
|
prompt_2_list = [[prompt] for prompt in prompts_2]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# but it all in an array
|
|
||||||
noisy_latents_list = [noisy_latents]
|
noisy_latents_list = [noisy_latents]
|
||||||
noise_list = [noise]
|
noise_list = [noise]
|
||||||
timesteps_list = [timesteps]
|
timesteps_list = [timesteps]
|
||||||
@@ -603,8 +604,13 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# apply gradients
|
# apply gradients
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad(set_to_none=True)
|
self.optimizer.zero_grad(set_to_none=True)
|
||||||
with self.timer('scheduler_step'):
|
else:
|
||||||
self.lr_scheduler.step()
|
# gradient accumulation. Just a place for breakpoint
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
|
||||||
|
with self.timer('scheduler_step'):
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
if self.embedding is not None:
|
if self.embedding is not None:
|
||||||
with self.timer('restore_embeddings'):
|
with self.timer('restore_embeddings'):
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
train_lora=self.network_config is not None,
|
train_lora=self.network_config is not None,
|
||||||
train_adapter=is_training_adapter,
|
train_adapter=is_training_adapter,
|
||||||
train_embedding=self.embed_config is not None,
|
train_embedding=self.embed_config is not None,
|
||||||
|
train_refiner=self.train_config.train_refiner,
|
||||||
)
|
)
|
||||||
|
|
||||||
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
||||||
@@ -382,16 +383,29 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
file_path = file_path.replace('.safetensors', '')
|
file_path = file_path.replace('.safetensors', '')
|
||||||
# convert it back to normal object
|
# convert it back to normal object
|
||||||
save_meta = parse_metadata_from_safetensors(save_meta)
|
save_meta = parse_metadata_from_safetensors(save_meta)
|
||||||
self.sd.save(
|
|
||||||
file_path,
|
if self.sd.refiner_unet and self.train_config.train_refiner:
|
||||||
save_meta,
|
# save refiner
|
||||||
get_torch_dtype(self.save_config.dtype)
|
refiner_name = self.job.name + '_refiner'
|
||||||
)
|
filename = f'{refiner_name}{step_num}.safetensors'
|
||||||
|
file_path = os.path.join(self.save_root, filename)
|
||||||
|
self.sd.save_refiner(
|
||||||
|
file_path,
|
||||||
|
save_meta,
|
||||||
|
get_torch_dtype(self.save_config.dtype)
|
||||||
|
)
|
||||||
|
if self.train_config.train_unet or self.train_config.train_text_encoder:
|
||||||
|
self.sd.save(
|
||||||
|
file_path,
|
||||||
|
save_meta,
|
||||||
|
get_torch_dtype(self.save_config.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
# save learnable params as json if we have thim
|
# save learnable params as json if we have thim
|
||||||
if self.snr_gos:
|
if self.snr_gos:
|
||||||
json_data = {
|
json_data = {
|
||||||
'offset': self.snr_gos.offset.item(),
|
'offset_1': self.snr_gos.offset_1.item(),
|
||||||
|
'offset_2': self.snr_gos.offset_2.item(),
|
||||||
'scale': self.snr_gos.scale.item(),
|
'scale': self.snr_gos.scale.item(),
|
||||||
'gamma': self.snr_gos.gamma.item(),
|
'gamma': self.snr_gos.gamma.item(),
|
||||||
}
|
}
|
||||||
@@ -447,7 +461,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# Filter out non-existent paths and sort by creation time
|
# Filter out non-existent paths and sort by creation time
|
||||||
if paths:
|
if paths:
|
||||||
paths = [p for p in paths if os.path.exists(p)]
|
paths = [p for p in paths if os.path.exists(p)]
|
||||||
latest_path = max(paths, key=os.path.getctime)
|
# remove false positives
|
||||||
|
if '_LoRA' not in name:
|
||||||
|
paths = [p for p in paths if '_LoRA' not in p]
|
||||||
|
if '_refiner' not in name:
|
||||||
|
paths = [p for p in paths if '_refiner' not in p]
|
||||||
|
if '_t2i' not in name:
|
||||||
|
paths = [p for p in paths if '_t2i' not in p]
|
||||||
|
|
||||||
|
if len(paths) > 0:
|
||||||
|
latest_path = max(paths, key=os.path.getctime)
|
||||||
|
|
||||||
return latest_path
|
return latest_path
|
||||||
|
|
||||||
@@ -540,6 +563,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# double batch and add short captions to the end
|
# double batch and add short captions to the end
|
||||||
prompts = prompts + batch.get_caption_short_list()
|
prompts = prompts + batch.get_caption_short_list()
|
||||||
is_reg_list = is_reg_list + is_reg_list
|
is_reg_list = is_reg_list + is_reg_list
|
||||||
|
if self.model_config.refiner_name_or_path is not None and self.train_config.train_unet:
|
||||||
|
prompts = prompts + prompts
|
||||||
|
is_reg_list = is_reg_list + is_reg_list
|
||||||
|
|
||||||
conditioned_prompts = []
|
conditioned_prompts = []
|
||||||
|
|
||||||
@@ -587,6 +613,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||||
|
|
||||||
batch_size = len(batch.file_items)
|
batch_size = len(batch.file_items)
|
||||||
|
min_noise_steps = self.train_config.min_denoising_steps
|
||||||
|
max_noise_steps = self.train_config.max_denoising_steps
|
||||||
|
if self.model_config.refiner_name_or_path is not None:
|
||||||
|
# if we are not training the unet, then we are only doing refiner and do not need to double up
|
||||||
|
if self.train_config.train_unet:
|
||||||
|
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||||
|
do_double = True
|
||||||
|
else:
|
||||||
|
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||||
|
do_double = False
|
||||||
|
|
||||||
with self.timer('prepare_noise'):
|
with self.timer('prepare_noise'):
|
||||||
|
|
||||||
@@ -615,18 +651,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
timesteps,
|
timesteps,
|
||||||
0,
|
0,
|
||||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||||
self.train_config.min_denoising_steps,
|
min_noise_steps,
|
||||||
self.train_config.max_denoising_steps
|
max_noise_steps
|
||||||
)
|
)
|
||||||
timesteps = timesteps.long().clamp(
|
timesteps = timesteps.long().clamp(
|
||||||
self.train_config.min_denoising_steps + 1,
|
min_noise_steps + 1,
|
||||||
self.train_config.max_denoising_steps - 1
|
max_noise_steps - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.train_config.content_or_style == 'balanced':
|
elif self.train_config.content_or_style == 'balanced':
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
self.train_config.min_denoising_steps,
|
min_noise_steps,
|
||||||
self.train_config.max_denoising_steps,
|
max_noise_steps,
|
||||||
(batch_size,),
|
(batch_size,),
|
||||||
device=self.device_torch
|
device=self.device_torch
|
||||||
)
|
)
|
||||||
@@ -678,9 +714,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
return torch.cat([tensor, tensor], dim=0)
|
return torch.cat([tensor, tensor], dim=0)
|
||||||
|
|
||||||
if do_double:
|
if do_double:
|
||||||
noisy_latents = double_up_tensor(noisy_latents)
|
if self.model_config.refiner_name_or_path:
|
||||||
|
# apply refiner double up
|
||||||
|
refiner_timesteps = torch.randint(
|
||||||
|
max_noise_steps,
|
||||||
|
self.train_config.max_denoising_steps,
|
||||||
|
(batch_size,),
|
||||||
|
device=self.device_torch
|
||||||
|
)
|
||||||
|
refiner_timesteps = refiner_timesteps.long()
|
||||||
|
# add our new timesteps on to end
|
||||||
|
timesteps = torch.cat([timesteps, refiner_timesteps], dim=0)
|
||||||
|
|
||||||
|
refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps)
|
||||||
|
noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# just double it
|
||||||
|
noisy_latents = double_up_tensor(noisy_latents)
|
||||||
|
timesteps = double_up_tensor(timesteps)
|
||||||
|
|
||||||
noise = double_up_tensor(noise)
|
noise = double_up_tensor(noise)
|
||||||
timesteps = double_up_tensor(timesteps)
|
|
||||||
# prompts are already updated above
|
# prompts are already updated above
|
||||||
imgs = double_up_tensor(imgs)
|
imgs = double_up_tensor(imgs)
|
||||||
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
|
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
|
||||||
@@ -772,6 +826,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# get the noise scheduler
|
# get the noise scheduler
|
||||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||||
|
|
||||||
|
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
||||||
|
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
|
||||||
|
if previous_refiner_save is not None:
|
||||||
|
model_config_to_load.refiner_name_or_path = previous_refiner_save
|
||||||
|
self.load_training_state_from_metadata(previous_refiner_save)
|
||||||
|
|
||||||
self.sd = StableDiffusion(
|
self.sd = StableDiffusion(
|
||||||
device=self.device,
|
device=self.device,
|
||||||
model_config=model_config_to_load,
|
model_config=model_config_to_load,
|
||||||
@@ -818,6 +878,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if hasattr(text_encoder, "gradient_checkpointing_enable"):
|
if hasattr(text_encoder, "gradient_checkpointing_enable"):
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
if self.sd.refiner_unet is not None:
|
||||||
|
self.sd.refiner_unet.to(self.device_torch, dtype=dtype)
|
||||||
|
self.sd.refiner_unet.requires_grad_(False)
|
||||||
|
self.sd.refiner_unet.eval()
|
||||||
|
if self.train_config.xformers:
|
||||||
|
self.sd.refiner_unet.enable_xformers_memory_efficient_attention()
|
||||||
|
if self.train_config.gradient_checkpointing:
|
||||||
|
self.sd.refiner_unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
if isinstance(text_encoder, list):
|
if isinstance(text_encoder, list):
|
||||||
for te in text_encoder:
|
for te in text_encoder:
|
||||||
te.requires_grad_(False)
|
te.requires_grad_(False)
|
||||||
@@ -840,7 +909,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if os.path.exists(path_to_load):
|
if os.path.exists(path_to_load):
|
||||||
with open(path_to_load, 'r') as f:
|
with open(path_to_load, 'r') as f:
|
||||||
json_data = json.load(f)
|
json_data = json.load(f)
|
||||||
self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch)
|
if 'offset' in json_data:
|
||||||
|
# legacy
|
||||||
|
self.snr_gos.offset_2.data = torch.tensor(json_data['offset'], device=self.device_torch)
|
||||||
|
else:
|
||||||
|
self.snr_gos.offset_1.data = torch.tensor(json_data['offset_1'], device=self.device_torch)
|
||||||
|
self.snr_gos.offset_2.data = torch.tensor(json_data['offset_2'], device=self.device_torch)
|
||||||
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||||
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||||
|
|
||||||
@@ -1018,7 +1092,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
text_encoder=self.train_config.train_text_encoder,
|
text_encoder=self.train_config.train_text_encoder,
|
||||||
text_encoder_lr=self.train_config.lr,
|
text_encoder_lr=self.train_config.lr,
|
||||||
unet_lr=self.train_config.lr,
|
unet_lr=self.train_config.lr,
|
||||||
default_lr=self.train_config.lr
|
default_lr=self.train_config.lr,
|
||||||
|
refiner=self.train_config.train_refiner and self.sd.refiner_unet is not None,
|
||||||
|
refiner_lr=self.train_config.refiner_lr,
|
||||||
)
|
)
|
||||||
# we may be using it for prompt injections
|
# we may be using it for prompt injections
|
||||||
if self.adapter_config is not None:
|
if self.adapter_config is not None:
|
||||||
@@ -1158,6 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
else:
|
else:
|
||||||
batch = None
|
batch = None
|
||||||
|
|
||||||
|
# if we are doing a reg step, always accumulate
|
||||||
|
if is_reg_step:
|
||||||
|
self.is_grad_accumulation_step = True
|
||||||
|
|
||||||
# setup accumulation
|
# setup accumulation
|
||||||
if self.train_config.gradient_accumulation_steps == -1:
|
if self.train_config.gradient_accumulation_steps == -1:
|
||||||
# epoch is handling the accumulation, dont touch it
|
# epoch is handling the accumulation, dont touch it
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import os
|
|||||||
# add project root to sys path
|
# add project root to sys path
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -50,6 +52,7 @@ parser.add_argument(
|
|||||||
|
|
||||||
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
|
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
|
||||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||||
|
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||||
|
|
||||||
@@ -61,29 +64,68 @@ find_matches = False
|
|||||||
|
|
||||||
print(f'Loading diffusers model')
|
print(f'Loading diffusers model')
|
||||||
|
|
||||||
|
ignore_ldm_begins_with = []
|
||||||
|
|
||||||
diffusers_file_path = file_path
|
diffusers_file_path = file_path
|
||||||
if args.ssd:
|
if args.ssd:
|
||||||
diffusers_file_path = "segmind/SSD-1B"
|
diffusers_file_path = "segmind/SSD-1B"
|
||||||
|
|
||||||
diffusers_model_config = ModelConfig(
|
if args.refiner:
|
||||||
name_or_path=diffusers_file_path,
|
diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||||
is_xl=args.sdxl,
|
|
||||||
is_v2=args.sd2,
|
if not args.refiner:
|
||||||
is_ssd=args.ssd,
|
|
||||||
dtype=dtype,
|
diffusers_model_config = ModelConfig(
|
||||||
)
|
name_or_path=diffusers_file_path,
|
||||||
diffusers_sd = StableDiffusion(
|
is_xl=args.sdxl,
|
||||||
model_config=diffusers_model_config,
|
is_v2=args.sd2,
|
||||||
device=device,
|
is_ssd=args.ssd,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
diffusers_sd.load_model()
|
diffusers_sd = StableDiffusion(
|
||||||
# delete things we dont need
|
model_config=diffusers_model_config,
|
||||||
del diffusers_sd.tokenizer
|
device=device,
|
||||||
flush()
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
diffusers_sd.load_model()
|
||||||
|
# delete things we dont need
|
||||||
|
del diffusers_sd.tokenizer
|
||||||
|
flush()
|
||||||
|
|
||||||
|
print(f'Loading ldm model')
|
||||||
|
diffusers_state_dict = diffusers_sd.state_dict()
|
||||||
|
else:
|
||||||
|
# refiner wont work directly with stable diffusion
|
||||||
|
# so we need to load the model and then load the state dict
|
||||||
|
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||||
|
file_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
use_safetensors=True,
|
||||||
|
variant="fp16",
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
SD_PREFIX_VAE = "vae"
|
||||||
|
SD_PREFIX_UNET = "unet"
|
||||||
|
SD_PREFIX_REFINER_UNET = "refiner_unet"
|
||||||
|
SD_PREFIX_TEXT_ENCODER = "te"
|
||||||
|
|
||||||
|
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
||||||
|
SD_PREFIX_TEXT_ENCODER2 = "te1"
|
||||||
|
|
||||||
|
diffusers_state_dict = OrderedDict()
|
||||||
|
for k, v in diffusers_pipeline.vae.state_dict().items():
|
||||||
|
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
||||||
|
diffusers_state_dict[new_key] = v
|
||||||
|
for k, v in diffusers_pipeline.text_encoder_2.state_dict().items():
|
||||||
|
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
|
||||||
|
diffusers_state_dict[new_key] = v
|
||||||
|
for k, v in diffusers_pipeline.unet.state_dict().items():
|
||||||
|
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
||||||
|
diffusers_state_dict[new_key] = v
|
||||||
|
|
||||||
|
# add ignore ones as we are only going to focus on unet and copy the rest
|
||||||
|
# ignore_ldm_begins_with = ["conditioner.", "first_stage_model."]
|
||||||
|
|
||||||
print(f'Loading ldm model')
|
|
||||||
diffusers_state_dict = diffusers_sd.state_dict()
|
|
||||||
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
||||||
|
|
||||||
ldm_state_dict = load_file(file_path)
|
ldm_state_dict = load_file(file_path)
|
||||||
@@ -113,6 +155,12 @@ if args.sdxl or args.ssd:
|
|||||||
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||||
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||||
text_proj_layer = "conditioner.embedders.1.model.text_projection"
|
text_proj_layer = "conditioner.embedders.1.model.text_projection"
|
||||||
|
if args.refiner:
|
||||||
|
te_suffix = '1'
|
||||||
|
ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks"
|
||||||
|
proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||||
|
proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||||
|
text_proj_layer = "conditioner.embedders.0.model.text_projection"
|
||||||
if args.sd2:
|
if args.sd2:
|
||||||
te_suffix = ''
|
te_suffix = ''
|
||||||
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
|
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
|
||||||
@@ -120,7 +168,7 @@ if args.sd2:
|
|||||||
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||||
text_proj_layer = "cond_stage_model.model.text_projection"
|
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||||
|
|
||||||
if args.sdxl or args.sd2 or args.ssd:
|
if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
||||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||||
@@ -297,6 +345,8 @@ if args.sdxl:
|
|||||||
name += '_sdxl'
|
name += '_sdxl'
|
||||||
elif args.ssd:
|
elif args.ssd:
|
||||||
name += '_ssd'
|
name += '_ssd'
|
||||||
|
elif args.refiner:
|
||||||
|
name += '_refiner'
|
||||||
elif args.sd2:
|
elif args.sd2:
|
||||||
name += '_sd2'
|
name += '_sd2'
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ class TrainConfig:
|
|||||||
self.lr = kwargs.get('lr', 1e-6)
|
self.lr = kwargs.get('lr', 1e-6)
|
||||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||||
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
|
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
|
||||||
|
self.refiner_lr = kwargs.get('refiner_lr', self.lr)
|
||||||
self.embedding_lr = kwargs.get('embedding_lr', self.lr)
|
self.embedding_lr = kwargs.get('embedding_lr', self.lr)
|
||||||
self.adapter_lr = kwargs.get('adapter_lr', self.lr)
|
self.adapter_lr = kwargs.get('adapter_lr', self.lr)
|
||||||
self.optimizer = kwargs.get('optimizer', 'adamw')
|
self.optimizer = kwargs.get('optimizer', 'adamw')
|
||||||
@@ -174,6 +175,7 @@ class TrainConfig:
|
|||||||
self.sdp = kwargs.get('sdp', False)
|
self.sdp = kwargs.get('sdp', False)
|
||||||
self.train_unet = kwargs.get('train_unet', True)
|
self.train_unet = kwargs.get('train_unet', True)
|
||||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||||
|
self.train_refiner = kwargs.get('train_refiner', True)
|
||||||
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||||
self.snr_gamma = kwargs.get('snr_gamma', None)
|
self.snr_gamma = kwargs.get('snr_gamma', None)
|
||||||
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
|
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
|
||||||
@@ -238,6 +240,8 @@ class ModelConfig:
|
|||||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||||
self.vae_path = kwargs.get('vae_path', None)
|
self.vae_path = kwargs.get('vae_path', None)
|
||||||
|
self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None)
|
||||||
|
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
|
||||||
|
|
||||||
# only for SDXL models for now
|
# only for SDXL models for now
|
||||||
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
|
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
|
||||||
|
|||||||
3499
toolkit/keymaps/stable_diffusion_refiner.json
Normal file
3499
toolkit/keymaps/stable_diffusion_refiner.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors
Normal file
BIN
toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors
Normal file
Binary file not shown.
10
toolkit/keymaps/stable_diffusion_refiner_unmatched.json
Normal file
10
toolkit/keymaps/stable_diffusion_refiner_unmatched.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"ldm": {
|
||||||
|
"conditioner.embedders.0.model.logit_scale": {
|
||||||
|
"shape": [],
|
||||||
|
"min": 4.60546875,
|
||||||
|
"max": 4.60546875
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"diffusers": {}
|
||||||
|
}
|
||||||
91
toolkit/orig_configs/sd_xl_refiner.yaml
Normal file
91
toolkit/orig_configs/sd_xl_refiner.yaml
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.13025
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||||
|
params:
|
||||||
|
num_idx: 1000
|
||||||
|
|
||||||
|
weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
adm_in_channels: 2560
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 384
|
||||||
|
attention_resolutions: [4, 2]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 4
|
||||||
|
context_dim: [1280, 1280, 1280, 1280] # 1280
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
# crossattn and vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: txt
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||||
|
params:
|
||||||
|
arch: ViT-bigG-14
|
||||||
|
version: laion2b_s39b_b160k
|
||||||
|
legacy: False
|
||||||
|
freeze: True
|
||||||
|
layer: penultimate
|
||||||
|
always_return_pooled: True
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: original_size_as_tuple
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: crop_coords_top_left
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: aesthetic_score
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by one
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
@@ -5,6 +5,8 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
|||||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
||||||
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
||||||
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
||||||
|
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
||||||
|
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
||||||
|
|
||||||
# check if ENV variable is set
|
# check if ENV variable is set
|
||||||
if 'MODELS_PATH' in os.environ:
|
if 'MODELS_PATH' in os.environ:
|
||||||
|
|||||||
@@ -8,10 +8,18 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
|
||||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||||
|
from diffusers.utils import is_torch_xla_available
|
||||||
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||||
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_xla_available():
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
XLA_AVAILABLE = True
|
||||||
|
else:
|
||||||
|
XLA_AVAILABLE = False
|
||||||
|
|
||||||
class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
|
class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -807,3 +815,389 @@ class CustomStableDiffusionPipeline(StableDiffusionPipeline):
|
|||||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline):
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
denoising_end: Optional[float] = None,
|
||||||
|
denoising_start: Optional[float] = None,
|
||||||
|
guidance_scale: float = 5.0,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
eta: float = 0.0,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||||
|
callback_steps: int = 1,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
original_size: Optional[Tuple[int, int]] = None,
|
||||||
|
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||||
|
target_size: Optional[Tuple[int, int]] = None,
|
||||||
|
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||||
|
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||||
|
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||||
|
clip_skip: Optional[int] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
prompt_2 (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||||
|
used in both text-encoders
|
||||||
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||||
|
Anything below 512 pixels won't work well for
|
||||||
|
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||||
|
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||||
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||||
|
Anything below 512 pixels won't work well for
|
||||||
|
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||||
|
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
denoising_end (`float`, *optional*):
|
||||||
|
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||||
|
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||||
|
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
||||||
|
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
||||||
|
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||||
|
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
||||||
|
denoising_start (`float`, *optional*):
|
||||||
|
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||||
|
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
|
||||||
|
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
|
||||||
|
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
|
||||||
|
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
|
||||||
|
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
|
||||||
|
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||||
|
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||||
|
to make generation deterministic.
|
||||||
|
latents (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||||
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||||
|
input argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||||
|
of a plain tuple.
|
||||||
|
callback (`Callable`, *optional*):
|
||||||
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||||
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||||
|
callback_steps (`int`, *optional*, defaults to 1):
|
||||||
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||||
|
called at every step.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||||
|
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||||
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||||
|
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
|
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||||
|
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||||
|
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||||
|
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||||
|
explained in section 2.2 of
|
||||||
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||||
|
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||||
|
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
||||||
|
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
||||||
|
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||||
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||||
|
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||||
|
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||||
|
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
||||||
|
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||||
|
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||||
|
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||||
|
micro-conditioning as explained in section 2.2 of
|
||||||
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||||
|
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||||
|
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||||
|
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||||
|
micro-conditioning as explained in section 2.2 of
|
||||||
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||||
|
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||||
|
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||||
|
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||||
|
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||||
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||||
|
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||||
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||||
|
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||||
|
"""
|
||||||
|
# 0. Default height and width to unet
|
||||||
|
height = height or self.default_sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.default_sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
original_size = original_size or (height, width)
|
||||||
|
target_size = target_size or (height, width)
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
prompt_2,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
callback_steps,
|
||||||
|
negative_prompt,
|
||||||
|
negative_prompt_2,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
# 3. Encode input prompt
|
||||||
|
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||||
|
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_2=prompt_2,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
negative_prompt_2=negative_prompt_2,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||||
|
lora_scale=lora_scale,
|
||||||
|
clip_skip=clip_skip,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
num_channels_latents = self.unet.config.in_channels
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
# 7. Prepare added time ids & embeddings
|
||||||
|
add_text_embeds = pooled_prompt_embeds
|
||||||
|
if self.text_encoder_2 is None:
|
||||||
|
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||||
|
else:
|
||||||
|
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||||
|
|
||||||
|
add_time_ids = self._get_add_time_ids(
|
||||||
|
original_size,
|
||||||
|
crops_coords_top_left,
|
||||||
|
target_size,
|
||||||
|
dtype=prompt_embeds.dtype,
|
||||||
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||||
|
)
|
||||||
|
if negative_original_size is not None and negative_target_size is not None:
|
||||||
|
negative_add_time_ids = self._get_add_time_ids(
|
||||||
|
negative_original_size,
|
||||||
|
negative_crops_coords_top_left,
|
||||||
|
negative_target_size,
|
||||||
|
dtype=prompt_embeds.dtype,
|
||||||
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
negative_add_time_ids = add_time_ids
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
|
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||||
|
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||||
|
|
||||||
|
prompt_embeds = prompt_embeds.to(device)
|
||||||
|
add_text_embeds = add_text_embeds.to(device)
|
||||||
|
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||||
|
|
||||||
|
# 8. Denoising loop
|
||||||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
|
|
||||||
|
# 8.1 Apply denoising_end
|
||||||
|
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||||
|
discrete_timestep_cutoff = int(
|
||||||
|
round(
|
||||||
|
self.scheduler.config.num_train_timesteps
|
||||||
|
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||||
|
timesteps = timesteps[:num_inference_steps]
|
||||||
|
|
||||||
|
# 8.2 Determine denoising_start
|
||||||
|
denoising_start_index = 0
|
||||||
|
if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1:
|
||||||
|
discrete_timestep_start = int(
|
||||||
|
round(
|
||||||
|
self.scheduler.config.num_train_timesteps
|
||||||
|
- (denoising_start * self.scheduler.config.num_train_timesteps)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps)))
|
||||||
|
|
||||||
|
|
||||||
|
with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps, start=denoising_start_index):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||||
|
noise_pred = self.unet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||||
|
callback(step_idx, t, latents)
|
||||||
|
|
||||||
|
if XLA_AVAILABLE:
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
if not output_type == "latent":
|
||||||
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||||
|
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||||
|
|
||||||
|
if needs_upcasting:
|
||||||
|
self.upcast_vae()
|
||||||
|
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||||
|
|
||||||
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||||
|
|
||||||
|
# cast back to fp16 if needed
|
||||||
|
if needs_upcasting:
|
||||||
|
self.vae.to(dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
image = latents
|
||||||
|
|
||||||
|
if not output_type == "latent":
|
||||||
|
# apply watermark if available
|
||||||
|
if self.watermark is not None:
|
||||||
|
image = self.watermark.apply_watermark(image)
|
||||||
|
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||||
|
|
||||||
|
# Offload all models
|
||||||
|
self.maybe_free_model_hooks()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return StableDiffusionXLPipelineOutput(images=image)
|
||||||
|
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
|
|||||||
|
|
||||||
def get_ldm_state_dict_from_diffusers(
|
def get_ldm_state_dict_from_diffusers(
|
||||||
state_dict: 'OrderedDict',
|
state_dict: 'OrderedDict',
|
||||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2',
|
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
|
||||||
device='cpu',
|
device='cpu',
|
||||||
dtype=get_torch_dtype('fp32'),
|
dtype=get_torch_dtype('fp32'),
|
||||||
):
|
):
|
||||||
@@ -115,6 +115,10 @@ def get_ldm_state_dict_from_diffusers(
|
|||||||
# load our base
|
# load our base
|
||||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
||||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
||||||
|
elif sd_version == 'sdxl_refiner':
|
||||||
|
# load our base
|
||||||
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
|
||||||
|
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json')
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,11 @@ empty_preset = {
|
|||||||
'requires_grad': False,
|
'requires_grad': False,
|
||||||
'device': 'cpu',
|
'device': 'cpu',
|
||||||
},
|
},
|
||||||
|
'refiner_unet': {
|
||||||
|
'training': False,
|
||||||
|
'requires_grad': False,
|
||||||
|
'device': 'cpu',
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -34,6 +39,7 @@ def get_train_sd_device_state_preset(
|
|||||||
train_lora: bool = False,
|
train_lora: bool = False,
|
||||||
train_adapter: bool = False,
|
train_adapter: bool = False,
|
||||||
train_embedding: bool = False,
|
train_embedding: bool = False,
|
||||||
|
train_refiner: bool = False,
|
||||||
):
|
):
|
||||||
preset = copy.deepcopy(empty_preset)
|
preset = copy.deepcopy(empty_preset)
|
||||||
if not cached_latents:
|
if not cached_latents:
|
||||||
@@ -59,9 +65,16 @@ def get_train_sd_device_state_preset(
|
|||||||
preset['text_encoder']['training'] = True
|
preset['text_encoder']['training'] = True
|
||||||
preset['unet']['training'] = True
|
preset['unet']['training'] = True
|
||||||
|
|
||||||
|
if train_refiner:
|
||||||
|
preset['refiner_unet']['training'] = True
|
||||||
|
preset['refiner_unet']['requires_grad'] = True
|
||||||
|
preset['refiner_unet']['device'] = device
|
||||||
|
|
||||||
if train_lora:
|
if train_lora:
|
||||||
# preset['text_encoder']['requires_grad'] = False
|
# preset['text_encoder']['requires_grad'] = False
|
||||||
preset['unet']['requires_grad'] = False
|
preset['unet']['requires_grad'] = False
|
||||||
|
if train_refiner:
|
||||||
|
preset['refiner_unet']['requires_grad'] = False
|
||||||
|
|
||||||
if train_adapter:
|
if train_adapter:
|
||||||
preset['adapter']['requires_grad'] = True
|
preset['adapter']['requires_grad'] = True
|
||||||
|
|||||||
@@ -26,21 +26,28 @@ from toolkit.metadata import get_meta_for_safetensors
|
|||||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||||
from toolkit.sampler import get_sampler
|
from toolkit.sampler import get_sampler
|
||||||
from toolkit.saving import save_ldm_model_from_diffusers
|
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
||||||
from toolkit.sd_device_states_presets import empty_preset
|
from toolkit.sd_device_states_presets import empty_preset
|
||||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||||
import torch
|
import torch
|
||||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||||
StableDiffusionKDiffusionXLPipeline
|
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||||
|
StableDiffusionXLImg2ImgPipeline
|
||||||
import diffusers
|
import diffusers
|
||||||
|
from diffusers import \
|
||||||
|
AutoencoderKL, \
|
||||||
|
UNet2DConditionModel
|
||||||
|
|
||||||
|
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||||
|
|
||||||
# tell it to shut up
|
# tell it to shut up
|
||||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||||
|
|
||||||
SD_PREFIX_VAE = "vae"
|
SD_PREFIX_VAE = "vae"
|
||||||
SD_PREFIX_UNET = "unet"
|
SD_PREFIX_UNET = "unet"
|
||||||
|
SD_PREFIX_REFINER_UNET = "refiner_unet"
|
||||||
SD_PREFIX_TEXT_ENCODER = "te"
|
SD_PREFIX_TEXT_ENCODER = "te"
|
||||||
|
|
||||||
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
||||||
@@ -52,6 +59,10 @@ DO_NOT_TRAIN_WEIGHTS = [
|
|||||||
"unet_time_embedding.linear_1.weight",
|
"unet_time_embedding.linear_1.weight",
|
||||||
"unet_time_embedding.linear_2.bias",
|
"unet_time_embedding.linear_2.bias",
|
||||||
"unet_time_embedding.linear_2.weight",
|
"unet_time_embedding.linear_2.weight",
|
||||||
|
"refiner_unet_time_embedding.linear_1.bias",
|
||||||
|
"refiner_unet_time_embedding.linear_1.weight",
|
||||||
|
"refiner_unet_time_embedding.linear_2.bias",
|
||||||
|
"refiner_unet_time_embedding.linear_2.weight",
|
||||||
]
|
]
|
||||||
|
|
||||||
DeviceStatePreset = Literal['cache_latents', 'generate']
|
DeviceStatePreset = Literal['cache_latents', 'generate']
|
||||||
@@ -81,10 +92,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
|
|||||||
|
|
||||||
# if is type checking
|
# if is type checking
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from diffusers import \
|
|
||||||
StableDiffusionPipeline, \
|
|
||||||
AutoencoderKL, \
|
|
||||||
UNet2DConditionModel
|
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||||
|
|
||||||
@@ -116,6 +123,8 @@ class StableDiffusion:
|
|||||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||||
|
|
||||||
|
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
|
||||||
|
|
||||||
# sdxl stuff
|
# sdxl stuff
|
||||||
self.logit_scale = None
|
self.logit_scale = None
|
||||||
self.ckppt_info = None
|
self.ckppt_info = None
|
||||||
@@ -214,7 +223,7 @@ class StableDiffusion:
|
|||||||
pipln = StableDiffusionPipeline
|
pipln = StableDiffusionPipeline
|
||||||
|
|
||||||
# see if path exists
|
# see if path exists
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||||
# try to load with default diffusers
|
# try to load with default diffusers
|
||||||
pipe = pipln.from_pretrained(
|
pipe = pipln.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
@@ -263,10 +272,47 @@ class StableDiffusion:
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.text_encoder = text_encoder
|
self.text_encoder = text_encoder
|
||||||
self.pipeline = pipe
|
self.pipeline = pipe
|
||||||
|
self.load_refiner()
|
||||||
self.is_loaded = True
|
self.is_loaded = True
|
||||||
|
|
||||||
|
def load_refiner(self):
|
||||||
|
# for now, we are just going to rely on the TE from the base model
|
||||||
|
# which is TE2 for SDXL and TE for SD (no refiner currently)
|
||||||
|
# and completely ignore a TE that may or may not be packaged with the refiner
|
||||||
|
if self.model_config.refiner_name_or_path is not None:
|
||||||
|
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
|
||||||
|
# load the refiner model
|
||||||
|
dtype = get_torch_dtype(self.dtype)
|
||||||
|
model_path = self.model_config.refiner_name_or_path
|
||||||
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||||
|
# TODO only load unet??
|
||||||
|
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device_torch,
|
||||||
|
variant="fp16",
|
||||||
|
use_safetensors=True,
|
||||||
|
).to(self.device_torch)
|
||||||
|
else:
|
||||||
|
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||||
|
model_path,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device_torch,
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
original_config_file=refiner_config_path,
|
||||||
|
).to(self.device_torch)
|
||||||
|
|
||||||
|
self.refiner_unet = refiner.unet
|
||||||
|
del refiner
|
||||||
|
flush()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
|
def generate_images(
|
||||||
|
self,
|
||||||
|
image_configs: List[GenerateImageConfig],
|
||||||
|
sampler=None,
|
||||||
|
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
|
||||||
|
):
|
||||||
merge_multiplier = 1.0
|
merge_multiplier = 1.0
|
||||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
@@ -289,65 +335,85 @@ class StableDiffusion:
|
|||||||
rng_state = torch.get_rng_state()
|
rng_state = torch.get_rng_state()
|
||||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||||
|
|
||||||
noise_scheduler = self.noise_scheduler
|
if pipeline is None:
|
||||||
if sampler is not None:
|
noise_scheduler = self.noise_scheduler
|
||||||
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
if sampler is not None:
|
||||||
# using ksampler
|
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
||||||
noise_scheduler = get_sampler('lms')
|
# using ksampler
|
||||||
else:
|
noise_scheduler = get_sampler('lms')
|
||||||
noise_scheduler = get_sampler(sampler)
|
|
||||||
|
|
||||||
if sampler.startswith("sample_") and self.is_xl:
|
|
||||||
# using kdiffusion
|
|
||||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
|
||||||
elif self.is_xl:
|
|
||||||
Pipe = StableDiffusionXLPipeline
|
|
||||||
else:
|
|
||||||
Pipe = StableDiffusionPipeline
|
|
||||||
|
|
||||||
extra_args = {}
|
|
||||||
if self.adapter is not None:
|
|
||||||
if isinstance(self.adapter, T2IAdapter):
|
|
||||||
if self.is_xl:
|
|
||||||
Pipe = StableDiffusionXLAdapterPipeline
|
|
||||||
else:
|
else:
|
||||||
Pipe = StableDiffusionAdapterPipeline
|
noise_scheduler = get_sampler(sampler)
|
||||||
extra_args['adapter'] = self.adapter
|
|
||||||
|
if sampler.startswith("sample_") and self.is_xl:
|
||||||
|
# using kdiffusion
|
||||||
|
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||||
|
elif self.is_xl:
|
||||||
|
Pipe = StableDiffusionXLPipeline
|
||||||
else:
|
else:
|
||||||
if self.is_xl:
|
Pipe = StableDiffusionPipeline
|
||||||
extra_args['add_watermarker'] = False
|
|
||||||
|
|
||||||
# TODO add clip skip
|
extra_args = {}
|
||||||
if self.is_xl:
|
if self.adapter is not None:
|
||||||
pipeline = Pipe(
|
if isinstance(self.adapter, T2IAdapter):
|
||||||
vae=self.vae,
|
if self.is_xl:
|
||||||
unet=self.unet,
|
Pipe = StableDiffusionXLAdapterPipeline
|
||||||
text_encoder=self.text_encoder[0],
|
else:
|
||||||
text_encoder_2=self.text_encoder[1],
|
Pipe = StableDiffusionAdapterPipeline
|
||||||
tokenizer=self.tokenizer[0],
|
extra_args['adapter'] = self.adapter
|
||||||
tokenizer_2=self.tokenizer[1],
|
else:
|
||||||
scheduler=noise_scheduler,
|
if self.is_xl:
|
||||||
**extra_args
|
extra_args['add_watermarker'] = False
|
||||||
).to(self.device_torch)
|
|
||||||
pipeline.watermark = None
|
|
||||||
else:
|
|
||||||
pipeline = Pipe(
|
|
||||||
vae=self.vae,
|
|
||||||
unet=self.unet,
|
|
||||||
text_encoder=self.text_encoder,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
scheduler=noise_scheduler,
|
|
||||||
safety_checker=None,
|
|
||||||
feature_extractor=None,
|
|
||||||
requires_safety_checker=False,
|
|
||||||
**extra_args
|
|
||||||
).to(self.device_torch)
|
|
||||||
flush()
|
|
||||||
# disable progress bar
|
|
||||||
pipeline.set_progress_bar_config(disable=True)
|
|
||||||
|
|
||||||
if sampler.startswith("sample_"):
|
# TODO add clip skip
|
||||||
pipeline.set_scheduler(sampler)
|
if self.is_xl:
|
||||||
|
pipeline = Pipe(
|
||||||
|
vae=self.vae,
|
||||||
|
unet=self.unet,
|
||||||
|
text_encoder=self.text_encoder[0],
|
||||||
|
text_encoder_2=self.text_encoder[1],
|
||||||
|
tokenizer=self.tokenizer[0],
|
||||||
|
tokenizer_2=self.tokenizer[1],
|
||||||
|
scheduler=noise_scheduler,
|
||||||
|
**extra_args
|
||||||
|
).to(self.device_torch)
|
||||||
|
pipeline.watermark = None
|
||||||
|
else:
|
||||||
|
pipeline = Pipe(
|
||||||
|
vae=self.vae,
|
||||||
|
unet=self.unet,
|
||||||
|
text_encoder=self.text_encoder,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
scheduler=noise_scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
**extra_args
|
||||||
|
).to(self.device_torch)
|
||||||
|
flush()
|
||||||
|
# disable progress bar
|
||||||
|
pipeline.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
|
if sampler.startswith("sample_"):
|
||||||
|
pipeline.set_scheduler(sampler)
|
||||||
|
|
||||||
|
refiner_pipeline = None
|
||||||
|
if self.refiner_unet:
|
||||||
|
# build refiner pipeline
|
||||||
|
refiner_pipeline = StableDiffusionXLImg2ImgPipeline(
|
||||||
|
vae=pipeline.vae,
|
||||||
|
unet=self.refiner_unet,
|
||||||
|
text_encoder=None,
|
||||||
|
text_encoder_2=pipeline.text_encoder_2,
|
||||||
|
tokenizer=None,
|
||||||
|
tokenizer_2=pipeline.tokenizer_2,
|
||||||
|
scheduler=pipeline.scheduler,
|
||||||
|
add_watermarker=False,
|
||||||
|
requires_aesthetics_score=True,
|
||||||
|
).to(self.device_torch)
|
||||||
|
# refiner_pipeline.register_to_config(requires_aesthetics_score=False)
|
||||||
|
refiner_pipeline.watermark = None
|
||||||
|
refiner_pipeline.set_progress_bar_config(disable=True)
|
||||||
|
flush()
|
||||||
|
|
||||||
start_multiplier = 1.0
|
start_multiplier = 1.0
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
@@ -406,14 +472,20 @@ class StableDiffusion:
|
|||||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||||
|
|
||||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
if self.refiner_unet is not None:
|
||||||
|
# if we have a refiner loaded, set the denoising end at the refiner start
|
||||||
|
extra['denoising_end'] = self.model_config.refiner_start_at
|
||||||
|
extra['output_type'] = 'latent'
|
||||||
|
if not self.is_xl:
|
||||||
|
raise ValueError("Refiner is only supported for XL models")
|
||||||
|
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
# fix guidance rescale for sdxl
|
# fix guidance rescale for sdxl
|
||||||
# was trained on 0.7 (I believe)
|
# was trained on 0.7 (I believe)
|
||||||
|
|
||||||
grs = gen_config.guidance_rescale
|
grs = gen_config.guidance_rescale
|
||||||
if grs is None or grs < 0.00001:
|
# if grs is None or grs < 0.00001:
|
||||||
grs = 0.7
|
# grs = 0.7
|
||||||
# grs = 0.0
|
# grs = 0.0
|
||||||
|
|
||||||
if sampler.startswith("sample_"):
|
if sampler.startswith("sample_"):
|
||||||
@@ -454,10 +526,37 @@ class StableDiffusion:
|
|||||||
**extra
|
**extra
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
|
if refiner_pipeline is not None:
|
||||||
|
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
|
||||||
|
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
|
||||||
|
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
|
||||||
|
refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:]
|
||||||
|
# run through refiner
|
||||||
|
img = refiner_pipeline(
|
||||||
|
# prompt=gen_config.prompt,
|
||||||
|
# prompt_2=gen_config.prompt_2,
|
||||||
|
|
||||||
|
# slice these as it does not use both text encoders
|
||||||
|
# height=gen_config.height,
|
||||||
|
# width=gen_config.width,
|
||||||
|
prompt_embeds=refiner_text_embeds,
|
||||||
|
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||||
|
negative_prompt_embeds=refiner_unconditional_text_embeds,
|
||||||
|
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||||
|
num_inference_steps=gen_config.num_inference_steps,
|
||||||
|
guidance_scale=gen_config.guidance_scale,
|
||||||
|
guidance_rescale=grs,
|
||||||
|
denoising_start=self.model_config.refiner_start_at,
|
||||||
|
denoising_end=gen_config.num_inference_steps,
|
||||||
|
image=img.unsqueeze(0)
|
||||||
|
).images[0]
|
||||||
|
|
||||||
gen_config.save_image(img, i)
|
gen_config.save_image(img, i)
|
||||||
|
|
||||||
# clear pipeline and cache to reduce vram usage
|
# clear pipeline and cache to reduce vram usage
|
||||||
del pipeline
|
del pipeline
|
||||||
|
if refiner_pipeline is not None:
|
||||||
|
del refiner_pipeline
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# restore training state
|
# restore training state
|
||||||
@@ -505,7 +604,7 @@ class StableDiffusion:
|
|||||||
noise = apply_noise_offset(noise, noise_offset)
|
noise = apply_noise_offset(noise, noise_offset)
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
|
||||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
bs, ch, h, w = list(latents.shape)
|
bs, ch, h, w = list(latents.shape)
|
||||||
@@ -518,7 +617,13 @@ class StableDiffusion:
|
|||||||
target_size = (height, width)
|
target_size = (height, width)
|
||||||
original_size = (height, width)
|
original_size = (height, width)
|
||||||
crops_coords_top_left = (0, 0)
|
crops_coords_top_left = (0, 0)
|
||||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
if requires_aesthetic_score:
|
||||||
|
# refiner
|
||||||
|
# https://huggingface.co/papers/2307.01952
|
||||||
|
aesthetic_score = 6.0 # simulate one
|
||||||
|
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||||
|
else:
|
||||||
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||||
add_time_ids = torch.tensor([add_time_ids])
|
add_time_ids = torch.tensor([add_time_ids])
|
||||||
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
||||||
|
|
||||||
@@ -588,14 +693,68 @@ class StableDiffusion:
|
|||||||
"time_ids": add_time_ids,
|
"time_ids": add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
# predict the noise residual
|
if self.model_config.refiner_name_or_path is not None:
|
||||||
noise_pred = self.unet(
|
# we have the refiner on the second half of everything. Do Both
|
||||||
latent_model_input,
|
if do_classifier_free_guidance:
|
||||||
timestep,
|
raise ValueError("Refiner is not supported with classifier free guidance")
|
||||||
encoder_hidden_states=text_embeddings.text_embeds,
|
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
if self.unet.training:
|
||||||
**kwargs,
|
input_chunks = torch.chunk(latent_model_input, 2, dim=0)
|
||||||
).sample
|
timestep_chunks = torch.chunk(timestep, 2, dim=0)
|
||||||
|
added_cond_kwargs_chunked = {
|
||||||
|
"text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0),
|
||||||
|
"time_ids": torch.chunk(add_time_ids, 2, dim=0),
|
||||||
|
}
|
||||||
|
text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
base_pred = self.unet(
|
||||||
|
input_chunks[0],
|
||||||
|
timestep_chunks[0],
|
||||||
|
encoder_hidden_states=text_embeds_chunks[0],
|
||||||
|
added_cond_kwargs={
|
||||||
|
"text_embeds": added_cond_kwargs_chunked['text_embeds'][0],
|
||||||
|
"time_ids": added_cond_kwargs_chunked['time_ids'][0],
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
).sample
|
||||||
|
|
||||||
|
refiner_pred = self.refiner_unet(
|
||||||
|
input_chunks[1],
|
||||||
|
timestep_chunks[1],
|
||||||
|
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
|
||||||
|
added_cond_kwargs={
|
||||||
|
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
|
||||||
|
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
|
||||||
|
"time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True),
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
).sample
|
||||||
|
|
||||||
|
noise_pred = torch.cat([base_pred, refiner_pred], dim=0)
|
||||||
|
else:
|
||||||
|
noise_pred = self.refiner_unet(
|
||||||
|
latent_model_input,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:],
|
||||||
|
# just use the first second text encoder
|
||||||
|
added_cond_kwargs={
|
||||||
|
"text_embeds": text_embeddings.pooled_embeds,
|
||||||
|
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
).sample
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.unet(
|
||||||
|
latent_model_input,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states=text_embeddings.text_embeds,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
).sample
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
# perform guidance
|
# perform guidance
|
||||||
@@ -852,7 +1011,7 @@ class StableDiffusion:
|
|||||||
state_dict[new_key] = v
|
state_dict[new_key] = v
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[
|
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
|
||||||
str, Parameter]:
|
str, Parameter]:
|
||||||
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
||||||
if vae:
|
if vae:
|
||||||
@@ -877,6 +1036,10 @@ class StableDiffusion:
|
|||||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
named_params[name] = param
|
named_params[name] = param
|
||||||
|
|
||||||
|
if refiner:
|
||||||
|
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
|
||||||
|
named_params[name] = param
|
||||||
|
|
||||||
# convert to state dict keys, jsut replace . with _ on keys
|
# convert to state dict keys, jsut replace . with _ on keys
|
||||||
if state_dict_keys:
|
if state_dict_keys:
|
||||||
new_named_params = OrderedDict()
|
new_named_params = OrderedDict()
|
||||||
@@ -888,6 +1051,64 @@ class StableDiffusion:
|
|||||||
|
|
||||||
return named_params
|
return named_params
|
||||||
|
|
||||||
|
def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')):
|
||||||
|
|
||||||
|
# load the full refiner since we only train unet
|
||||||
|
if self.model_config.refiner_name_or_path is None:
|
||||||
|
raise ValueError("Refiner must be specified to save it")
|
||||||
|
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
|
||||||
|
# load the refiner model
|
||||||
|
dtype = get_torch_dtype(self.dtype)
|
||||||
|
model_path = self.model_config.refiner_name_or_path
|
||||||
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||||
|
# TODO only load unet??
|
||||||
|
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cpu',
|
||||||
|
variant="fp16",
|
||||||
|
use_safetensors=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||||
|
model_path,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cpu',
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
original_config_file=refiner_config_path,
|
||||||
|
)
|
||||||
|
# replace original unet
|
||||||
|
refiner.unet = self.refiner_unet
|
||||||
|
flush()
|
||||||
|
|
||||||
|
diffusers_state_dict = OrderedDict()
|
||||||
|
for k, v in refiner.vae.state_dict().items():
|
||||||
|
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
||||||
|
diffusers_state_dict[new_key] = v
|
||||||
|
for k, v in refiner.text_encoder_2.state_dict().items():
|
||||||
|
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
|
||||||
|
diffusers_state_dict[new_key] = v
|
||||||
|
for k, v in refiner.unet.state_dict().items():
|
||||||
|
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
||||||
|
diffusers_state_dict[new_key] = v
|
||||||
|
|
||||||
|
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||||
|
diffusers_state_dict,
|
||||||
|
'sdxl_refiner',
|
||||||
|
device='cpu',
|
||||||
|
dtype=save_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure parent folder exists
|
||||||
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
|
save_file(converted_state_dict, output_file, metadata=meta)
|
||||||
|
|
||||||
|
if self.config_file is not None:
|
||||||
|
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||||
|
output_config_path = f"{output_path_no_ext}.yaml"
|
||||||
|
shutil.copyfile(self.config_file, output_config_path)
|
||||||
|
|
||||||
|
|
||||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||||
version_string = '1'
|
version_string = '1'
|
||||||
if self.is_v2:
|
if self.is_v2:
|
||||||
@@ -929,6 +1150,8 @@ class StableDiffusion:
|
|||||||
text_encoder=False,
|
text_encoder=False,
|
||||||
text_encoder_lr=None,
|
text_encoder_lr=None,
|
||||||
unet_lr=None,
|
unet_lr=None,
|
||||||
|
refiner_lr=None,
|
||||||
|
refiner=False,
|
||||||
default_lr=1e-6,
|
default_lr=1e-6,
|
||||||
):
|
):
|
||||||
# todo maybe only get locon ones?
|
# todo maybe only get locon ones?
|
||||||
@@ -974,6 +1197,20 @@ class StableDiffusion:
|
|||||||
|
|
||||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||||
|
|
||||||
|
if refiner:
|
||||||
|
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
|
||||||
|
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
|
||||||
|
params = []
|
||||||
|
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||||
|
diffusers_key = f"refiner_{diffusers_key}"
|
||||||
|
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||||
|
if named_params[diffusers_key].requires_grad:
|
||||||
|
params.append(named_params[diffusers_key])
|
||||||
|
param_data = {"params": params, "lr": refiner_lr}
|
||||||
|
trainable_parameters.append(param_data)
|
||||||
|
|
||||||
|
print(f"Found {len(params)} trainable parameter in refiner")
|
||||||
|
|
||||||
return trainable_parameters
|
return trainable_parameters
|
||||||
|
|
||||||
def save_device_state(self):
|
def save_device_state(self):
|
||||||
@@ -1021,6 +1258,13 @@ class StableDiffusion:
|
|||||||
'requires_grad': requires_grad,
|
'requires_grad': requires_grad,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.refiner_unet is not None:
|
||||||
|
self.device_state['refiner_unet'] = {
|
||||||
|
'training': self.refiner_unet.training,
|
||||||
|
'device': self.refiner_unet.device,
|
||||||
|
'requires_grad': self.refiner_unet.conv_in.weight.requires_grad,
|
||||||
|
}
|
||||||
|
|
||||||
def restore_device_state(self):
|
def restore_device_state(self):
|
||||||
# restores the device state for all modules
|
# restores the device state for all modules
|
||||||
# this is useful for when we want to alter the state and restore it
|
# this is useful for when we want to alter the state and restore it
|
||||||
@@ -1075,6 +1319,14 @@ class StableDiffusion:
|
|||||||
self.adapter.train()
|
self.adapter.train()
|
||||||
else:
|
else:
|
||||||
self.adapter.eval()
|
self.adapter.eval()
|
||||||
|
|
||||||
|
if self.refiner_unet is not None:
|
||||||
|
self.refiner_unet.to(state['refiner_unet']['device'])
|
||||||
|
self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad'])
|
||||||
|
if state['refiner_unet']['training']:
|
||||||
|
self.refiner_unet.train()
|
||||||
|
else:
|
||||||
|
self.refiner_unet.eval()
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
||||||
@@ -1088,7 +1340,7 @@ class StableDiffusion:
|
|||||||
if device_state_preset in ['cache_latents']:
|
if device_state_preset in ['cache_latents']:
|
||||||
active_modules = ['vae']
|
active_modules = ['vae']
|
||||||
if device_state_preset in ['generate']:
|
if device_state_preset in ['generate']:
|
||||||
active_modules = ['vae', 'unet', 'text_encoder', 'adapter']
|
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']
|
||||||
|
|
||||||
state = copy.deepcopy(empty_preset)
|
state = copy.deepcopy(empty_preset)
|
||||||
# vae
|
# vae
|
||||||
@@ -1105,6 +1357,13 @@ class StableDiffusion:
|
|||||||
'requires_grad': 'unet' in training_modules,
|
'requires_grad': 'unet' in training_modules,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.refiner_unet is not None:
|
||||||
|
state['refiner_unet'] = {
|
||||||
|
'training': 'refiner_unet' in training_modules,
|
||||||
|
'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu',
|
||||||
|
'requires_grad': 'refiner_unet' in training_modules,
|
||||||
|
}
|
||||||
|
|
||||||
# text encoder
|
# text encoder
|
||||||
if isinstance(self.text_encoder, list):
|
if isinstance(self.text_encoder, list):
|
||||||
state['text_encoder'] = []
|
state['text_encoder'] = []
|
||||||
|
|||||||
@@ -691,10 +691,11 @@ class LearnableSNRGamma:
|
|||||||
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
||||||
self.offset = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
|
self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
|
||||||
|
self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
|
||||||
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
|
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
|
||||||
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
|
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
|
||||||
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.01)
|
self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01)
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.max_buffer_size = 20
|
self.max_buffer_size = 20
|
||||||
|
|
||||||
@@ -711,7 +712,7 @@ class LearnableSNRGamma:
|
|||||||
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
|
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
|
||||||
base_snrs = snr.clone().detach()
|
base_snrs = snr.clone().detach()
|
||||||
snr.requires_grad = True
|
snr.requires_grad = True
|
||||||
snr = snr * self.scale + self.offset
|
snr = (snr + self.offset_1) * self.scale + self.offset_2
|
||||||
|
|
||||||
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
|
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
|
||||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||||
@@ -726,18 +727,18 @@ class LearnableSNRGamma:
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
|
return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach()
|
||||||
|
|
||||||
|
|
||||||
def apply_learnable_snr_gos(
|
def apply_learnable_snr_gos(
|
||||||
loss,
|
loss,
|
||||||
timesteps,
|
timesteps,
|
||||||
learnable_snr_trainer:LearnableSNRGamma
|
learnable_snr_trainer: LearnableSNRGamma
|
||||||
):
|
):
|
||||||
|
|
||||||
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
|
snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps)
|
||||||
|
|
||||||
snr = snr * scale + offset
|
snr = (snr + offset_1) * scale + offset_2
|
||||||
|
|
||||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||||
|
|||||||
Reference in New Issue
Block a user