mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added refiner fine tuning. Works, but needs some polish.
This commit is contained in:
@@ -382,8 +382,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_strength_max = 1.0
|
||||
else:
|
||||
# training with assistance, we want it low
|
||||
adapter_strength_min = 0.5
|
||||
adapter_strength_max = 0.8
|
||||
adapter_strength_min = 0.4
|
||||
adapter_strength_max = 0.7
|
||||
# adapter_strength_min = 0.9
|
||||
# adapter_strength_max = 1.1
|
||||
|
||||
@@ -431,6 +431,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# make the batch splits
|
||||
if self.train_config.single_item_batching:
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
raise ValueError("Single item batching is not supported when training the refiner")
|
||||
batch_size = noisy_latents.shape[0]
|
||||
# chunk/split everything
|
||||
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
||||
@@ -452,7 +454,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt_2_list = [[prompt] for prompt in prompts_2]
|
||||
|
||||
else:
|
||||
# but it all in an array
|
||||
noisy_latents_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
@@ -603,6 +604,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# gradient accumulation. Just a place for breakpoint
|
||||
pass
|
||||
|
||||
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
@@ -152,6 +152,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
)
|
||||
|
||||
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
||||
@@ -382,6 +383,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = file_path.replace('.safetensors', '')
|
||||
# convert it back to normal object
|
||||
save_meta = parse_metadata_from_safetensors(save_meta)
|
||||
|
||||
if self.sd.refiner_unet and self.train_config.train_refiner:
|
||||
# save refiner
|
||||
refiner_name = self.job.name + '_refiner'
|
||||
filename = f'{refiner_name}{step_num}.safetensors'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
self.sd.save_refiner(
|
||||
file_path,
|
||||
save_meta,
|
||||
get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
if self.train_config.train_unet or self.train_config.train_text_encoder:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
save_meta,
|
||||
@@ -391,7 +404,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# save learnable params as json if we have thim
|
||||
if self.snr_gos:
|
||||
json_data = {
|
||||
'offset': self.snr_gos.offset.item(),
|
||||
'offset_1': self.snr_gos.offset_1.item(),
|
||||
'offset_2': self.snr_gos.offset_2.item(),
|
||||
'scale': self.snr_gos.scale.item(),
|
||||
'gamma': self.snr_gos.gamma.item(),
|
||||
}
|
||||
@@ -447,6 +461,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# Filter out non-existent paths and sort by creation time
|
||||
if paths:
|
||||
paths = [p for p in paths if os.path.exists(p)]
|
||||
# remove false positives
|
||||
if '_LoRA' not in name:
|
||||
paths = [p for p in paths if '_LoRA' not in p]
|
||||
if '_refiner' not in name:
|
||||
paths = [p for p in paths if '_refiner' not in p]
|
||||
if '_t2i' not in name:
|
||||
paths = [p for p in paths if '_t2i' not in p]
|
||||
|
||||
if len(paths) > 0:
|
||||
latest_path = max(paths, key=os.path.getctime)
|
||||
|
||||
return latest_path
|
||||
@@ -540,6 +563,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# double batch and add short captions to the end
|
||||
prompts = prompts + batch.get_caption_short_list()
|
||||
is_reg_list = is_reg_list + is_reg_list
|
||||
if self.model_config.refiner_name_or_path is not None and self.train_config.train_unet:
|
||||
prompts = prompts + prompts
|
||||
is_reg_list = is_reg_list + is_reg_list
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
@@ -587,6 +613,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
|
||||
batch_size = len(batch.file_items)
|
||||
min_noise_steps = self.train_config.min_denoising_steps
|
||||
max_noise_steps = self.train_config.max_denoising_steps
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
# if we are not training the unet, then we are only doing refiner and do not need to double up
|
||||
if self.train_config.train_unet:
|
||||
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = True
|
||||
else:
|
||||
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = False
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
|
||||
@@ -615,18 +651,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
timesteps,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps
|
||||
min_noise_steps,
|
||||
max_noise_steps
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
self.train_config.min_denoising_steps + 1,
|
||||
self.train_config.max_denoising_steps - 1
|
||||
min_noise_steps + 1,
|
||||
max_noise_steps - 1
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
timesteps = torch.randint(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
min_noise_steps,
|
||||
max_noise_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
@@ -678,9 +714,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return torch.cat([tensor, tensor], dim=0)
|
||||
|
||||
if do_double:
|
||||
if self.model_config.refiner_name_or_path:
|
||||
# apply refiner double up
|
||||
refiner_timesteps = torch.randint(
|
||||
max_noise_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
refiner_timesteps = refiner_timesteps.long()
|
||||
# add our new timesteps on to end
|
||||
timesteps = torch.cat([timesteps, refiner_timesteps], dim=0)
|
||||
|
||||
refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps)
|
||||
noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0)
|
||||
|
||||
else:
|
||||
# just double it
|
||||
noisy_latents = double_up_tensor(noisy_latents)
|
||||
noise = double_up_tensor(noise)
|
||||
timesteps = double_up_tensor(timesteps)
|
||||
|
||||
noise = double_up_tensor(noise)
|
||||
# prompts are already updated above
|
||||
imgs = double_up_tensor(imgs)
|
||||
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
|
||||
@@ -772,6 +826,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
||||
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
|
||||
if previous_refiner_save is not None:
|
||||
model_config_to_load.refiner_name_or_path = previous_refiner_save
|
||||
self.load_training_state_from_metadata(previous_refiner_save)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=model_config_to_load,
|
||||
@@ -818,6 +878,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if hasattr(text_encoder, "gradient_checkpointing_enable"):
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if self.sd.refiner_unet is not None:
|
||||
self.sd.refiner_unet.to(self.device_torch, dtype=dtype)
|
||||
self.sd.refiner_unet.requires_grad_(False)
|
||||
self.sd.refiner_unet.eval()
|
||||
if self.train_config.xformers:
|
||||
self.sd.refiner_unet.enable_xformers_memory_efficient_attention()
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.sd.refiner_unet.enable_gradient_checkpointing()
|
||||
|
||||
if isinstance(text_encoder, list):
|
||||
for te in text_encoder:
|
||||
te.requires_grad_(False)
|
||||
@@ -840,7 +909,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if os.path.exists(path_to_load):
|
||||
with open(path_to_load, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch)
|
||||
if 'offset' in json_data:
|
||||
# legacy
|
||||
self.snr_gos.offset_2.data = torch.tensor(json_data['offset'], device=self.device_torch)
|
||||
else:
|
||||
self.snr_gos.offset_1.data = torch.tensor(json_data['offset_1'], device=self.device_torch)
|
||||
self.snr_gos.offset_2.data = torch.tensor(json_data['offset_2'], device=self.device_torch)
|
||||
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||
|
||||
@@ -1018,7 +1092,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
default_lr=self.train_config.lr,
|
||||
refiner=self.train_config.train_refiner and self.sd.refiner_unet is not None,
|
||||
refiner_lr=self.train_config.refiner_lr,
|
||||
)
|
||||
# we may be using it for prompt injections
|
||||
if self.adapter_config is not None:
|
||||
@@ -1158,6 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
batch = None
|
||||
|
||||
# if we are doing a reg step, always accumulate
|
||||
if is_reg_step:
|
||||
self.is_grad_accumulation_step = True
|
||||
|
||||
# setup accumulation
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
# epoch is handling the accumulation, dont touch it
|
||||
|
||||
@@ -6,6 +6,8 @@ import os
|
||||
# add project root to sys path
|
||||
import sys
|
||||
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import torch
|
||||
@@ -50,6 +52,7 @@ parser.add_argument(
|
||||
|
||||
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
@@ -61,10 +64,17 @@ find_matches = False
|
||||
|
||||
print(f'Loading diffusers model')
|
||||
|
||||
ignore_ldm_begins_with = []
|
||||
|
||||
diffusers_file_path = file_path
|
||||
if args.ssd:
|
||||
diffusers_file_path = "segmind/SSD-1B"
|
||||
|
||||
if args.refiner:
|
||||
diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||
|
||||
if not args.refiner:
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=diffusers_file_path,
|
||||
is_xl=args.sdxl,
|
||||
@@ -84,6 +94,38 @@ flush()
|
||||
|
||||
print(f'Loading ldm model')
|
||||
diffusers_state_dict = diffusers_sd.state_dict()
|
||||
else:
|
||||
# refiner wont work directly with stable diffusion
|
||||
# so we need to load the model and then load the state dict
|
||||
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
file_path,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
|
||||
SD_PREFIX_VAE = "vae"
|
||||
SD_PREFIX_UNET = "unet"
|
||||
SD_PREFIX_REFINER_UNET = "refiner_unet"
|
||||
SD_PREFIX_TEXT_ENCODER = "te"
|
||||
|
||||
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
||||
SD_PREFIX_TEXT_ENCODER2 = "te1"
|
||||
|
||||
diffusers_state_dict = OrderedDict()
|
||||
for k, v in diffusers_pipeline.vae.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in diffusers_pipeline.text_encoder_2.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in diffusers_pipeline.unet.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
|
||||
# add ignore ones as we are only going to focus on unet and copy the rest
|
||||
# ignore_ldm_begins_with = ["conditioner.", "first_stage_model."]
|
||||
|
||||
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
||||
|
||||
ldm_state_dict = load_file(file_path)
|
||||
@@ -113,6 +155,12 @@ if args.sdxl or args.ssd:
|
||||
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "conditioner.embedders.1.model.text_projection"
|
||||
if args.refiner:
|
||||
te_suffix = '1'
|
||||
ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks"
|
||||
proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "conditioner.embedders.0.model.text_projection"
|
||||
if args.sd2:
|
||||
te_suffix = ''
|
||||
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
|
||||
@@ -120,7 +168,7 @@ if args.sd2:
|
||||
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||
|
||||
if args.sdxl or args.sd2 or args.ssd:
|
||||
if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||
@@ -297,6 +345,8 @@ if args.sdxl:
|
||||
name += '_sdxl'
|
||||
elif args.ssd:
|
||||
name += '_ssd'
|
||||
elif args.refiner:
|
||||
name += '_refiner'
|
||||
elif args.sd2:
|
||||
name += '_sd2'
|
||||
else:
|
||||
|
||||
@@ -160,6 +160,7 @@ class TrainConfig:
|
||||
self.lr = kwargs.get('lr', 1e-6)
|
||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
|
||||
self.refiner_lr = kwargs.get('refiner_lr', self.lr)
|
||||
self.embedding_lr = kwargs.get('embedding_lr', self.lr)
|
||||
self.adapter_lr = kwargs.get('adapter_lr', self.lr)
|
||||
self.optimizer = kwargs.get('optimizer', 'adamw')
|
||||
@@ -174,6 +175,7 @@ class TrainConfig:
|
||||
self.sdp = kwargs.get('sdp', False)
|
||||
self.train_unet = kwargs.get('train_unet', True)
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.train_refiner = kwargs.get('train_refiner', True)
|
||||
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||
self.snr_gamma = kwargs.get('snr_gamma', None)
|
||||
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
|
||||
@@ -238,6 +240,8 @@ class ModelConfig:
|
||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||
self.vae_path = kwargs.get('vae_path', None)
|
||||
self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None)
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
|
||||
|
||||
# only for SDXL models for now
|
||||
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
|
||||
|
||||
3499
toolkit/keymaps/stable_diffusion_refiner.json
Normal file
3499
toolkit/keymaps/stable_diffusion_refiner.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors
Normal file
BIN
toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors
Normal file
Binary file not shown.
10
toolkit/keymaps/stable_diffusion_refiner_unmatched.json
Normal file
10
toolkit/keymaps/stable_diffusion_refiner_unmatched.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"ldm": {
|
||||
"conditioner.embedders.0.model.logit_scale": {
|
||||
"shape": [],
|
||||
"min": 4.60546875,
|
||||
"max": 4.60546875
|
||||
}
|
||||
},
|
||||
"diffusers": {}
|
||||
}
|
||||
91
toolkit/orig_configs/sd_xl_refiner.yaml
Normal file
91
toolkit/orig_configs/sd_xl_refiner.yaml
Normal file
@@ -0,0 +1,91 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.13025
|
||||
disable_first_stage_autocast: True
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
adm_in_channels: 2560
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 384
|
||||
attention_resolutions: [4, 2]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 4
|
||||
context_dim: [1280, 1280, 1280, 1280] # 1280
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
legacy: False
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn and vector cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||
params:
|
||||
arch: ViT-bigG-14
|
||||
version: laion2b_s39b_b160k
|
||||
legacy: False
|
||||
freeze: True
|
||||
layer: penultimate
|
||||
always_return_pooled: True
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: aesthetic_score
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by one
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
@@ -5,6 +5,8 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
||||
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
||||
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
||||
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
|
||||
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
|
||||
|
||||
# check if ENV variable is set
|
||||
if 'MODELS_PATH' in os.environ:
|
||||
|
||||
@@ -8,10 +8,18 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from diffusers.utils import is_torch_xla_available
|
||||
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
|
||||
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
|
||||
|
||||
def __init__(
|
||||
@@ -807,3 +815,389 @@ class CustomStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
return noise_pred
|
||||
|
||||
|
||||
class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
denoising_start: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in both text-encoders
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
denoising_end (`float`, *optional*):
|
||||
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
||||
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
||||
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
||||
denoising_start (`float`, *optional*):
|
||||
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
|
||||
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
|
||||
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
|
||||
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
|
||||
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||
explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
||||
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
||||
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 8.1 Apply denoising_end
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
# 8.2 Determine denoising_start
|
||||
denoising_start_index = 0
|
||||
if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1:
|
||||
discrete_timestep_start = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_start * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps)))
|
||||
|
||||
|
||||
with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar:
|
||||
for i, t in enumerate(timesteps, start=denoising_start_index):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
if not output_type == "latent":
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
|
||||
def get_ldm_state_dict_from_diffusers(
|
||||
state_dict: 'OrderedDict',
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2',
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
|
||||
device='cpu',
|
||||
dtype=get_torch_dtype('fp32'),
|
||||
):
|
||||
@@ -115,6 +115,10 @@ def get_ldm_state_dict_from_diffusers(
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
||||
elif sd_version == 'sdxl_refiner':
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json')
|
||||
else:
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
|
||||
|
||||
@@ -23,6 +23,11 @@ empty_preset = {
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
'refiner_unet': {
|
||||
'training': False,
|
||||
'requires_grad': False,
|
||||
'device': 'cpu',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +39,7 @@ def get_train_sd_device_state_preset(
|
||||
train_lora: bool = False,
|
||||
train_adapter: bool = False,
|
||||
train_embedding: bool = False,
|
||||
train_refiner: bool = False,
|
||||
):
|
||||
preset = copy.deepcopy(empty_preset)
|
||||
if not cached_latents:
|
||||
@@ -59,9 +65,16 @@ def get_train_sd_device_state_preset(
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['unet']['training'] = True
|
||||
|
||||
if train_refiner:
|
||||
preset['refiner_unet']['training'] = True
|
||||
preset['refiner_unet']['requires_grad'] = True
|
||||
preset['refiner_unet']['device'] = device
|
||||
|
||||
if train_lora:
|
||||
# preset['text_encoder']['requires_grad'] = False
|
||||
preset['unet']['requires_grad'] = False
|
||||
if train_refiner:
|
||||
preset['refiner_unet']['requires_grad'] = False
|
||||
|
||||
if train_adapter:
|
||||
preset['adapter']['requires_grad'] = True
|
||||
|
||||
@@ -26,21 +26,28 @@ from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
||||
from toolkit.sd_device_states_presets import empty_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
UNet2DConditionModel
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
|
||||
SD_PREFIX_VAE = "vae"
|
||||
SD_PREFIX_UNET = "unet"
|
||||
SD_PREFIX_REFINER_UNET = "refiner_unet"
|
||||
SD_PREFIX_TEXT_ENCODER = "te"
|
||||
|
||||
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
||||
@@ -52,6 +59,10 @@ DO_NOT_TRAIN_WEIGHTS = [
|
||||
"unet_time_embedding.linear_1.weight",
|
||||
"unet_time_embedding.linear_2.bias",
|
||||
"unet_time_embedding.linear_2.weight",
|
||||
"refiner_unet_time_embedding.linear_1.bias",
|
||||
"refiner_unet_time_embedding.linear_1.weight",
|
||||
"refiner_unet_time_embedding.linear_2.bias",
|
||||
"refiner_unet_time_embedding.linear_2.weight",
|
||||
]
|
||||
|
||||
DeviceStatePreset = Literal['cache_latents', 'generate']
|
||||
@@ -81,10 +92,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
|
||||
|
||||
# if is type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
from diffusers import \
|
||||
StableDiffusionPipeline, \
|
||||
AutoencoderKL, \
|
||||
UNet2DConditionModel
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
|
||||
@@ -116,6 +123,8 @@ class StableDiffusion:
|
||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||
|
||||
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
self.ckppt_info = None
|
||||
@@ -214,7 +223,7 @@ class StableDiffusion:
|
||||
pipln = StableDiffusionPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path):
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
@@ -263,10 +272,47 @@ class StableDiffusion:
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.pipeline = pipe
|
||||
self.load_refiner()
|
||||
self.is_loaded = True
|
||||
|
||||
def load_refiner(self):
|
||||
# for now, we are just going to rely on the TE from the base model
|
||||
# which is TE2 for SDXL and TE for SD (no refiner currently)
|
||||
# and completely ignore a TE that may or may not be packaged with the refiner
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
|
||||
# load the refiner model
|
||||
dtype = get_torch_dtype(self.dtype)
|
||||
model_path = self.model_config.refiner_name_or_path
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# TODO only load unet??
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
original_config_file=refiner_config_path,
|
||||
).to(self.device_torch)
|
||||
|
||||
self.refiner_unet = refiner.unet
|
||||
del refiner
|
||||
flush()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
|
||||
def generate_images(
|
||||
self,
|
||||
image_configs: List[GenerateImageConfig],
|
||||
sampler=None,
|
||||
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
|
||||
):
|
||||
merge_multiplier = 1.0
|
||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if self.network is not None:
|
||||
@@ -289,6 +335,7 @@ class StableDiffusion:
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
|
||||
if pipeline is None:
|
||||
noise_scheduler = self.noise_scheduler
|
||||
if sampler is not None:
|
||||
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
||||
@@ -349,6 +396,25 @@ class StableDiffusion:
|
||||
if sampler.startswith("sample_"):
|
||||
pipeline.set_scheduler(sampler)
|
||||
|
||||
refiner_pipeline = None
|
||||
if self.refiner_unet:
|
||||
# build refiner pipeline
|
||||
refiner_pipeline = StableDiffusionXLImg2ImgPipeline(
|
||||
vae=pipeline.vae,
|
||||
unet=self.refiner_unet,
|
||||
text_encoder=None,
|
||||
text_encoder_2=pipeline.text_encoder_2,
|
||||
tokenizer=None,
|
||||
tokenizer_2=pipeline.tokenizer_2,
|
||||
scheduler=pipeline.scheduler,
|
||||
add_watermarker=False,
|
||||
requires_aesthetics_score=True,
|
||||
).to(self.device_torch)
|
||||
# refiner_pipeline.register_to_config(requires_aesthetics_score=False)
|
||||
refiner_pipeline.watermark = None
|
||||
refiner_pipeline.set_progress_bar_config(disable=True)
|
||||
flush()
|
||||
|
||||
start_multiplier = 1.0
|
||||
if self.network is not None:
|
||||
start_multiplier = self.network.multiplier
|
||||
@@ -406,14 +472,20 @@ class StableDiffusion:
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||
|
||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||
if self.refiner_unet is not None:
|
||||
# if we have a refiner loaded, set the denoising end at the refiner start
|
||||
extra['denoising_end'] = self.model_config.refiner_start_at
|
||||
extra['output_type'] = 'latent'
|
||||
if not self.is_xl:
|
||||
raise ValueError("Refiner is only supported for XL models")
|
||||
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
# was trained on 0.7 (I believe)
|
||||
|
||||
grs = gen_config.guidance_rescale
|
||||
if grs is None or grs < 0.00001:
|
||||
grs = 0.7
|
||||
# if grs is None or grs < 0.00001:
|
||||
# grs = 0.7
|
||||
# grs = 0.0
|
||||
|
||||
if sampler.startswith("sample_"):
|
||||
@@ -454,10 +526,37 @@ class StableDiffusion:
|
||||
**extra
|
||||
).images[0]
|
||||
|
||||
if refiner_pipeline is not None:
|
||||
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
|
||||
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
|
||||
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
|
||||
refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:]
|
||||
# run through refiner
|
||||
img = refiner_pipeline(
|
||||
# prompt=gen_config.prompt,
|
||||
# prompt_2=gen_config.prompt_2,
|
||||
|
||||
# slice these as it does not use both text encoders
|
||||
# height=gen_config.height,
|
||||
# width=gen_config.width,
|
||||
prompt_embeds=refiner_text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
negative_prompt_embeds=refiner_unconditional_text_embeds,
|
||||
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=grs,
|
||||
denoising_start=self.model_config.refiner_start_at,
|
||||
denoising_end=gen_config.num_inference_steps,
|
||||
image=img.unsqueeze(0)
|
||||
).images[0]
|
||||
|
||||
gen_config.save_image(img, i)
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
if refiner_pipeline is not None:
|
||||
del refiner_pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# restore training state
|
||||
@@ -505,7 +604,7 @@ class StableDiffusion:
|
||||
noise = apply_noise_offset(noise, noise_offset)
|
||||
return noise
|
||||
|
||||
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
||||
def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
if self.is_xl:
|
||||
bs, ch, h, w = list(latents.shape)
|
||||
@@ -518,6 +617,12 @@ class StableDiffusion:
|
||||
target_size = (height, width)
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
if requires_aesthetic_score:
|
||||
# refiner
|
||||
# https://huggingface.co/papers/2307.01952
|
||||
aesthetic_score = 6.0 # simulate one
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
else:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
||||
@@ -588,6 +693,60 @@ class StableDiffusion:
|
||||
"time_ids": add_time_ids,
|
||||
}
|
||||
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
# we have the refiner on the second half of everything. Do Both
|
||||
if do_classifier_free_guidance:
|
||||
raise ValueError("Refiner is not supported with classifier free guidance")
|
||||
|
||||
if self.unet.training:
|
||||
input_chunks = torch.chunk(latent_model_input, 2, dim=0)
|
||||
timestep_chunks = torch.chunk(timestep, 2, dim=0)
|
||||
added_cond_kwargs_chunked = {
|
||||
"text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0),
|
||||
"time_ids": torch.chunk(add_time_ids, 2, dim=0),
|
||||
}
|
||||
text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0)
|
||||
|
||||
# predict the noise residual
|
||||
base_pred = self.unet(
|
||||
input_chunks[0],
|
||||
timestep_chunks[0],
|
||||
encoder_hidden_states=text_embeds_chunks[0],
|
||||
added_cond_kwargs={
|
||||
"text_embeds": added_cond_kwargs_chunked['text_embeds'][0],
|
||||
"time_ids": added_cond_kwargs_chunked['time_ids'][0],
|
||||
},
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
refiner_pred = self.refiner_unet(
|
||||
input_chunks[1],
|
||||
timestep_chunks[1],
|
||||
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
|
||||
added_cond_kwargs={
|
||||
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
|
||||
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
|
||||
"time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True),
|
||||
},
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
noise_pred = torch.cat([base_pred, refiner_pred], dim=0)
|
||||
else:
|
||||
noise_pred = self.refiner_unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:],
|
||||
# just use the first second text encoder
|
||||
added_cond_kwargs={
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
|
||||
},
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
else:
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
@@ -852,7 +1011,7 @@ class StableDiffusion:
|
||||
state_dict[new_key] = v
|
||||
return state_dict
|
||||
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
|
||||
str, Parameter]:
|
||||
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
||||
if vae:
|
||||
@@ -877,6 +1036,10 @@ class StableDiffusion:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
if refiner:
|
||||
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
# convert to state dict keys, jsut replace . with _ on keys
|
||||
if state_dict_keys:
|
||||
new_named_params = OrderedDict()
|
||||
@@ -888,6 +1051,64 @@ class StableDiffusion:
|
||||
|
||||
return named_params
|
||||
|
||||
def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')):
|
||||
|
||||
# load the full refiner since we only train unet
|
||||
if self.model_config.refiner_name_or_path is None:
|
||||
raise ValueError("Refiner must be specified to save it")
|
||||
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
|
||||
# load the refiner model
|
||||
dtype = get_torch_dtype(self.dtype)
|
||||
model_path = self.model_config.refiner_name_or_path
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# TODO only load unet??
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device='cpu',
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
else:
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device='cpu',
|
||||
torch_dtype=self.torch_dtype,
|
||||
original_config_file=refiner_config_path,
|
||||
)
|
||||
# replace original unet
|
||||
refiner.unet = self.refiner_unet
|
||||
flush()
|
||||
|
||||
diffusers_state_dict = OrderedDict()
|
||||
for k, v in refiner.vae.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in refiner.text_encoder_2.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in refiner.unet.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
|
||||
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||
diffusers_state_dict,
|
||||
'sdxl_refiner',
|
||||
device='cpu',
|
||||
dtype=save_dtype
|
||||
)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
version_string = '1'
|
||||
if self.is_v2:
|
||||
@@ -929,6 +1150,8 @@ class StableDiffusion:
|
||||
text_encoder=False,
|
||||
text_encoder_lr=None,
|
||||
unet_lr=None,
|
||||
refiner_lr=None,
|
||||
refiner=False,
|
||||
default_lr=1e-6,
|
||||
):
|
||||
# todo maybe only get locon ones?
|
||||
@@ -974,6 +1197,20 @@ class StableDiffusion:
|
||||
|
||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||
|
||||
if refiner:
|
||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
|
||||
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
|
||||
params = []
|
||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
diffusers_key = f"refiner_{diffusers_key}"
|
||||
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||
if named_params[diffusers_key].requires_grad:
|
||||
params.append(named_params[diffusers_key])
|
||||
param_data = {"params": params, "lr": refiner_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
|
||||
print(f"Found {len(params)} trainable parameter in refiner")
|
||||
|
||||
return trainable_parameters
|
||||
|
||||
def save_device_state(self):
|
||||
@@ -1021,6 +1258,13 @@ class StableDiffusion:
|
||||
'requires_grad': requires_grad,
|
||||
}
|
||||
|
||||
if self.refiner_unet is not None:
|
||||
self.device_state['refiner_unet'] = {
|
||||
'training': self.refiner_unet.training,
|
||||
'device': self.refiner_unet.device,
|
||||
'requires_grad': self.refiner_unet.conv_in.weight.requires_grad,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
# restores the device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
@@ -1075,6 +1319,14 @@ class StableDiffusion:
|
||||
self.adapter.train()
|
||||
else:
|
||||
self.adapter.eval()
|
||||
|
||||
if self.refiner_unet is not None:
|
||||
self.refiner_unet.to(state['refiner_unet']['device'])
|
||||
self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad'])
|
||||
if state['refiner_unet']['training']:
|
||||
self.refiner_unet.train()
|
||||
else:
|
||||
self.refiner_unet.eval()
|
||||
flush()
|
||||
|
||||
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
||||
@@ -1088,7 +1340,7 @@ class StableDiffusion:
|
||||
if device_state_preset in ['cache_latents']:
|
||||
active_modules = ['vae']
|
||||
if device_state_preset in ['generate']:
|
||||
active_modules = ['vae', 'unet', 'text_encoder', 'adapter']
|
||||
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']
|
||||
|
||||
state = copy.deepcopy(empty_preset)
|
||||
# vae
|
||||
@@ -1105,6 +1357,13 @@ class StableDiffusion:
|
||||
'requires_grad': 'unet' in training_modules,
|
||||
}
|
||||
|
||||
if self.refiner_unet is not None:
|
||||
state['refiner_unet'] = {
|
||||
'training': 'refiner_unet' in training_modules,
|
||||
'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu',
|
||||
'requires_grad': 'refiner_unet' in training_modules,
|
||||
}
|
||||
|
||||
# text encoder
|
||||
if isinstance(self.text_encoder, list):
|
||||
state['text_encoder'] = []
|
||||
|
||||
@@ -691,10 +691,11 @@ class LearnableSNRGamma:
|
||||
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
||||
self.device = device
|
||||
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
||||
self.offset = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
|
||||
self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
|
||||
self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
|
||||
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
|
||||
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
|
||||
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.01)
|
||||
self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01)
|
||||
self.buffer = []
|
||||
self.max_buffer_size = 20
|
||||
|
||||
@@ -711,7 +712,7 @@ class LearnableSNRGamma:
|
||||
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
|
||||
base_snrs = snr.clone().detach()
|
||||
snr.requires_grad = True
|
||||
snr = snr * self.scale + self.offset
|
||||
snr = (snr + self.offset_1) * self.scale + self.offset_2
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
@@ -726,7 +727,7 @@ class LearnableSNRGamma:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
|
||||
return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach()
|
||||
|
||||
|
||||
def apply_learnable_snr_gos(
|
||||
@@ -735,9 +736,9 @@ def apply_learnable_snr_gos(
|
||||
learnable_snr_trainer: LearnableSNRGamma
|
||||
):
|
||||
|
||||
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
|
||||
snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps)
|
||||
|
||||
snr = snr * scale + offset
|
||||
snr = (snr + offset_1) * scale + offset_2
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
|
||||
Reference in New Issue
Block a user