Adapter work. Bug fixes. Auto adjust LR when resuming optimizer.

This commit is contained in:
Jaret Burkett
2024-03-17 10:21:47 -06:00
parent 72de68d8aa
commit 016687bda1
8 changed files with 84 additions and 15 deletions

View File

@@ -7,7 +7,7 @@ import numpy as np
from PIL import Image
from diffusers import T2IAdapter
from torch.utils.data import DataLoader
from diffusers import StableDiffusionXLAdapterPipeline
from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
from tqdm import tqdm
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
@@ -100,25 +100,43 @@ class ReferenceGenerator(BaseExtensionProcess):
if self.generate_config.t2i_adapter_path is not None:
self.adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16"
self.generate_config.t2i_adapter_path,
torch_dtype=self.torch_dtype,
varient="fp16"
).to(device)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to(device)
pipe = StableDiffusionXLAdapterPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=get_sampler(self.generate_config.sampler),
adapter=self.adapter,
).to(device)
if self.model_config.is_xl:
pipe = StableDiffusionXLAdapterPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=get_sampler(self.generate_config.sampler),
adapter=self.adapter,
).to(device, dtype=self.torch_dtype)
else:
pipe = StableDiffusionAdapterPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder,
tokenizer=self.sd.tokenizer,
scheduler=get_sampler(self.generate_config.sampler),
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
adapter=self.adapter,
).to(device, dtype=self.torch_dtype)
pipe.set_progress_bar_config(disable=True)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
num_batches = len(self.data_loader)
@@ -176,6 +194,7 @@ class ReferenceGenerator(BaseExtensionProcess):
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
guidance_scale=self.generate_config.guidance_scale,
).images[0]
os.makedirs(os.path.dirname(output_path), exist_ok=True)
gen_images.save(output_path)
# save caption

View File

@@ -796,6 +796,7 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
rescale_cfg=self.train_config.cfg_rescale,
**pred_kwargs # adapter residuals in here
)
if was_unet_training:
@@ -1355,6 +1356,7 @@ class SDTrainer(BaseSDTrainProcess):
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
**pred_kwargs
)
self.after_unet_predict()

View File

@@ -1352,6 +1352,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# try to load
# previous param groups
# previous_params = copy.deepcopy(optimizer.param_groups)
previous_lrs = []
for group in optimizer.param_groups:
previous_lrs.append(group['lr'])
try:
print(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path)
@@ -1360,6 +1364,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
print(e)
# update the optimizer LR from the params
print(f"Updating optimizer LR from params")
if len(previous_lrs) > 0:
for i, group in enumerate(optimizer.param_groups):
group['lr'] = previous_lrs[i]
group['initial_lr'] = previous_lrs[i]
# Update the learning rates if they changed
# optimizer.param_groups = previous_params

View File

@@ -42,6 +42,7 @@ dataset_config = DatasetConfig(
resolution=resolution,
caption_ext='json',
default_caption='default',
clip_image_path='/mnt/Datasets/face_pairs2/control_clean',
buckets=True,
bucket_tolerance=bucket_tolerance,
poi='person',

View File

@@ -218,7 +218,7 @@ class TrainConfig:
self.xformers = kwargs.get('xformers', False)
self.sdp = kwargs.get('sdp', False)
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.train_text_encoder = kwargs.get('train_text_encoder', False)
self.train_refiner = kwargs.get('train_refiner', True)
self.train_turbo = kwargs.get('train_turbo', False)
self.show_turbo_outputs = kwargs.get('show_turbo_outputs', False)
@@ -298,6 +298,9 @@ class TrainConfig:
self.do_random_cfg = kwargs.get('do_random_cfg', False)
self.cfg_scale = kwargs.get('cfg_scale', 1.0)
self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale)
self.cfg_rescale = kwargs.get('cfg_rescale', None)
if self.cfg_rescale is None:
self.cfg_rescale = self.cfg_scale
# applies the inverse of the prediction mean and std to the target to correct
# for norm drift

View File

@@ -132,9 +132,16 @@ class CustomAdapter(torch.nn.Module):
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
if self.config.image_encoder_arch == 'clip':
vision_tokens = vision_tokens + 1
vision_hidden_size = self.vision_encoder.config.hidden_size
if self.config.clip_layer == 'image_embeds':
vision_tokens = 1
vision_hidden_size = self.vision_encoder.config.projection_dim
self.ilora_module = InstantLoRAModule(
vision_tokens=vision_tokens,
vision_hidden_size=self.vision_encoder.config.hidden_size,
vision_hidden_size=vision_hidden_size,
sd=self.sd_ref()
)
elif self.adapter_type == 'text_encoder':
@@ -731,7 +738,14 @@ class CustomAdapter(torch.nn.Module):
clip_image, output_hidden_states=True
)
img_embeds = id_embeds['last_hidden_state']
if self.config.clip_layer == 'penultimate_hidden_states':
img_embeds = id_embeds.hidden_states[-2]
elif self.config.clip_layer == 'last_hidden_state':
img_embeds = id_embeds.hidden_states[-1]
elif self.config.clip_layer == 'image_embeds':
img_embeds = id_embeds.image_embeds
else:
raise ValueError(f"unknown clip layer: {self.config.clip_layer}")
if self.config.quad_image:
# get the outputs of the quat

View File

@@ -106,6 +106,9 @@ class InstantLoRAModule(torch.nn.Module):
# this will be used to add the vector to the original forward
def forward(self, img_embeds):
# expand token rank if only rank 2
if len(img_embeds.shape) == 2:
img_embeds = img_embeds.unsqueeze(1)
img_embeds = self.resampler(img_embeds)
self.img_embeds = img_embeds

View File

@@ -863,6 +863,7 @@ class StableDiffusion:
unconditional_embeddings: Union[PromptEmbeds, None] = None,
is_input_scaled=False,
detach_unconditional=False,
rescale_cfg=None,
**kwargs,
):
# get the embeddings
@@ -1111,6 +1112,21 @@ class StableDiffusion:
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if rescale_cfg is not None and rescale_cfg != guidance_scale:
with torch.no_grad():
# do cfg at the target rescale so we can match it
target_pred_mean_std = noise_pred_uncond + rescale_cfg * (
noise_pred_text - noise_pred_uncond
)
target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach()
target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach()
pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach()
pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach()
# match the mean and std
noise_pred = (noise_pred - pred_mean) / pred_std
noise_pred = (noise_pred * target_std) + target_mean
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
if guidance_rescale > 0.0: