mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-14 06:57:35 +00:00
Added refiner fine tuning. Works, but needs some polish.
This commit is contained in:
@@ -26,21 +26,28 @@ from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
||||
from toolkit.sd_device_states_presets import empty_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
UNet2DConditionModel
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
|
||||
SD_PREFIX_VAE = "vae"
|
||||
SD_PREFIX_UNET = "unet"
|
||||
SD_PREFIX_REFINER_UNET = "refiner_unet"
|
||||
SD_PREFIX_TEXT_ENCODER = "te"
|
||||
|
||||
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
||||
@@ -52,6 +59,10 @@ DO_NOT_TRAIN_WEIGHTS = [
|
||||
"unet_time_embedding.linear_1.weight",
|
||||
"unet_time_embedding.linear_2.bias",
|
||||
"unet_time_embedding.linear_2.weight",
|
||||
"refiner_unet_time_embedding.linear_1.bias",
|
||||
"refiner_unet_time_embedding.linear_1.weight",
|
||||
"refiner_unet_time_embedding.linear_2.bias",
|
||||
"refiner_unet_time_embedding.linear_2.weight",
|
||||
]
|
||||
|
||||
DeviceStatePreset = Literal['cache_latents', 'generate']
|
||||
@@ -81,10 +92,6 @@ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも
|
||||
|
||||
# if is type checking
|
||||
if typing.TYPE_CHECKING:
|
||||
from diffusers import \
|
||||
StableDiffusionPipeline, \
|
||||
AutoencoderKL, \
|
||||
UNet2DConditionModel
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
|
||||
@@ -116,6 +123,8 @@ class StableDiffusion:
|
||||
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
||||
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
||||
|
||||
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
|
||||
|
||||
# sdxl stuff
|
||||
self.logit_scale = None
|
||||
self.ckppt_info = None
|
||||
@@ -214,7 +223,7 @@ class StableDiffusion:
|
||||
pipln = StableDiffusionPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path):
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
@@ -263,10 +272,47 @@ class StableDiffusion:
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.pipeline = pipe
|
||||
self.load_refiner()
|
||||
self.is_loaded = True
|
||||
|
||||
def load_refiner(self):
|
||||
# for now, we are just going to rely on the TE from the base model
|
||||
# which is TE2 for SDXL and TE for SD (no refiner currently)
|
||||
# and completely ignore a TE that may or may not be packaged with the refiner
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
|
||||
# load the refiner model
|
||||
dtype = get_torch_dtype(self.dtype)
|
||||
model_path = self.model_config.refiner_name_or_path
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# TODO only load unet??
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
original_config_file=refiner_config_path,
|
||||
).to(self.device_torch)
|
||||
|
||||
self.refiner_unet = refiner.unet
|
||||
del refiner
|
||||
flush()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None):
|
||||
def generate_images(
|
||||
self,
|
||||
image_configs: List[GenerateImageConfig],
|
||||
sampler=None,
|
||||
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
|
||||
):
|
||||
merge_multiplier = 1.0
|
||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if self.network is not None:
|
||||
@@ -289,65 +335,85 @@ class StableDiffusion:
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
|
||||
noise_scheduler = self.noise_scheduler
|
||||
if sampler is not None:
|
||||
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
||||
# using ksampler
|
||||
noise_scheduler = get_sampler('lms')
|
||||
else:
|
||||
noise_scheduler = get_sampler(sampler)
|
||||
|
||||
if sampler.startswith("sample_") and self.is_xl:
|
||||
# using kdiffusion
|
||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||
elif self.is_xl:
|
||||
Pipe = StableDiffusionXLPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionPipeline
|
||||
|
||||
extra_args = {}
|
||||
if self.adapter is not None:
|
||||
if isinstance(self.adapter, T2IAdapter):
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLAdapterPipeline
|
||||
if pipeline is None:
|
||||
noise_scheduler = self.noise_scheduler
|
||||
if sampler is not None:
|
||||
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
||||
# using ksampler
|
||||
noise_scheduler = get_sampler('lms')
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
noise_scheduler = get_sampler(sampler)
|
||||
|
||||
if sampler.startswith("sample_") and self.is_xl:
|
||||
# using kdiffusion
|
||||
Pipe = StableDiffusionKDiffusionXLPipeline
|
||||
elif self.is_xl:
|
||||
Pipe = StableDiffusionXLPipeline
|
||||
else:
|
||||
if self.is_xl:
|
||||
extra_args['add_watermarker'] = False
|
||||
Pipe = StableDiffusionPipeline
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
pipeline.watermark = None
|
||||
else:
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=noise_scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
flush()
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
extra_args = {}
|
||||
if self.adapter is not None:
|
||||
if isinstance(self.adapter, T2IAdapter):
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLAdapterPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
else:
|
||||
if self.is_xl:
|
||||
extra_args['add_watermarker'] = False
|
||||
|
||||
if sampler.startswith("sample_"):
|
||||
pipeline.set_scheduler(sampler)
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
pipeline.watermark = None
|
||||
else:
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=noise_scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
flush()
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if sampler.startswith("sample_"):
|
||||
pipeline.set_scheduler(sampler)
|
||||
|
||||
refiner_pipeline = None
|
||||
if self.refiner_unet:
|
||||
# build refiner pipeline
|
||||
refiner_pipeline = StableDiffusionXLImg2ImgPipeline(
|
||||
vae=pipeline.vae,
|
||||
unet=self.refiner_unet,
|
||||
text_encoder=None,
|
||||
text_encoder_2=pipeline.text_encoder_2,
|
||||
tokenizer=None,
|
||||
tokenizer_2=pipeline.tokenizer_2,
|
||||
scheduler=pipeline.scheduler,
|
||||
add_watermarker=False,
|
||||
requires_aesthetics_score=True,
|
||||
).to(self.device_torch)
|
||||
# refiner_pipeline.register_to_config(requires_aesthetics_score=False)
|
||||
refiner_pipeline.watermark = None
|
||||
refiner_pipeline.set_progress_bar_config(disable=True)
|
||||
flush()
|
||||
|
||||
start_multiplier = 1.0
|
||||
if self.network is not None:
|
||||
@@ -406,14 +472,20 @@ class StableDiffusion:
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||
|
||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||
if self.refiner_unet is not None:
|
||||
# if we have a refiner loaded, set the denoising end at the refiner start
|
||||
extra['denoising_end'] = self.model_config.refiner_start_at
|
||||
extra['output_type'] = 'latent'
|
||||
if not self.is_xl:
|
||||
raise ValueError("Refiner is only supported for XL models")
|
||||
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
# was trained on 0.7 (I believe)
|
||||
|
||||
grs = gen_config.guidance_rescale
|
||||
if grs is None or grs < 0.00001:
|
||||
grs = 0.7
|
||||
# if grs is None or grs < 0.00001:
|
||||
# grs = 0.7
|
||||
# grs = 0.0
|
||||
|
||||
if sampler.startswith("sample_"):
|
||||
@@ -454,10 +526,37 @@ class StableDiffusion:
|
||||
**extra
|
||||
).images[0]
|
||||
|
||||
if refiner_pipeline is not None:
|
||||
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
|
||||
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
|
||||
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
|
||||
refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:]
|
||||
# run through refiner
|
||||
img = refiner_pipeline(
|
||||
# prompt=gen_config.prompt,
|
||||
# prompt_2=gen_config.prompt_2,
|
||||
|
||||
# slice these as it does not use both text encoders
|
||||
# height=gen_config.height,
|
||||
# width=gen_config.width,
|
||||
prompt_embeds=refiner_text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
negative_prompt_embeds=refiner_unconditional_text_embeds,
|
||||
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
guidance_rescale=grs,
|
||||
denoising_start=self.model_config.refiner_start_at,
|
||||
denoising_end=gen_config.num_inference_steps,
|
||||
image=img.unsqueeze(0)
|
||||
).images[0]
|
||||
|
||||
gen_config.save_image(img, i)
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
if refiner_pipeline is not None:
|
||||
del refiner_pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# restore training state
|
||||
@@ -505,7 +604,7 @@ class StableDiffusion:
|
||||
noise = apply_noise_offset(noise, noise_offset)
|
||||
return noise
|
||||
|
||||
def get_time_ids_from_latents(self, latents: torch.Tensor):
|
||||
def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
if self.is_xl:
|
||||
bs, ch, h, w = list(latents.shape)
|
||||
@@ -518,7 +617,13 @@ class StableDiffusion:
|
||||
target_size = (height, width)
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
if requires_aesthetic_score:
|
||||
# refiner
|
||||
# https://huggingface.co/papers/2307.01952
|
||||
aesthetic_score = 6.0 # simulate one
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
else:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
||||
|
||||
@@ -588,14 +693,68 @@ class StableDiffusion:
|
||||
"time_ids": add_time_ids,
|
||||
}
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
).sample
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
# we have the refiner on the second half of everything. Do Both
|
||||
if do_classifier_free_guidance:
|
||||
raise ValueError("Refiner is not supported with classifier free guidance")
|
||||
|
||||
if self.unet.training:
|
||||
input_chunks = torch.chunk(latent_model_input, 2, dim=0)
|
||||
timestep_chunks = torch.chunk(timestep, 2, dim=0)
|
||||
added_cond_kwargs_chunked = {
|
||||
"text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0),
|
||||
"time_ids": torch.chunk(add_time_ids, 2, dim=0),
|
||||
}
|
||||
text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0)
|
||||
|
||||
# predict the noise residual
|
||||
base_pred = self.unet(
|
||||
input_chunks[0],
|
||||
timestep_chunks[0],
|
||||
encoder_hidden_states=text_embeds_chunks[0],
|
||||
added_cond_kwargs={
|
||||
"text_embeds": added_cond_kwargs_chunked['text_embeds'][0],
|
||||
"time_ids": added_cond_kwargs_chunked['time_ids'][0],
|
||||
},
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
refiner_pred = self.refiner_unet(
|
||||
input_chunks[1],
|
||||
timestep_chunks[1],
|
||||
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder
|
||||
added_cond_kwargs={
|
||||
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
|
||||
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
|
||||
"time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True),
|
||||
},
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
noise_pred = torch.cat([base_pred, refiner_pred], dim=0)
|
||||
else:
|
||||
noise_pred = self.refiner_unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:],
|
||||
# just use the first second text encoder
|
||||
added_cond_kwargs={
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True),
|
||||
},
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
else:
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# perform guidance
|
||||
@@ -852,7 +1011,7 @@ class StableDiffusion:
|
||||
state_dict[new_key] = v
|
||||
return state_dict
|
||||
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, state_dict_keys=False) -> OrderedDict[
|
||||
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> OrderedDict[
|
||||
str, Parameter]:
|
||||
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
||||
if vae:
|
||||
@@ -877,6 +1036,10 @@ class StableDiffusion:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
if refiner:
|
||||
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
# convert to state dict keys, jsut replace . with _ on keys
|
||||
if state_dict_keys:
|
||||
new_named_params = OrderedDict()
|
||||
@@ -888,6 +1051,64 @@ class StableDiffusion:
|
||||
|
||||
return named_params
|
||||
|
||||
def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')):
|
||||
|
||||
# load the full refiner since we only train unet
|
||||
if self.model_config.refiner_name_or_path is None:
|
||||
raise ValueError("Refiner must be specified to save it")
|
||||
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
|
||||
# load the refiner model
|
||||
dtype = get_torch_dtype(self.dtype)
|
||||
model_path = self.model_config.refiner_name_or_path
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# TODO only load unet??
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device='cpu',
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
else:
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device='cpu',
|
||||
torch_dtype=self.torch_dtype,
|
||||
original_config_file=refiner_config_path,
|
||||
)
|
||||
# replace original unet
|
||||
refiner.unet = self.refiner_unet
|
||||
flush()
|
||||
|
||||
diffusers_state_dict = OrderedDict()
|
||||
for k, v in refiner.vae.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in refiner.text_encoder_2.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in refiner.unet.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
|
||||
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||
diffusers_state_dict,
|
||||
'sdxl_refiner',
|
||||
device='cpu',
|
||||
dtype=save_dtype
|
||||
)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
version_string = '1'
|
||||
if self.is_v2:
|
||||
@@ -929,6 +1150,8 @@ class StableDiffusion:
|
||||
text_encoder=False,
|
||||
text_encoder_lr=None,
|
||||
unet_lr=None,
|
||||
refiner_lr=None,
|
||||
refiner=False,
|
||||
default_lr=1e-6,
|
||||
):
|
||||
# todo maybe only get locon ones?
|
||||
@@ -974,6 +1197,20 @@ class StableDiffusion:
|
||||
|
||||
print(f"Found {len(params)} trainable parameter in text encoder")
|
||||
|
||||
if refiner:
|
||||
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True)
|
||||
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
|
||||
params = []
|
||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
diffusers_key = f"refiner_{diffusers_key}"
|
||||
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||
if named_params[diffusers_key].requires_grad:
|
||||
params.append(named_params[diffusers_key])
|
||||
param_data = {"params": params, "lr": refiner_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
|
||||
print(f"Found {len(params)} trainable parameter in refiner")
|
||||
|
||||
return trainable_parameters
|
||||
|
||||
def save_device_state(self):
|
||||
@@ -1021,6 +1258,13 @@ class StableDiffusion:
|
||||
'requires_grad': requires_grad,
|
||||
}
|
||||
|
||||
if self.refiner_unet is not None:
|
||||
self.device_state['refiner_unet'] = {
|
||||
'training': self.refiner_unet.training,
|
||||
'device': self.refiner_unet.device,
|
||||
'requires_grad': self.refiner_unet.conv_in.weight.requires_grad,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
# restores the device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
@@ -1075,6 +1319,14 @@ class StableDiffusion:
|
||||
self.adapter.train()
|
||||
else:
|
||||
self.adapter.eval()
|
||||
|
||||
if self.refiner_unet is not None:
|
||||
self.refiner_unet.to(state['refiner_unet']['device'])
|
||||
self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad'])
|
||||
if state['refiner_unet']['training']:
|
||||
self.refiner_unet.train()
|
||||
else:
|
||||
self.refiner_unet.eval()
|
||||
flush()
|
||||
|
||||
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
||||
@@ -1088,7 +1340,7 @@ class StableDiffusion:
|
||||
if device_state_preset in ['cache_latents']:
|
||||
active_modules = ['vae']
|
||||
if device_state_preset in ['generate']:
|
||||
active_modules = ['vae', 'unet', 'text_encoder', 'adapter']
|
||||
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']
|
||||
|
||||
state = copy.deepcopy(empty_preset)
|
||||
# vae
|
||||
@@ -1105,6 +1357,13 @@ class StableDiffusion:
|
||||
'requires_grad': 'unet' in training_modules,
|
||||
}
|
||||
|
||||
if self.refiner_unet is not None:
|
||||
state['refiner_unet'] = {
|
||||
'training': 'refiner_unet' in training_modules,
|
||||
'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu',
|
||||
'requires_grad': 'refiner_unet' in training_modules,
|
||||
}
|
||||
|
||||
# text encoder
|
||||
if isinstance(self.text_encoder, list):
|
||||
state['text_encoder'] = []
|
||||
|
||||
Reference in New Issue
Block a user