mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-03 17:49:49 +00:00
Added adapter assistance to SD training
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import os.path
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
@@ -30,10 +32,26 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
self.assistant_adapter: Union['T2IAdapter', None]
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def before_dataset_load(self):
|
||||
self.assistant_adapter = None
|
||||
# get adapter assistant if one is set
|
||||
if self.train_config.adapter_assist_name_or_path is not None:
|
||||
adapter_path = self.train_config.adapter_assist_name_or_path
|
||||
|
||||
# dont name this adapter since we are not training it
|
||||
self.assistant_adapter = T2IAdapter.from_pretrained(
|
||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
||||
).to(self.device_torch)
|
||||
self.assistant_adapter.eval()
|
||||
self.assistant_adapter.requires_grad_(False)
|
||||
flush()
|
||||
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
# move vae to device if we did not cache latents
|
||||
if not self.is_latents_cached:
|
||||
@@ -44,53 +62,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
|
||||
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
|
||||
if self.adapter_config.image_dir is None:
|
||||
# adapter needs 0 to 1 values, batch is -1 to 1
|
||||
adapter_batch = batch.tensor.clone().to(
|
||||
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
|
||||
)
|
||||
adapter_batch = (adapter_batch + 1) / 2
|
||||
return adapter_batch
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
adapter_folder_path = self.adapter_config.image_dir
|
||||
adapter_images = []
|
||||
# loop through images
|
||||
for file_item in batch.file_items:
|
||||
img_path = file_item.path
|
||||
file_name_no_ext = os.path.basename(img_path).split('.')[0]
|
||||
# find the image
|
||||
for ext in img_ext_list:
|
||||
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
|
||||
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
|
||||
break
|
||||
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
|
||||
adapter_tensors = []
|
||||
# load images with torch transforms
|
||||
for idx, adapter_image in enumerate(adapter_images):
|
||||
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
|
||||
# to the smallest dimension
|
||||
img: Image.Image = Image.open(adapter_image)
|
||||
if img.width > img.height:
|
||||
# scale down so height is the same as batch
|
||||
new_height = height
|
||||
new_width = int(img.width * (height / img.height))
|
||||
else:
|
||||
new_width = width
|
||||
new_height = int(img.height * (width / img.width))
|
||||
|
||||
img = img.resize((new_width, new_height))
|
||||
crop_fn = transforms.CenterCrop((height, width))
|
||||
# crop the center to match batch
|
||||
img = crop_fn(img)
|
||||
img = adapter_transforms(img)
|
||||
adapter_tensors.append(img)
|
||||
|
||||
# stack them
|
||||
adapter_tensors = torch.stack(adapter_tensors).to(
|
||||
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
|
||||
)
|
||||
return adapter_tensors
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
|
||||
@@ -98,18 +69,21 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
with torch.no_grad():
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if self.adapter:
|
||||
if has_adapter_img and (self.adapter or self.assistant_adapter):
|
||||
with self.timer('get_adapter_images'):
|
||||
# todo move this to data loader
|
||||
if batch.control_tensor is not None:
|
||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
else:
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
raise NotImplementedError("Adapter images now must be loaded with dataloader")
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
@@ -128,9 +102,36 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
def get_adapter_multiplier():
|
||||
if self.adapter and isinstance(self.adapter, T2IAdapter):
|
||||
# training a t2i adapter, not using as assistant.
|
||||
return 1.0
|
||||
elif self.train_config.match_adapter_assist:
|
||||
# training a texture. We want it high
|
||||
adapter_strength_min = 0.9
|
||||
adapter_strength_max = 1.0
|
||||
else:
|
||||
# training with assistance, we want it low
|
||||
# adapter_strength_min = 0.5
|
||||
# adapter_strength_max = 0.8
|
||||
adapter_strength_min = 0.9
|
||||
adapter_strength_max = 1.1
|
||||
|
||||
adapter_conditioning_scale = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
)
|
||||
|
||||
adapter_conditioning_scale = value_map(
|
||||
adapter_conditioning_scale,
|
||||
0.0,
|
||||
1.0,
|
||||
adapter_strength_min,
|
||||
adapter_strength_max
|
||||
)
|
||||
return adapter_conditioning_scale
|
||||
|
||||
# flush()
|
||||
with self.timer('grad_setup'):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
@@ -148,6 +149,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# set the weights
|
||||
network.multiplier = network_weight_list
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# activate network if it exits
|
||||
with network:
|
||||
@@ -159,15 +161,44 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if self.adapter and isinstance(self.adapter, T2IAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = self.adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.adapter if self.adapter else self.assistant_adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals
|
||||
]
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
control_pred = None
|
||||
if has_adapter_img and self.assistant_adapter and self.train_config.match_adapter_assist:
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
with torch.no_grad():
|
||||
# dont use network on this
|
||||
network.multiplier = 0.0
|
||||
control_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
control_pred = control_pred.detach()
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
# restore network
|
||||
network.multiplier = network_weight_list
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
@@ -183,29 +214,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# if self.adapter:
|
||||
# # todo, diffusers does this on t2i training, is it better approach?
|
||||
# # Denoise the latents
|
||||
# denoised_latents = noise_pred * (-sigmas) + noisy_latents
|
||||
# weighing = sigmas ** -2.0
|
||||
#
|
||||
# # Get the target for loss depending on the prediction type
|
||||
# if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
||||
# target = batch.latents # we are computing loss against denoise latents
|
||||
# elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
|
||||
# else:
|
||||
# raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
||||
#
|
||||
# # MSE loss
|
||||
# loss = torch.mean(
|
||||
# (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
||||
# dim=1,
|
||||
# )
|
||||
# else:
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
|
||||
if control_pred is not None:
|
||||
# matching adapter prediction
|
||||
target = control_pred
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
|
||||
@@ -121,6 +121,9 @@ class TrainConfig:
|
||||
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
|
||||
self.start_step = kwargs.get('start_step', None)
|
||||
self.free_u = kwargs.get('free_u', False)
|
||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||
self.match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
Reference in New Issue
Block a user