mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Adapter work. Bug fixes. Auto adjust LR when resuming optimizer.
This commit is contained in:
@@ -7,7 +7,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from torch.utils.data import DataLoader
|
||||
from diffusers import StableDiffusionXLAdapterPipeline
|
||||
from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
||||
@@ -100,25 +100,43 @@ class ReferenceGenerator(BaseExtensionProcess):
|
||||
|
||||
if self.generate_config.t2i_adapter_path is not None:
|
||||
self.adapter = T2IAdapter.from_pretrained(
|
||||
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16"
|
||||
self.generate_config.t2i_adapter_path,
|
||||
torch_dtype=self.torch_dtype,
|
||||
varient="fp16"
|
||||
).to(device)
|
||||
|
||||
midas_depth = MidasDetector.from_pretrained(
|
||||
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
|
||||
).to(device)
|
||||
|
||||
pipe = StableDiffusionXLAdapterPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder[0],
|
||||
text_encoder_2=self.sd.text_encoder[1],
|
||||
tokenizer=self.sd.tokenizer[0],
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=get_sampler(self.generate_config.sampler),
|
||||
adapter=self.adapter,
|
||||
).to(device)
|
||||
if self.model_config.is_xl:
|
||||
pipe = StableDiffusionXLAdapterPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder[0],
|
||||
text_encoder_2=self.sd.text_encoder[1],
|
||||
tokenizer=self.sd.tokenizer[0],
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=get_sampler(self.generate_config.sampler),
|
||||
adapter=self.adapter,
|
||||
).to(device, dtype=self.torch_dtype)
|
||||
else:
|
||||
pipe = StableDiffusionAdapterPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder,
|
||||
tokenizer=self.sd.tokenizer,
|
||||
scheduler=get_sampler(self.generate_config.sampler),
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
adapter=self.adapter,
|
||||
).to(device, dtype=self.torch_dtype)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
||||
|
||||
num_batches = len(self.data_loader)
|
||||
@@ -176,6 +194,7 @@ class ReferenceGenerator(BaseExtensionProcess):
|
||||
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
).images[0]
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
gen_images.save(output_path)
|
||||
|
||||
# save caption
|
||||
|
||||
@@ -796,6 +796,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
if was_unet_training:
|
||||
@@ -1355,6 +1356,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
@@ -1352,6 +1352,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# try to load
|
||||
# previous param groups
|
||||
# previous_params = copy.deepcopy(optimizer.param_groups)
|
||||
previous_lrs = []
|
||||
for group in optimizer.param_groups:
|
||||
previous_lrs.append(group['lr'])
|
||||
|
||||
try:
|
||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
||||
@@ -1360,6 +1364,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print(e)
|
||||
|
||||
# update the optimizer LR from the params
|
||||
print(f"Updating optimizer LR from params")
|
||||
if len(previous_lrs) > 0:
|
||||
for i, group in enumerate(optimizer.param_groups):
|
||||
group['lr'] = previous_lrs[i]
|
||||
group['initial_lr'] = previous_lrs[i]
|
||||
|
||||
# Update the learning rates if they changed
|
||||
# optimizer.param_groups = previous_params
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ dataset_config = DatasetConfig(
|
||||
resolution=resolution,
|
||||
caption_ext='json',
|
||||
default_caption='default',
|
||||
clip_image_path='/mnt/Datasets/face_pairs2/control_clean',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
poi='person',
|
||||
|
||||
@@ -218,7 +218,7 @@ class TrainConfig:
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
self.sdp = kwargs.get('sdp', False)
|
||||
self.train_unet = kwargs.get('train_unet', True)
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', False)
|
||||
self.train_refiner = kwargs.get('train_refiner', True)
|
||||
self.train_turbo = kwargs.get('train_turbo', False)
|
||||
self.show_turbo_outputs = kwargs.get('show_turbo_outputs', False)
|
||||
@@ -298,6 +298,9 @@ class TrainConfig:
|
||||
self.do_random_cfg = kwargs.get('do_random_cfg', False)
|
||||
self.cfg_scale = kwargs.get('cfg_scale', 1.0)
|
||||
self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale)
|
||||
self.cfg_rescale = kwargs.get('cfg_rescale', None)
|
||||
if self.cfg_rescale is None:
|
||||
self.cfg_rescale = self.cfg_scale
|
||||
|
||||
# applies the inverse of the prediction mean and std to the target to correct
|
||||
# for norm drift
|
||||
|
||||
@@ -132,9 +132,16 @@ class CustomAdapter(torch.nn.Module):
|
||||
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
|
||||
if self.config.image_encoder_arch == 'clip':
|
||||
vision_tokens = vision_tokens + 1
|
||||
|
||||
vision_hidden_size = self.vision_encoder.config.hidden_size
|
||||
|
||||
if self.config.clip_layer == 'image_embeds':
|
||||
vision_tokens = 1
|
||||
vision_hidden_size = self.vision_encoder.config.projection_dim
|
||||
|
||||
self.ilora_module = InstantLoRAModule(
|
||||
vision_tokens=vision_tokens,
|
||||
vision_hidden_size=self.vision_encoder.config.hidden_size,
|
||||
vision_hidden_size=vision_hidden_size,
|
||||
sd=self.sd_ref()
|
||||
)
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
@@ -731,7 +738,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
img_embeds = id_embeds['last_hidden_state']
|
||||
if self.config.clip_layer == 'penultimate_hidden_states':
|
||||
img_embeds = id_embeds.hidden_states[-2]
|
||||
elif self.config.clip_layer == 'last_hidden_state':
|
||||
img_embeds = id_embeds.hidden_states[-1]
|
||||
elif self.config.clip_layer == 'image_embeds':
|
||||
img_embeds = id_embeds.image_embeds
|
||||
else:
|
||||
raise ValueError(f"unknown clip layer: {self.config.clip_layer}")
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
|
||||
@@ -106,6 +106,9 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# this will be used to add the vector to the original forward
|
||||
|
||||
def forward(self, img_embeds):
|
||||
# expand token rank if only rank 2
|
||||
if len(img_embeds.shape) == 2:
|
||||
img_embeds = img_embeds.unsqueeze(1)
|
||||
img_embeds = self.resampler(img_embeds)
|
||||
self.img_embeds = img_embeds
|
||||
|
||||
|
||||
@@ -863,6 +863,7 @@ class StableDiffusion:
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
is_input_scaled=False,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=None,
|
||||
**kwargs,
|
||||
):
|
||||
# get the embeddings
|
||||
@@ -1111,6 +1112,21 @@ class StableDiffusion:
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
if rescale_cfg is not None and rescale_cfg != guidance_scale:
|
||||
with torch.no_grad():
|
||||
# do cfg at the target rescale so we can match it
|
||||
target_pred_mean_std = noise_pred_uncond + rescale_cfg * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach()
|
||||
target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach()
|
||||
|
||||
pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach()
|
||||
pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach()
|
||||
|
||||
# match the mean and std
|
||||
noise_pred = (noise_pred - pred_mean) / pred_std
|
||||
noise_pred = (noise_pred * target_std) + target_mean
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||
if guidance_rescale > 0.0:
|
||||
|
||||
Reference in New Issue
Block a user