Added refiner fine tuning. Works, but needs some polish.

This commit is contained in:
Jaret Burkett
2023-11-05 17:15:03 -07:00
parent 8a9e8f708f
commit 93ea955d7c
14 changed files with 4541 additions and 128 deletions

View File

@@ -382,8 +382,8 @@ class SDTrainer(BaseSDTrainProcess):
adapter_strength_max = 1.0 adapter_strength_max = 1.0
else: else:
# training with assistance, we want it low # training with assistance, we want it low
adapter_strength_min = 0.5 adapter_strength_min = 0.4
adapter_strength_max = 0.8 adapter_strength_max = 0.7
# adapter_strength_min = 0.9 # adapter_strength_min = 0.9
# adapter_strength_max = 1.1 # adapter_strength_max = 1.1
@@ -431,6 +431,8 @@ class SDTrainer(BaseSDTrainProcess):
# make the batch splits # make the batch splits
if self.train_config.single_item_batching: if self.train_config.single_item_batching:
if self.model_config.refiner_name_or_path is not None:
raise ValueError("Single item batching is not supported when training the refiner")
batch_size = noisy_latents.shape[0] batch_size = noisy_latents.shape[0]
# chunk/split everything # chunk/split everything
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
@@ -452,7 +454,6 @@ class SDTrainer(BaseSDTrainProcess):
prompt_2_list = [[prompt] for prompt in prompts_2] prompt_2_list = [[prompt] for prompt in prompts_2]
else: else:
# but it all in an array
noisy_latents_list = [noisy_latents] noisy_latents_list = [noisy_latents]
noise_list = [noise] noise_list = [noise]
timesteps_list = [timesteps] timesteps_list = [timesteps]
@@ -603,8 +604,13 @@ class SDTrainer(BaseSDTrainProcess):
# apply gradients # apply gradients
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True) self.optimizer.zero_grad(set_to_none=True)
with self.timer('scheduler_step'): else:
self.lr_scheduler.step() # gradient accumulation. Just a place for breakpoint
pass
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
with self.timer('scheduler_step'):
self.lr_scheduler.step()
if self.embedding is not None: if self.embedding is not None:
with self.timer('restore_embeddings'): with self.timer('restore_embeddings'):

View File

@@ -152,6 +152,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_lora=self.network_config is not None, train_lora=self.network_config is not None,
train_adapter=is_training_adapter, train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None, train_embedding=self.embed_config is not None,
train_refiner=self.train_config.train_refiner,
) )
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc) # fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
@@ -382,16 +383,29 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = file_path.replace('.safetensors', '') file_path = file_path.replace('.safetensors', '')
# convert it back to normal object # convert it back to normal object
save_meta = parse_metadata_from_safetensors(save_meta) save_meta = parse_metadata_from_safetensors(save_meta)
self.sd.save(
file_path, if self.sd.refiner_unet and self.train_config.train_refiner:
save_meta, # save refiner
get_torch_dtype(self.save_config.dtype) refiner_name = self.job.name + '_refiner'
) filename = f'{refiner_name}{step_num}.safetensors'
file_path = os.path.join(self.save_root, filename)
self.sd.save_refiner(
file_path,
save_meta,
get_torch_dtype(self.save_config.dtype)
)
if self.train_config.train_unet or self.train_config.train_text_encoder:
self.sd.save(
file_path,
save_meta,
get_torch_dtype(self.save_config.dtype)
)
# save learnable params as json if we have thim # save learnable params as json if we have thim
if self.snr_gos: if self.snr_gos:
json_data = { json_data = {
'offset': self.snr_gos.offset.item(), 'offset_1': self.snr_gos.offset_1.item(),
'offset_2': self.snr_gos.offset_2.item(),
'scale': self.snr_gos.scale.item(), 'scale': self.snr_gos.scale.item(),
'gamma': self.snr_gos.gamma.item(), 'gamma': self.snr_gos.gamma.item(),
} }
@@ -447,7 +461,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
# Filter out non-existent paths and sort by creation time # Filter out non-existent paths and sort by creation time
if paths: if paths:
paths = [p for p in paths if os.path.exists(p)] paths = [p for p in paths if os.path.exists(p)]
latest_path = max(paths, key=os.path.getctime) # remove false positives
if '_LoRA' not in name:
paths = [p for p in paths if '_LoRA' not in p]
if '_refiner' not in name:
paths = [p for p in paths if '_refiner' not in p]
if '_t2i' not in name:
paths = [p for p in paths if '_t2i' not in p]
if len(paths) > 0:
latest_path = max(paths, key=os.path.getctime)
return latest_path return latest_path
@@ -540,6 +563,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# double batch and add short captions to the end # double batch and add short captions to the end
prompts = prompts + batch.get_caption_short_list() prompts = prompts + batch.get_caption_short_list()
is_reg_list = is_reg_list + is_reg_list is_reg_list = is_reg_list + is_reg_list
if self.model_config.refiner_name_or_path is not None and self.train_config.train_unet:
prompts = prompts + prompts
is_reg_list = is_reg_list + is_reg_list
conditioned_prompts = [] conditioned_prompts = []
@@ -587,6 +613,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
batch_size = len(batch.file_items) batch_size = len(batch.file_items)
min_noise_steps = self.train_config.min_denoising_steps
max_noise_steps = self.train_config.max_denoising_steps
if self.model_config.refiner_name_or_path is not None:
# if we are not training the unet, then we are only doing refiner and do not need to double up
if self.train_config.train_unet:
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
do_double = True
else:
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
do_double = False
with self.timer('prepare_noise'): with self.timer('prepare_noise'):
@@ -615,18 +651,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
timesteps, timesteps,
0, 0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1, self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
self.train_config.min_denoising_steps, min_noise_steps,
self.train_config.max_denoising_steps max_noise_steps
) )
timesteps = timesteps.long().clamp( timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps + 1, min_noise_steps + 1,
self.train_config.max_denoising_steps - 1 max_noise_steps - 1
) )
elif self.train_config.content_or_style == 'balanced': elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint( timesteps = torch.randint(
self.train_config.min_denoising_steps, min_noise_steps,
self.train_config.max_denoising_steps, max_noise_steps,
(batch_size,), (batch_size,),
device=self.device_torch device=self.device_torch
) )
@@ -678,9 +714,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
return torch.cat([tensor, tensor], dim=0) return torch.cat([tensor, tensor], dim=0)
if do_double: if do_double:
noisy_latents = double_up_tensor(noisy_latents) if self.model_config.refiner_name_or_path:
# apply refiner double up
refiner_timesteps = torch.randint(
max_noise_steps,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
refiner_timesteps = refiner_timesteps.long()
# add our new timesteps on to end
timesteps = torch.cat([timesteps, refiner_timesteps], dim=0)
refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps)
noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0)
else:
# just double it
noisy_latents = double_up_tensor(noisy_latents)
timesteps = double_up_tensor(timesteps)
noise = double_up_tensor(noise) noise = double_up_tensor(noise)
timesteps = double_up_tensor(timesteps)
# prompts are already updated above # prompts are already updated above
imgs = double_up_tensor(imgs) imgs = double_up_tensor(imgs)
batch.mask_tensor = double_up_tensor(batch.mask_tensor) batch.mask_tensor = double_up_tensor(batch.mask_tensor)
@@ -772,6 +826,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
# get the noise scheduler # get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler) sampler = get_sampler(self.train_config.noise_scheduler)
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
if previous_refiner_save is not None:
model_config_to_load.refiner_name_or_path = previous_refiner_save
self.load_training_state_from_metadata(previous_refiner_save)
self.sd = StableDiffusion( self.sd = StableDiffusion(
device=self.device, device=self.device,
model_config=model_config_to_load, model_config=model_config_to_load,
@@ -818,6 +878,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
if hasattr(text_encoder, "gradient_checkpointing_enable"): if hasattr(text_encoder, "gradient_checkpointing_enable"):
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
if self.sd.refiner_unet is not None:
self.sd.refiner_unet.to(self.device_torch, dtype=dtype)
self.sd.refiner_unet.requires_grad_(False)
self.sd.refiner_unet.eval()
if self.train_config.xformers:
self.sd.refiner_unet.enable_xformers_memory_efficient_attention()
if self.train_config.gradient_checkpointing:
self.sd.refiner_unet.enable_gradient_checkpointing()
if isinstance(text_encoder, list): if isinstance(text_encoder, list):
for te in text_encoder: for te in text_encoder:
te.requires_grad_(False) te.requires_grad_(False)
@@ -840,7 +909,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
if os.path.exists(path_to_load): if os.path.exists(path_to_load):
with open(path_to_load, 'r') as f: with open(path_to_load, 'r') as f:
json_data = json.load(f) json_data = json.load(f)
self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch) if 'offset' in json_data:
# legacy
self.snr_gos.offset_2.data = torch.tensor(json_data['offset'], device=self.device_torch)
else:
self.snr_gos.offset_1.data = torch.tensor(json_data['offset_1'], device=self.device_torch)
self.snr_gos.offset_2.data = torch.tensor(json_data['offset_2'], device=self.device_torch)
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch) self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch) self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
@@ -1018,7 +1092,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_encoder=self.train_config.train_text_encoder, text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr, text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr, unet_lr=self.train_config.lr,
default_lr=self.train_config.lr default_lr=self.train_config.lr,
refiner=self.train_config.train_refiner and self.sd.refiner_unet is not None,
refiner_lr=self.train_config.refiner_lr,
) )
# we may be using it for prompt injections # we may be using it for prompt injections
if self.adapter_config is not None: if self.adapter_config is not None:
@@ -1158,6 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
else: else:
batch = None batch = None
# if we are doing a reg step, always accumulate
if is_reg_step:
self.is_grad_accumulation_step = True
# setup accumulation # setup accumulation
if self.train_config.gradient_accumulation_steps == -1: if self.train_config.gradient_accumulation_steps == -1:
# epoch is handling the accumulation, dont touch it # epoch is handling the accumulation, dont touch it

View File

@@ -6,6 +6,8 @@ import os
# add project root to sys path # add project root to sys path
import sys import sys
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch import torch
@@ -50,6 +52,7 @@ parser.add_argument(
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make') parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
parser.add_argument('--sdxl', action='store_true', help='is sdxl model') parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--refiner', action='store_true', help='is refiner model')
parser.add_argument('--ssd', action='store_true', help='is ssd model') parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model') parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
@@ -61,29 +64,68 @@ find_matches = False
print(f'Loading diffusers model') print(f'Loading diffusers model')
ignore_ldm_begins_with = []
diffusers_file_path = file_path diffusers_file_path = file_path
if args.ssd: if args.ssd:
diffusers_file_path = "segmind/SSD-1B" diffusers_file_path = "segmind/SSD-1B"
diffusers_model_config = ModelConfig( if args.refiner:
name_or_path=diffusers_file_path, diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
is_xl=args.sdxl,
is_v2=args.sd2, if not args.refiner:
is_ssd=args.ssd,
dtype=dtype, diffusers_model_config = ModelConfig(
) name_or_path=diffusers_file_path,
diffusers_sd = StableDiffusion( is_xl=args.sdxl,
model_config=diffusers_model_config, is_v2=args.sd2,
device=device, is_ssd=args.ssd,
dtype=dtype, dtype=dtype,
) )
diffusers_sd.load_model() diffusers_sd = StableDiffusion(
# delete things we dont need model_config=diffusers_model_config,
del diffusers_sd.tokenizer device=device,
flush() dtype=dtype,
)
diffusers_sd.load_model()
# delete things we dont need
del diffusers_sd.tokenizer
flush()
print(f'Loading ldm model')
diffusers_state_dict = diffusers_sd.state_dict()
else:
# refiner wont work directly with stable diffusion
# so we need to load the model and then load the state dict
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
file_path,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to(device)
SD_PREFIX_VAE = "vae"
SD_PREFIX_UNET = "unet"
SD_PREFIX_REFINER_UNET = "refiner_unet"
SD_PREFIX_TEXT_ENCODER = "te"
SD_PREFIX_TEXT_ENCODER1 = "te0"
SD_PREFIX_TEXT_ENCODER2 = "te1"
diffusers_state_dict = OrderedDict()
for k, v in diffusers_pipeline.vae.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
diffusers_state_dict[new_key] = v
for k, v in diffusers_pipeline.text_encoder_2.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
diffusers_state_dict[new_key] = v
for k, v in diffusers_pipeline.unet.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
diffusers_state_dict[new_key] = v
# add ignore ones as we are only going to focus on unet and copy the rest
# ignore_ldm_begins_with = ["conditioner.", "first_stage_model."]
print(f'Loading ldm model')
diffusers_state_dict = diffusers_sd.state_dict()
diffusers_dict_keys = list(diffusers_state_dict.keys()) diffusers_dict_keys = list(diffusers_state_dict.keys())
ldm_state_dict = load_file(file_path) ldm_state_dict = load_file(file_path)
@@ -113,6 +155,12 @@ if args.sdxl or args.ssd:
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight" proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias" proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "conditioner.embedders.1.model.text_projection" text_proj_layer = "conditioner.embedders.1.model.text_projection"
if args.refiner:
te_suffix = '1'
ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks"
proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "conditioner.embedders.0.model.text_projection"
if args.sd2: if args.sd2:
te_suffix = '' te_suffix = ''
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks" ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
@@ -120,7 +168,7 @@ if args.sd2:
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias" proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "cond_stage_model.model.text_projection" text_proj_layer = "cond_stage_model.model.text_projection"
if args.sdxl or args.sd2 or args.ssd: if args.sdxl or args.sd2 or args.ssd or args.refiner:
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys: if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0])) # d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0]) d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
@@ -297,6 +345,8 @@ if args.sdxl:
name += '_sdxl' name += '_sdxl'
elif args.ssd: elif args.ssd:
name += '_ssd' name += '_ssd'
elif args.refiner:
name += '_refiner'
elif args.sd2: elif args.sd2:
name += '_sd2' name += '_sd2'
else: else:

View File

@@ -160,6 +160,7 @@ class TrainConfig:
self.lr = kwargs.get('lr', 1e-6) self.lr = kwargs.get('lr', 1e-6)
self.unet_lr = kwargs.get('unet_lr', self.lr) self.unet_lr = kwargs.get('unet_lr', self.lr)
self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr) self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
self.refiner_lr = kwargs.get('refiner_lr', self.lr)
self.embedding_lr = kwargs.get('embedding_lr', self.lr) self.embedding_lr = kwargs.get('embedding_lr', self.lr)
self.adapter_lr = kwargs.get('adapter_lr', self.lr) self.adapter_lr = kwargs.get('adapter_lr', self.lr)
self.optimizer = kwargs.get('optimizer', 'adamw') self.optimizer = kwargs.get('optimizer', 'adamw')
@@ -174,6 +175,7 @@ class TrainConfig:
self.sdp = kwargs.get('sdp', False) self.sdp = kwargs.get('sdp', False)
self.train_unet = kwargs.get('train_unet', True) self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True) self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.train_refiner = kwargs.get('train_refiner', True)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None) self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
self.snr_gamma = kwargs.get('snr_gamma', None) self.snr_gamma = kwargs.get('snr_gamma', None)
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials # trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
@@ -238,6 +240,8 @@ class ModelConfig:
self.is_v_pred: bool = kwargs.get('is_v_pred', False) self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16') self.dtype: str = kwargs.get('dtype', 'float16')
self.vae_path = kwargs.get('vae_path', None) self.vae_path = kwargs.get('vae_path', None)
self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None)
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
# only for SDXL models for now # only for SDXL models for now
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True) self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,10 @@
{
"ldm": {
"conditioner.embedders.0.model.logit_scale": {
"shape": [],
"min": 4.60546875,
"max": 4.60546875
}
},
"diffusers": {}
}

View File

@@ -0,0 +1,91 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2560
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 384
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 4
context_dim: [1280, 1280, 1280, 1280] # 1280
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
legacy: False
freeze: True
layer: penultimate
always_return_pooled: True
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: aesthetic_score
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by one
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@@ -5,6 +5,8 @@ CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts") SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories") REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps") KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
# check if ENV variable is set # check if ENV variable is set
if 'MODELS_PATH' in os.environ: if 'MODELS_PATH' in os.environ:

View File

@@ -8,10 +8,18 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from diffusers.utils import is_torch_xla_available
from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline): class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
def __init__( def __init__(
@@ -807,3 +815,389 @@ class CustomStableDiffusionPipeline(StableDiffusionPipeline):
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
return noise_pred return noise_pred
class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
denoising_start: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
To negatively condition the generation process based on a target image resolution. It should be as same
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 8.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
# 8.2 Determine denoising_start
denoising_start_index = 0
if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1:
discrete_timestep_start = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps)))
with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar:
for i, t in enumerate(timesteps, start=denoising_start_index):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
else:
image = latents
if not output_type == "latent":
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)

View File

@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
def get_ldm_state_dict_from_diffusers( def get_ldm_state_dict_from_diffusers(
state_dict: 'OrderedDict', state_dict: 'OrderedDict',
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2', sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
device='cpu', device='cpu',
dtype=get_torch_dtype('fp32'), dtype=get_torch_dtype('fp32'),
): ):
@@ -115,6 +115,10 @@ def get_ldm_state_dict_from_diffusers(
# load our base # load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors') base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json') mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
elif sd_version == 'sdxl_refiner':
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json')
else: else:
raise ValueError(f"Invalid sd_version {sd_version}") raise ValueError(f"Invalid sd_version {sd_version}")

View File

@@ -23,6 +23,11 @@ empty_preset = {
'requires_grad': False, 'requires_grad': False,
'device': 'cpu', 'device': 'cpu',
}, },
'refiner_unet': {
'training': False,
'requires_grad': False,
'device': 'cpu',
},
} }
@@ -34,6 +39,7 @@ def get_train_sd_device_state_preset(
train_lora: bool = False, train_lora: bool = False,
train_adapter: bool = False, train_adapter: bool = False,
train_embedding: bool = False, train_embedding: bool = False,
train_refiner: bool = False,
): ):
preset = copy.deepcopy(empty_preset) preset = copy.deepcopy(empty_preset)
if not cached_latents: if not cached_latents:
@@ -59,9 +65,16 @@ def get_train_sd_device_state_preset(
preset['text_encoder']['training'] = True preset['text_encoder']['training'] = True
preset['unet']['training'] = True preset['unet']['training'] = True
if train_refiner:
preset['refiner_unet']['training'] = True
preset['refiner_unet']['requires_grad'] = True
preset['refiner_unet']['device'] = device
if train_lora: if train_lora:
# preset['text_encoder']['requires_grad'] = False # preset['text_encoder']['requires_grad'] = False
preset['unet']['requires_grad'] = False preset['unet']['requires_grad'] = False
if train_refiner:
preset['refiner_unet']['requires_grad'] = False
if train_adapter: if train_adapter:
preset['adapter']['requires_grad'] = True preset['adapter']['requires_grad'] = True

View File

@@ -26,21 +26,28 @@ from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
from toolkit.sampler import get_sampler from toolkit.sampler import get_sampler
from toolkit.saving import save_ldm_model_from_diffusers from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
from toolkit.sd_device_states_presets import empty_preset from toolkit.sd_device_states_presets import empty_preset
from toolkit.train_tools import get_torch_dtype, apply_noise_offset from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline
import diffusers import diffusers
from diffusers import \
AutoencoderKL, \
UNet2DConditionModel
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
# tell it to shut up # tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR) diffusers.logging.set_verbosity(diffusers.logging.ERROR)
SD_PREFIX_VAE = "vae" SD_PREFIX_VAE = "vae"
SD_PREFIX_UNET = "unet" SD_PREFIX_UNET = "unet"
SD_PREFIX_REFINER_UNET = "refiner_unet"
SD_PREFIX_TEXT_ENCODER = "te" SD_PREFIX_TEXT_ENCODER = "te"
SD_PREFIX_TEXT_ENCODER1 = "te0" SD_PREFIX_TEXT_ENCODER1 = "te0"
@@ -52,6 +59,10 @@ DO_NOT_TRAIN_WEIGHTS = [
"unet_time_embedding.linear_1.weight", "unet_time_embedding.linear_1.weight",
"unet_time_embedding.linear_2.bias", "unet_time_embedding.linear_2.bias",
"unet_time_embedding.linear_2.weight", "unet_time_embedding.linear_2.weight",
"refiner_unet_time_embedding.linear_1.bias",
"refiner_unet_time_embedding.linear_1.weight",
"refiner_unet_time_embedding.linear_2.bias",
"refiner_unet_time_embedding.linear_2.weight",
] ]
DeviceStatePreset = Literal['cache_latents', 'generate'] DeviceStatePreset = Literal['cache_latents', 'generate']
@@ -81,10 +92,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
# if is type checking # if is type checking
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from diffusers import \
StableDiffusionPipeline, \
AutoencoderKL, \
UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
@@ -116,6 +123,8 @@ class StableDiffusion:
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
# sdxl stuff # sdxl stuff
self.logit_scale = None self.logit_scale = None
self.ckppt_info = None self.ckppt_info = None
@@ -214,7 +223,7 @@ class StableDiffusion:
pipln = StableDiffusionPipeline pipln = StableDiffusionPipeline
# see if path exists # see if path exists
if not os.path.exists(model_path): if not os.path.exists(model_path) or os.path.isdir(model_path):
# try to load with default diffusers # try to load with default diffusers
pipe = pipln.from_pretrained( pipe = pipln.from_pretrained(
model_path, model_path,
@@ -263,10 +272,47 @@ class StableDiffusion:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.text_encoder = text_encoder self.text_encoder = text_encoder
self.pipeline = pipe self.pipeline = pipe
self.load_refiner()
self.is_loaded = True self.is_loaded = True
def load_refiner(self):
# for now, we are just going to rely on the TE from the base model
# which is TE2 for SDXL and TE for SD (no refiner currently)
# and completely ignore a TE that may or may not be packaged with the refiner
if self.model_config.refiner_name_or_path is not None:
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
# load the refiner model
dtype = get_torch_dtype(self.dtype)
model_path = self.model_config.refiner_name_or_path
if not os.path.exists(model_path) or os.path.isdir(model_path):
# TODO only load unet??
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
model_path,
dtype=dtype,
device=self.device_torch,
variant="fp16",
use_safetensors=True,
).to(self.device_torch)
else:
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
model_path,
dtype=dtype,
device=self.device_torch,
torch_dtype=self.torch_dtype,
original_config_file=refiner_config_path,
).to(self.device_torch)
self.refiner_unet = refiner.unet
del refiner
flush()
@torch.no_grad() @torch.no_grad()
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None): def generate_images(
self,
image_configs: List[GenerateImageConfig],
sampler=None,
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
):
merge_multiplier = 1.0 merge_multiplier = 1.0
# sample_folder = os.path.join(self.save_root, 'samples') # sample_folder = os.path.join(self.save_root, 'samples')
if self.network is not None: if self.network is not None:
@@ -289,65 +335,85 @@ class StableDiffusion:
rng_state = torch.get_rng_state() rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
noise_scheduler = self.noise_scheduler if pipeline is None:
if sampler is not None: noise_scheduler = self.noise_scheduler
if sampler.startswith("sample_"): # sample_dpmpp_2m if sampler is not None:
# using ksampler if sampler.startswith("sample_"): # sample_dpmpp_2m
noise_scheduler = get_sampler('lms') # using ksampler
else: noise_scheduler = get_sampler('lms')
noise_scheduler = get_sampler(sampler)
if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline
elif self.is_xl:
Pipe = StableDiffusionXLPipeline
else:
Pipe = StableDiffusionPipeline
extra_args = {}
if self.adapter is not None:
if isinstance(self.adapter, T2IAdapter):
if self.is_xl:
Pipe = StableDiffusionXLAdapterPipeline
else: else:
Pipe = StableDiffusionAdapterPipeline noise_scheduler = get_sampler(sampler)
extra_args['adapter'] = self.adapter
if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline
elif self.is_xl:
Pipe = StableDiffusionXLPipeline
else: else:
if self.is_xl: Pipe = StableDiffusionPipeline
extra_args['add_watermarker'] = False
# TODO add clip skip extra_args = {}
if self.is_xl: if self.adapter is not None:
pipeline = Pipe( if isinstance(self.adapter, T2IAdapter):
vae=self.vae, if self.is_xl:
unet=self.unet, Pipe = StableDiffusionXLAdapterPipeline
text_encoder=self.text_encoder[0], else:
text_encoder_2=self.text_encoder[1], Pipe = StableDiffusionAdapterPipeline
tokenizer=self.tokenizer[0], extra_args['adapter'] = self.adapter
tokenizer_2=self.tokenizer[1], else:
scheduler=noise_scheduler, if self.is_xl:
**extra_args extra_args['add_watermarker'] = False
).to(self.device_torch)
pipeline.watermark = None
else:
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
**extra_args
).to(self.device_torch)
flush()
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
if sampler.startswith("sample_"): # TODO add clip skip
pipeline.set_scheduler(sampler) if self.is_xl:
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
pipeline.watermark = None
else:
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
**extra_args
).to(self.device_torch)
flush()
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
if sampler.startswith("sample_"):
pipeline.set_scheduler(sampler)
refiner_pipeline = None
if self.refiner_unet:
# build refiner pipeline
refiner_pipeline = StableDiffusionXLImg2ImgPipeline(
vae=pipeline.vae,
unet=self.refiner_unet,
text_encoder=None,
text_encoder_2=pipeline.text_encoder_2,
tokenizer=None,
tokenizer_2=pipeline.tokenizer_2,
scheduler=pipeline.scheduler,
add_watermarker=False,
requires_aesthetics_score=True,
).to(self.device_torch)
# refiner_pipeline.register_to_config(requires_aesthetics_score=False)
refiner_pipeline.watermark = None
refiner_pipeline.set_progress_bar_config(disable=True)
flush()
start_multiplier = 1.0 start_multiplier = 1.0
if self.network is not None: if self.network is not None:
@@ -406,14 +472,20 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
# todo do we disable text encoder here as well if disabled for model, or only do that for training? if self.refiner_unet is not None:
# if we have a refiner loaded, set the denoising end at the refiner start
extra['denoising_end'] = self.model_config.refiner_start_at
extra['output_type'] = 'latent'
if not self.is_xl:
raise ValueError("Refiner is only supported for XL models")
if self.is_xl: if self.is_xl:
# fix guidance rescale for sdxl # fix guidance rescale for sdxl
# was trained on 0.7 (I believe) # was trained on 0.7 (I believe)
grs = gen_config.guidance_rescale grs = gen_config.guidance_rescale
if grs is None or grs < 0.00001: # if grs is None or grs < 0.00001:
grs = 0.7 # grs = 0.7
# grs = 0.0 # grs = 0.0
if sampler.startswith("sample_"): if sampler.startswith("sample_"):
@@ -454,10 +526,37 @@ class StableDiffusion:
**extra **extra
).images[0] ).images[0]
if refiner_pipeline is not None:
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:]
# run through refiner
img = refiner_pipeline(
# prompt=gen_config.prompt,
# prompt_2=gen_config.prompt_2,
# slice these as it does not use both text encoders
# height=gen_config.height,
# width=gen_config.width,
prompt_embeds=refiner_text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
negative_prompt_embeds=refiner_unconditional_text_embeds,
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=grs,
denoising_start=self.model_config.refiner_start_at,
denoising_end=gen_config.num_inference_steps,
image=img.unsqueeze(0)
).images[0]
gen_config.save_image(img, i) gen_config.save_image(img, i)
# clear pipeline and cache to reduce vram usage # clear pipeline and cache to reduce vram usage
del pipeline del pipeline
if refiner_pipeline is not None:
del refiner_pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
# restore training state # restore training state
@@ -505,7 +604,7 @@ class StableDiffusion:
noise = apply_noise_offset(noise, noise_offset) noise = apply_noise_offset(noise, noise_offset)
return noise return noise
def get_time_ids_from_latents(self, latents: torch.Tensor): def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
if self.is_xl: if self.is_xl:
bs, ch, h, w = list(latents.shape) bs, ch, h, w = list(latents.shape)
@@ -518,7 +617,13 @@ class StableDiffusion:
target_size = (height, width) target_size = (height, width)
original_size = (height, width) original_size = (height, width)
crops_coords_top_left = (0, 0) crops_coords_top_left = (0, 0)
add_time_ids = list(original_size + crops_coords_top_left + target_size) if requires_aesthetic_score:
# refiner
# https://huggingface.co/papers/2307.01952
aesthetic_score = 6.0 # simulate one
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
else:
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids]) add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(latents.device, dtype=dtype) add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
@@ -588,14 +693,68 @@ class StableDiffusion:
"time_ids": add_time_ids, "time_ids": add_time_ids,
} }
# predict the noise residual if self.model_config.refiner_name_or_path is not None:
noise_pred = self.unet( # we have the refiner on the second half of everything. Do Both
latent_model_input, if do_classifier_free_guidance:
timestep, raise ValueError("Refiner is not supported with classifier free guidance")
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs, if self.unet.training:
**kwargs, input_chunks = torch.chunk(latent_model_input, 2, dim=0)
).sample timestep_chunks = torch.chunk(timestep, 2, dim=0)
added_cond_kwargs_chunked = {
"text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0),
"time_ids": torch.chunk(add_time_ids, 2, dim=0),
}
text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0)
# predict the noise residual
base_pred = self.unet(
input_chunks[0],
timestep_chunks[0],
encoder_hidden_states=text_embeds_chunks[0],
added_cond_kwargs={
"text_embeds": added_cond_kwargs_chunked['text_embeds'][0],
"time_ids": added_cond_kwargs_chunked['time_ids'][0],
},
**kwargs,
).sample
refiner_pred = self.refiner_unet(
input_chunks[1],
timestep_chunks[1],
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
added_cond_kwargs={
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
"time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True),
},
**kwargs,
).sample
noise_pred = torch.cat([base_pred, refiner_pred], dim=0)
else:
noise_pred = self.refiner_unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:],
# just use the first second text encoder
added_cond_kwargs={
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
},
**kwargs,
).sample
else:
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
).sample
if do_classifier_free_guidance: if do_classifier_free_guidance:
# perform guidance # perform guidance
@@ -852,7 +1011,7 @@ class StableDiffusion:
state_dict[new_key] = v state_dict[new_key] = v
return state_dict return state_dict
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[ def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
str, Parameter]: str, Parameter]:
named_params: OrderedDict[str, Parameter] = OrderedDict() named_params: OrderedDict[str, Parameter] = OrderedDict()
if vae: if vae:
@@ -877,6 +1036,10 @@ class StableDiffusion:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param named_params[name] = param
if refiner:
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
named_params[name] = param
# convert to state dict keys, jsut replace . with _ on keys # convert to state dict keys, jsut replace . with _ on keys
if state_dict_keys: if state_dict_keys:
new_named_params = OrderedDict() new_named_params = OrderedDict()
@@ -888,6 +1051,64 @@ class StableDiffusion:
return named_params return named_params
def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')):
# load the full refiner since we only train unet
if self.model_config.refiner_name_or_path is None:
raise ValueError("Refiner must be specified to save it")
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
# load the refiner model
dtype = get_torch_dtype(self.dtype)
model_path = self.model_config.refiner_name_or_path
if not os.path.exists(model_path) or os.path.isdir(model_path):
# TODO only load unet??
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
model_path,
dtype=dtype,
device='cpu',
variant="fp16",
use_safetensors=True,
)
else:
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
model_path,
dtype=dtype,
device='cpu',
torch_dtype=self.torch_dtype,
original_config_file=refiner_config_path,
)
# replace original unet
refiner.unet = self.refiner_unet
flush()
diffusers_state_dict = OrderedDict()
for k, v in refiner.vae.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
diffusers_state_dict[new_key] = v
for k, v in refiner.text_encoder_2.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
diffusers_state_dict[new_key] = v
for k, v in refiner.unet.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
diffusers_state_dict[new_key] = v
converted_state_dict = get_ldm_state_dict_from_diffusers(
diffusers_state_dict,
'sdxl_refiner',
device='cpu',
dtype=save_dtype
)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
if self.config_file is not None:
output_path_no_ext = os.path.splitext(output_file)[0]
output_config_path = f"{output_path_no_ext}.yaml"
shutil.copyfile(self.config_file, output_config_path)
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
version_string = '1' version_string = '1'
if self.is_v2: if self.is_v2:
@@ -929,6 +1150,8 @@ class StableDiffusion:
text_encoder=False, text_encoder=False,
text_encoder_lr=None, text_encoder_lr=None,
unet_lr=None, unet_lr=None,
refiner_lr=None,
refiner=False,
default_lr=1e-6, default_lr=1e-6,
): ):
# todo maybe only get locon ones? # todo maybe only get locon ones?
@@ -974,6 +1197,20 @@ class StableDiffusion:
print(f"Found {len(params)} trainable parameter in text encoder") print(f"Found {len(params)} trainable parameter in text encoder")
if refiner:
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
params = []
for key, diffusers_key in ldm_diffusers_keymap.items():
diffusers_key = f"refiner_{diffusers_key}"
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
if named_params[diffusers_key].requires_grad:
params.append(named_params[diffusers_key])
param_data = {"params": params, "lr": refiner_lr}
trainable_parameters.append(param_data)
print(f"Found {len(params)} trainable parameter in refiner")
return trainable_parameters return trainable_parameters
def save_device_state(self): def save_device_state(self):
@@ -1021,6 +1258,13 @@ class StableDiffusion:
'requires_grad': requires_grad, 'requires_grad': requires_grad,
} }
if self.refiner_unet is not None:
self.device_state['refiner_unet'] = {
'training': self.refiner_unet.training,
'device': self.refiner_unet.device,
'requires_grad': self.refiner_unet.conv_in.weight.requires_grad,
}
def restore_device_state(self): def restore_device_state(self):
# restores the device state for all modules # restores the device state for all modules
# this is useful for when we want to alter the state and restore it # this is useful for when we want to alter the state and restore it
@@ -1075,6 +1319,14 @@ class StableDiffusion:
self.adapter.train() self.adapter.train()
else: else:
self.adapter.eval() self.adapter.eval()
if self.refiner_unet is not None:
self.refiner_unet.to(state['refiner_unet']['device'])
self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad'])
if state['refiner_unet']['training']:
self.refiner_unet.train()
else:
self.refiner_unet.eval()
flush() flush()
def set_device_state_preset(self, device_state_preset: DeviceStatePreset): def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
@@ -1088,7 +1340,7 @@ class StableDiffusion:
if device_state_preset in ['cache_latents']: if device_state_preset in ['cache_latents']:
active_modules = ['vae'] active_modules = ['vae']
if device_state_preset in ['generate']: if device_state_preset in ['generate']:
active_modules = ['vae', 'unet', 'text_encoder', 'adapter'] active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']
state = copy.deepcopy(empty_preset) state = copy.deepcopy(empty_preset)
# vae # vae
@@ -1105,6 +1357,13 @@ class StableDiffusion:
'requires_grad': 'unet' in training_modules, 'requires_grad': 'unet' in training_modules,
} }
if self.refiner_unet is not None:
state['refiner_unet'] = {
'training': 'refiner_unet' in training_modules,
'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu',
'requires_grad': 'refiner_unet' in training_modules,
}
# text encoder # text encoder
if isinstance(self.text_encoder, list): if isinstance(self.text_encoder, list):
state['text_encoder'] = [] state['text_encoder'] = []

View File

@@ -691,10 +691,11 @@ class LearnableSNRGamma:
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'): def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
self.device = device self.device = device
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
self.offset = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device)) self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device)) self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device)) self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.01) self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01)
self.buffer = [] self.buffer = []
self.max_buffer_size = 20 self.max_buffer_size = 20
@@ -711,7 +712,7 @@ class LearnableSNRGamma:
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device) snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
base_snrs = snr.clone().detach() base_snrs = snr.clone().detach()
snr.requires_grad = True snr.requires_grad = True
snr = snr * self.scale + self.offset snr = (snr + self.offset_1) * self.scale + self.offset_2
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr) gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
@@ -726,18 +727,18 @@ class LearnableSNRGamma:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach() return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach()
def apply_learnable_snr_gos( def apply_learnable_snr_gos(
loss, loss,
timesteps, timesteps,
learnable_snr_trainer:LearnableSNRGamma learnable_snr_trainer: LearnableSNRGamma
): ):
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps) snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps)
snr = snr * scale + offset snr = (snr + offset_1) * scale + offset_2
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr