Added refiner fine tuning. Works, but needs some polish.

This commit is contained in:
Jaret Burkett
2023-11-05 17:15:03 -07:00
parent 8a9e8f708f
commit 93ea955d7c
14 changed files with 4541 additions and 128 deletions

View File

@@ -26,21 +26,28 @@ from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
from toolkit.sampler import get_sampler
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
from toolkit.sd_device_states_presets import empty_preset
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline
import diffusers
from diffusers import \
AutoencoderKL, \
UNet2DConditionModel
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
SD_PREFIX_VAE = "vae"
SD_PREFIX_UNET = "unet"
SD_PREFIX_REFINER_UNET = "refiner_unet"
SD_PREFIX_TEXT_ENCODER = "te"
SD_PREFIX_TEXT_ENCODER1 = "te0"
@@ -52,6 +59,10 @@ DO_NOT_TRAIN_WEIGHTS = [
"unet_time_embedding.linear_1.weight",
"unet_time_embedding.linear_2.bias",
"unet_time_embedding.linear_2.weight",
"refiner_unet_time_embedding.linear_1.bias",
"refiner_unet_time_embedding.linear_1.weight",
"refiner_unet_time_embedding.linear_2.bias",
"refiner_unet_time_embedding.linear_2.weight",
]
DeviceStatePreset = Literal['cache_latents', 'generate']
@@ -81,10 +92,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
# if is type checking
if typing.TYPE_CHECKING:
from diffusers import \
StableDiffusionPipeline, \
AutoencoderKL, \
UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
@@ -116,6 +123,8 @@ class StableDiffusion:
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
# sdxl stuff
self.logit_scale = None
self.ckppt_info = None
@@ -214,7 +223,7 @@ class StableDiffusion:
pipln = StableDiffusionPipeline
# see if path exists
if not os.path.exists(model_path):
if not os.path.exists(model_path) or os.path.isdir(model_path):
# try to load with default diffusers
pipe = pipln.from_pretrained(
model_path,
@@ -263,10 +272,47 @@ class StableDiffusion:
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.pipeline = pipe
self.load_refiner()
self.is_loaded = True
def load_refiner(self):
# for now, we are just going to rely on the TE from the base model
# which is TE2 for SDXL and TE for SD (no refiner currently)
# and completely ignore a TE that may or may not be packaged with the refiner
if self.model_config.refiner_name_or_path is not None:
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
# load the refiner model
dtype = get_torch_dtype(self.dtype)
model_path = self.model_config.refiner_name_or_path
if not os.path.exists(model_path) or os.path.isdir(model_path):
# TODO only load unet??
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
model_path,
dtype=dtype,
device=self.device_torch,
variant="fp16",
use_safetensors=True,
).to(self.device_torch)
else:
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
model_path,
dtype=dtype,
device=self.device_torch,
torch_dtype=self.torch_dtype,
original_config_file=refiner_config_path,
).to(self.device_torch)
self.refiner_unet = refiner.unet
del refiner
flush()
@torch.no_grad()
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
def generate_images(
self,
image_configs: List[GenerateImageConfig],
sampler=None,
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
):
merge_multiplier = 1.0
# sample_folder = os.path.join(self.save_root, 'samples')
if self.network is not None:
@@ -289,65 +335,85 @@ class StableDiffusion:
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
noise_scheduler = self.noise_scheduler
if sampler is not None:
if sampler.startswith("sample_"): # sample_dpmpp_2m
# using ksampler
noise_scheduler = get_sampler('lms')
else:
noise_scheduler = get_sampler(sampler)
if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline
elif self.is_xl:
Pipe = StableDiffusionXLPipeline
else:
Pipe = StableDiffusionPipeline
extra_args = {}
if self.adapter is not None:
if isinstance(self.adapter, T2IAdapter):
if self.is_xl:
Pipe = StableDiffusionXLAdapterPipeline
if pipeline is None:
noise_scheduler = self.noise_scheduler
if sampler is not None:
if sampler.startswith("sample_"): # sample_dpmpp_2m
# using ksampler
noise_scheduler = get_sampler('lms')
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
noise_scheduler = get_sampler(sampler)
if sampler.startswith("sample_") and self.is_xl:
# using kdiffusion
Pipe = StableDiffusionKDiffusionXLPipeline
elif self.is_xl:
Pipe = StableDiffusionXLPipeline
else:
if self.is_xl:
extra_args['add_watermarker'] = False
Pipe = StableDiffusionPipeline
# TODO add clip skip
if self.is_xl:
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
pipeline.watermark = None
else:
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
**extra_args
).to(self.device_torch)
flush()
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
extra_args = {}
if self.adapter is not None:
if isinstance(self.adapter, T2IAdapter):
if self.is_xl:
Pipe = StableDiffusionXLAdapterPipeline
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
else:
if self.is_xl:
extra_args['add_watermarker'] = False
if sampler.startswith("sample_"):
pipeline.set_scheduler(sampler)
# TODO add clip skip
if self.is_xl:
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
pipeline.watermark = None
else:
pipeline = Pipe(
vae=self.vae,
unet=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=noise_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
**extra_args
).to(self.device_torch)
flush()
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
if sampler.startswith("sample_"):
pipeline.set_scheduler(sampler)
refiner_pipeline = None
if self.refiner_unet:
# build refiner pipeline
refiner_pipeline = StableDiffusionXLImg2ImgPipeline(
vae=pipeline.vae,
unet=self.refiner_unet,
text_encoder=None,
text_encoder_2=pipeline.text_encoder_2,
tokenizer=None,
tokenizer_2=pipeline.tokenizer_2,
scheduler=pipeline.scheduler,
add_watermarker=False,
requires_aesthetics_score=True,
).to(self.device_torch)
# refiner_pipeline.register_to_config(requires_aesthetics_score=False)
refiner_pipeline.watermark = None
refiner_pipeline.set_progress_bar_config(disable=True)
flush()
start_multiplier = 1.0
if self.network is not None:
@@ -406,14 +472,20 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
if self.refiner_unet is not None:
# if we have a refiner loaded, set the denoising end at the refiner start
extra['denoising_end'] = self.model_config.refiner_start_at
extra['output_type'] = 'latent'
if not self.is_xl:
raise ValueError("Refiner is only supported for XL models")
if self.is_xl:
# fix guidance rescale for sdxl
# was trained on 0.7 (I believe)
grs = gen_config.guidance_rescale
if grs is None or grs < 0.00001:
grs = 0.7
# if grs is None or grs < 0.00001:
# grs = 0.7
# grs = 0.0
if sampler.startswith("sample_"):
@@ -454,10 +526,37 @@ class StableDiffusion:
**extra
).images[0]
if refiner_pipeline is not None:
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:]
# run through refiner
img = refiner_pipeline(
# prompt=gen_config.prompt,
# prompt_2=gen_config.prompt_2,
# slice these as it does not use both text encoders
# height=gen_config.height,
# width=gen_config.width,
prompt_embeds=refiner_text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
negative_prompt_embeds=refiner_unconditional_text_embeds,
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
guidance_rescale=grs,
denoising_start=self.model_config.refiner_start_at,
denoising_end=gen_config.num_inference_steps,
image=img.unsqueeze(0)
).images[0]
gen_config.save_image(img, i)
# clear pipeline and cache to reduce vram usage
del pipeline
if refiner_pipeline is not None:
del refiner_pipeline
torch.cuda.empty_cache()
# restore training state
@@ -505,7 +604,7 @@ class StableDiffusion:
noise = apply_noise_offset(noise, noise_offset)
return noise
def get_time_ids_from_latents(self, latents: torch.Tensor):
def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
if self.is_xl:
bs, ch, h, w = list(latents.shape)
@@ -518,7 +617,13 @@ class StableDiffusion:
target_size = (height, width)
original_size = (height, width)
crops_coords_top_left = (0, 0)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
if requires_aesthetic_score:
# refiner
# https://huggingface.co/papers/2307.01952
aesthetic_score = 6.0 # simulate one
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
else:
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
@@ -588,14 +693,68 @@ class StableDiffusion:
"time_ids": add_time_ids,
}
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
).sample
if self.model_config.refiner_name_or_path is not None:
# we have the refiner on the second half of everything. Do Both
if do_classifier_free_guidance:
raise ValueError("Refiner is not supported with classifier free guidance")
if self.unet.training:
input_chunks = torch.chunk(latent_model_input, 2, dim=0)
timestep_chunks = torch.chunk(timestep, 2, dim=0)
added_cond_kwargs_chunked = {
"text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0),
"time_ids": torch.chunk(add_time_ids, 2, dim=0),
}
text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0)
# predict the noise residual
base_pred = self.unet(
input_chunks[0],
timestep_chunks[0],
encoder_hidden_states=text_embeds_chunks[0],
added_cond_kwargs={
"text_embeds": added_cond_kwargs_chunked['text_embeds'][0],
"time_ids": added_cond_kwargs_chunked['time_ids'][0],
},
**kwargs,
).sample
refiner_pred = self.refiner_unet(
input_chunks[1],
timestep_chunks[1],
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
added_cond_kwargs={
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
"time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True),
},
**kwargs,
).sample
noise_pred = torch.cat([base_pred, refiner_pred], dim=0)
else:
noise_pred = self.refiner_unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:],
# just use the first second text encoder
added_cond_kwargs={
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
},
**kwargs,
).sample
else:
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
).sample
if do_classifier_free_guidance:
# perform guidance
@@ -852,7 +1011,7 @@ class StableDiffusion:
state_dict[new_key] = v
return state_dict
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
str, Parameter]:
named_params: OrderedDict[str, Parameter] = OrderedDict()
if vae:
@@ -877,6 +1036,10 @@ class StableDiffusion:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
if refiner:
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
named_params[name] = param
# convert to state dict keys, jsut replace . with _ on keys
if state_dict_keys:
new_named_params = OrderedDict()
@@ -888,6 +1051,64 @@ class StableDiffusion:
return named_params
def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')):
# load the full refiner since we only train unet
if self.model_config.refiner_name_or_path is None:
raise ValueError("Refiner must be specified to save it")
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
# load the refiner model
dtype = get_torch_dtype(self.dtype)
model_path = self.model_config.refiner_name_or_path
if not os.path.exists(model_path) or os.path.isdir(model_path):
# TODO only load unet??
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
model_path,
dtype=dtype,
device='cpu',
variant="fp16",
use_safetensors=True,
)
else:
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
model_path,
dtype=dtype,
device='cpu',
torch_dtype=self.torch_dtype,
original_config_file=refiner_config_path,
)
# replace original unet
refiner.unet = self.refiner_unet
flush()
diffusers_state_dict = OrderedDict()
for k, v in refiner.vae.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
diffusers_state_dict[new_key] = v
for k, v in refiner.text_encoder_2.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
diffusers_state_dict[new_key] = v
for k, v in refiner.unet.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
diffusers_state_dict[new_key] = v
converted_state_dict = get_ldm_state_dict_from_diffusers(
diffusers_state_dict,
'sdxl_refiner',
device='cpu',
dtype=save_dtype
)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
if self.config_file is not None:
output_path_no_ext = os.path.splitext(output_file)[0]
output_config_path = f"{output_path_no_ext}.yaml"
shutil.copyfile(self.config_file, output_config_path)
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
version_string = '1'
if self.is_v2:
@@ -929,6 +1150,8 @@ class StableDiffusion:
text_encoder=False,
text_encoder_lr=None,
unet_lr=None,
refiner_lr=None,
refiner=False,
default_lr=1e-6,
):
# todo maybe only get locon ones?
@@ -974,6 +1197,20 @@ class StableDiffusion:
print(f"Found {len(params)} trainable parameter in text encoder")
if refiner:
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
params = []
for key, diffusers_key in ldm_diffusers_keymap.items():
diffusers_key = f"refiner_{diffusers_key}"
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
if named_params[diffusers_key].requires_grad:
params.append(named_params[diffusers_key])
param_data = {"params": params, "lr": refiner_lr}
trainable_parameters.append(param_data)
print(f"Found {len(params)} trainable parameter in refiner")
return trainable_parameters
def save_device_state(self):
@@ -1021,6 +1258,13 @@ class StableDiffusion:
'requires_grad': requires_grad,
}
if self.refiner_unet is not None:
self.device_state['refiner_unet'] = {
'training': self.refiner_unet.training,
'device': self.refiner_unet.device,
'requires_grad': self.refiner_unet.conv_in.weight.requires_grad,
}
def restore_device_state(self):
# restores the device state for all modules
# this is useful for when we want to alter the state and restore it
@@ -1075,6 +1319,14 @@ class StableDiffusion:
self.adapter.train()
else:
self.adapter.eval()
if self.refiner_unet is not None:
self.refiner_unet.to(state['refiner_unet']['device'])
self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad'])
if state['refiner_unet']['training']:
self.refiner_unet.train()
else:
self.refiner_unet.eval()
flush()
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
@@ -1088,7 +1340,7 @@ class StableDiffusion:
if device_state_preset in ['cache_latents']:
active_modules = ['vae']
if device_state_preset in ['generate']:
active_modules = ['vae', 'unet', 'text_encoder', 'adapter']
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']
state = copy.deepcopy(empty_preset)
# vae
@@ -1105,6 +1357,13 @@ class StableDiffusion:
'requires_grad': 'unet' in training_modules,
}
if self.refiner_unet is not None:
state['refiner_unet'] = {
'training': 'refiner_unet' in training_modules,
'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu',
'requires_grad': 'refiner_unet' in training_modules,
}
# text encoder
if isinstance(self.text_encoder, list):
state['text_encoder'] = []