Added adapter assistance to SD training

This commit is contained in:
Jaret Burkett
2023-10-14 08:44:53 -06:00
parent 38e441a29c
commit 7909b50d24
2 changed files with 97 additions and 79 deletions

View File

@@ -1,10 +1,12 @@
import os.path
from collections import OrderedDict
from typing import Union
from PIL import Image
from diffusers import T2IAdapter
from torch.utils.data import DataLoader
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
@@ -30,10 +32,26 @@ class SDTrainer(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', None]
def before_model_load(self):
pass
def before_dataset_load(self):
self.assistant_adapter = None
# get adapter assistant if one is set
if self.train_config.adapter_assist_name_or_path is not None:
adapter_path = self.train_config.adapter_assist_name_or_path
# dont name this adapter since we are not training it
self.assistant_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
).to(self.device_torch)
self.assistant_adapter.eval()
self.assistant_adapter.requires_grad_(False)
flush()
def hook_before_train_loop(self):
# move vae to device if we did not cache latents
if not self.is_latents_cached:
@@ -44,53 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu')
flush()
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
if self.adapter_config.image_dir is None:
# adapter needs 0 to 1 values, batch is -1 to 1
adapter_batch = batch.tensor.clone().to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
adapter_batch = (adapter_batch + 1) / 2
return adapter_batch
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
adapter_folder_path = self.adapter_config.image_dir
adapter_images = []
# loop through images
for file_item in batch.file_items:
img_path = file_item.path
file_name_no_ext = os.path.basename(img_path).split('.')[0]
# find the image
for ext in img_ext_list:
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
break
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
adapter_tensors = []
# load images with torch transforms
for idx, adapter_image in enumerate(adapter_images):
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
# to the smallest dimension
img: Image.Image = Image.open(adapter_image)
if img.width > img.height:
# scale down so height is the same as batch
new_height = height
new_width = int(img.width * (height / img.height))
else:
new_width = width
new_height = int(img.height * (width / img.width))
img = img.resize((new_width, new_height))
crop_fn = transforms.CenterCrop((height, width))
# crop the center to match batch
img = crop_fn(img)
img = adapter_transforms(img)
adapter_tensors.append(img)
# stack them
adapter_tensors = torch.stack(adapter_tensors).to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
return adapter_tensors
def hook_train_loop(self, batch):
@@ -98,18 +69,21 @@ class SDTrainer(BaseSDTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
has_adapter_img = batch.control_tensor is not None
self.timer.stop('preprocess_batch')
with torch.no_grad():
adapter_images = None
sigmas = None
if self.adapter:
if has_adapter_img and (self.adapter or self.assistant_adapter):
with self.timer('get_adapter_images'):
# todo move this to data loader
if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
else:
adapter_images = self.get_adapter_images(batch)
raise NotImplementedError("Adapter images now must be loaded with dataloader")
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
@@ -128,9 +102,36 @@ class SDTrainer(BaseSDTrainProcess):
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
def get_adapter_multiplier():
if self.adapter and isinstance(self.adapter, T2IAdapter):
# training a t2i adapter, not using as assistant.
return 1.0
elif self.train_config.match_adapter_assist:
# training a texture. We want it high
adapter_strength_min = 0.9
adapter_strength_max = 1.0
else:
# training with assistance, we want it low
# adapter_strength_min = 0.5
# adapter_strength_max = 0.8
adapter_strength_min = 0.9
adapter_strength_max = 1.1
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return adapter_conditioning_scale
# flush()
with self.timer('grad_setup'):
self.optimizer.zero_grad()
# text encoding
grad_on_text_encoder = False
@@ -148,6 +149,7 @@ class SDTrainer(BaseSDTrainProcess):
# set the weights
network.multiplier = network_weight_list
self.optimizer.zero_grad(set_to_none=True)
# activate network if it exits
with network:
@@ -159,15 +161,44 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = conditional_embeds.detach()
# flush()
pred_kwargs = {}
if self.adapter and isinstance(self.adapter, T2IAdapter):
with self.timer('encode_adapter'):
down_block_additional_residuals = self.adapter(adapter_images)
down_block_additional_residuals = [
sample.to(dtype=dtype) for sample in down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.adapter if self.adapter else self.assistant_adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals
]
if self.adapter and isinstance(self.adapter, IPAdapter):
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
control_pred = None
if has_adapter_img and self.assistant_adapter and self.train_config.match_adapter_assist:
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
# dont use network on this
network.multiplier = 0.0
control_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
control_pred = control_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
del pred_kwargs['down_block_additional_residuals']
# restore network
network.multiplier = network_weight_list
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
@@ -183,29 +214,13 @@ class SDTrainer(BaseSDTrainProcess):
**pred_kwargs
)
# if self.adapter:
# # todo, diffusers does this on t2i training, is it better approach?
# # Denoise the latents
# denoised_latents = noise_pred * (-sigmas) + noisy_latents
# weighing = sigmas ** -2.0
#
# # Get the target for loss depending on the prediction type
# if self.sd.noise_scheduler.config.prediction_type == "epsilon":
# target = batch.latents # we are computing loss against denoise latents
# elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
# target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
# else:
# raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
#
# # MSE loss
# loss = torch.mean(
# (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
# dim=1,
# )
# else:
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
if control_pred is not None:
# matching adapter prediction
target = control_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:

View File

@@ -121,6 +121,9 @@ class TrainConfig:
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
self.start_step = kwargs.get('start_step', None)
self.free_u = kwargs.get('free_u', False)
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.match_adapter_assist = kwargs.get('match_adapter_assist', False)
class ModelConfig: