Added adapter assistance to SD training

This commit is contained in:
Jaret Burkett
2023-10-14 08:44:53 -06:00
parent 38e441a29c
commit 7909b50d24
2 changed files with 97 additions and 79 deletions

View File

@@ -1,10 +1,12 @@
import os.path import os.path
from collections import OrderedDict from collections import OrderedDict
from typing import Union
from PIL import Image from PIL import Image
from diffusers import T2IAdapter from diffusers import T2IAdapter
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
@@ -30,10 +32,26 @@ class SDTrainer(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs) super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', None]
def before_model_load(self): def before_model_load(self):
pass pass
def before_dataset_load(self):
self.assistant_adapter = None
# get adapter assistant if one is set
if self.train_config.adapter_assist_name_or_path is not None:
adapter_path = self.train_config.adapter_assist_name_or_path
# dont name this adapter since we are not training it
self.assistant_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
).to(self.device_torch)
self.assistant_adapter.eval()
self.assistant_adapter.requires_grad_(False)
flush()
def hook_before_train_loop(self): def hook_before_train_loop(self):
# move vae to device if we did not cache latents # move vae to device if we did not cache latents
if not self.is_latents_cached: if not self.is_latents_cached:
@@ -44,53 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu') self.sd.vae.to('cpu')
flush() flush()
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
if self.adapter_config.image_dir is None:
# adapter needs 0 to 1 values, batch is -1 to 1
adapter_batch = batch.tensor.clone().to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
adapter_batch = (adapter_batch + 1) / 2
return adapter_batch
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
adapter_folder_path = self.adapter_config.image_dir
adapter_images = []
# loop through images
for file_item in batch.file_items:
img_path = file_item.path
file_name_no_ext = os.path.basename(img_path).split('.')[0]
# find the image
for ext in img_ext_list:
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
break
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
adapter_tensors = []
# load images with torch transforms
for idx, adapter_image in enumerate(adapter_images):
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
# to the smallest dimension
img: Image.Image = Image.open(adapter_image)
if img.width > img.height:
# scale down so height is the same as batch
new_height = height
new_width = int(img.width * (height / img.height))
else:
new_width = width
new_height = int(img.height * (width / img.width))
img = img.resize((new_width, new_height))
crop_fn = transforms.CenterCrop((height, width))
# crop the center to match batch
img = crop_fn(img)
img = adapter_transforms(img)
adapter_tensors.append(img)
# stack them
adapter_tensors = torch.stack(adapter_tensors).to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
return adapter_tensors
def hook_train_loop(self, batch): def hook_train_loop(self, batch):
@@ -98,18 +69,21 @@ class SDTrainer(BaseSDTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list() network_weight_list = batch.get_network_weight_list()
has_adapter_img = batch.control_tensor is not None
self.timer.stop('preprocess_batch') self.timer.stop('preprocess_batch')
with torch.no_grad(): with torch.no_grad():
adapter_images = None adapter_images = None
sigmas = None sigmas = None
if self.adapter: if has_adapter_img and (self.adapter or self.assistant_adapter):
with self.timer('get_adapter_images'): with self.timer('get_adapter_images'):
# todo move this to data loader # todo move this to data loader
if batch.control_tensor is not None: if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
else: else:
adapter_images = self.get_adapter_images(batch) raise NotImplementedError("Adapter images now must be loaded with dataloader")
# not 100% sure what this does. But they do it here # not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170 # https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) # sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
@@ -128,9 +102,36 @@ class SDTrainer(BaseSDTrainProcess):
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
def get_adapter_multiplier():
if self.adapter and isinstance(self.adapter, T2IAdapter):
# training a t2i adapter, not using as assistant.
return 1.0
elif self.train_config.match_adapter_assist:
# training a texture. We want it high
adapter_strength_min = 0.9
adapter_strength_max = 1.0
else:
# training with assistance, we want it low
# adapter_strength_min = 0.5
# adapter_strength_max = 0.8
adapter_strength_min = 0.9
adapter_strength_max = 1.1
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return adapter_conditioning_scale
# flush() # flush()
with self.timer('grad_setup'): with self.timer('grad_setup'):
self.optimizer.zero_grad()
# text encoding # text encoding
grad_on_text_encoder = False grad_on_text_encoder = False
@@ -148,6 +149,7 @@ class SDTrainer(BaseSDTrainProcess):
# set the weights # set the weights
network.multiplier = network_weight_list network.multiplier = network_weight_list
self.optimizer.zero_grad(set_to_none=True)
# activate network if it exits # activate network if it exits
with network: with network:
@@ -159,15 +161,44 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = conditional_embeds.detach() conditional_embeds = conditional_embeds.detach()
# flush() # flush()
pred_kwargs = {} pred_kwargs = {}
if self.adapter and isinstance(self.adapter, T2IAdapter): if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with self.timer('encode_adapter'): with torch.set_grad_enabled(self.adapter is not None):
down_block_additional_residuals = self.adapter(adapter_images) adapter = self.adapter if self.adapter else self.assistant_adapter
down_block_additional_residuals = [ adapter_multiplier = get_adapter_multiplier()
sample.to(dtype=dtype) for sample in down_block_additional_residuals with self.timer('encode_adapter'):
] down_block_additional_residuals = adapter(adapter_images)
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals
]
if self.adapter and isinstance(self.adapter, IPAdapter): pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
control_pred = None
if has_adapter_img and self.assistant_adapter and self.train_config.match_adapter_assist:
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
# dont use network on this
network.multiplier = 0.0
control_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
control_pred = control_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
del pred_kwargs['down_block_additional_residuals']
# restore network
network.multiplier = network_weight_list
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'): with self.timer('encode_adapter'):
with torch.no_grad(): with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
@@ -183,29 +214,13 @@ class SDTrainer(BaseSDTrainProcess):
**pred_kwargs **pred_kwargs
) )
# if self.adapter:
# # todo, diffusers does this on t2i training, is it better approach?
# # Denoise the latents
# denoised_latents = noise_pred * (-sigmas) + noisy_latents
# weighing = sigmas ** -2.0
#
# # Get the target for loss depending on the prediction type
# if self.sd.noise_scheduler.config.prediction_type == "epsilon":
# target = batch.latents # we are computing loss against denoise latents
# elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
# target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
# else:
# raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
#
# # MSE loss
# loss = torch.mean(
# (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
# dim=1,
# )
# else:
with self.timer('calculate_loss'): with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach() noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
if control_pred is not None:
# matching adapter prediction
target = control_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training # v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else: else:

View File

@@ -121,6 +121,9 @@ class TrainConfig:
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0) self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
self.start_step = kwargs.get('start_step', None) self.start_step = kwargs.get('start_step', None)
self.free_u = kwargs.get('free_u', False) self.free_u = kwargs.get('free_u', False)
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.match_adapter_assist = kwargs.get('match_adapter_assist', False)
class ModelConfig: class ModelConfig: