mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added a way to add a t2i adapter guided slider training for more consitant images
This commit is contained in:
@@ -1,8 +1,17 @@
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from torchvision.transforms import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
@@ -21,6 +30,10 @@ def flush():
|
||||
gc.collect()
|
||||
|
||||
|
||||
adapter_transforms = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
class TrainSliderProcess(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -42,6 +55,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# trim targets
|
||||
self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
|
||||
|
||||
# get presets
|
||||
self.eval_slider_device_state = get_train_sd_device_state_preset(
|
||||
self.device_torch,
|
||||
train_unet=False,
|
||||
train_text_encoder=False,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=False,
|
||||
train_adapter=False,
|
||||
train_embedding=False,
|
||||
)
|
||||
|
||||
self.train_slider_device_state = get_train_sd_device_state_preset(
|
||||
self.device_torch,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=False,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=True,
|
||||
train_adapter=False,
|
||||
train_embedding=False,
|
||||
)
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
@@ -66,6 +100,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# trim list to our max steps
|
||||
|
||||
cache = PromptEmbedsCache()
|
||||
print(f"Building prompt cache")
|
||||
|
||||
# get encoded latents for our prompts
|
||||
with torch.no_grad():
|
||||
@@ -175,30 +210,95 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.sd.vae.to(self.device_torch)
|
||||
# end hook_before_train_loop
|
||||
|
||||
def before_dataset_load(self):
|
||||
if self.slider_config.use_adapter == 'depth':
|
||||
print(f"Loading T2I Adapter for depth")
|
||||
# called before LoRA network is loaded but after model is loaded
|
||||
# attach the adapter here so it is there before we load the network
|
||||
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
|
||||
if self.sd.is_xl:
|
||||
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
|
||||
|
||||
# dont name this adapter since we are not training it
|
||||
self.t2i_adapter = T2IAdapter.from_pretrained(
|
||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
||||
).to(self.device_torch)
|
||||
self.t2i_adapter.eval()
|
||||
self.t2i_adapter.requires_grad_(False)
|
||||
flush()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']):
|
||||
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
adapter_folder_path = self.slider_config.adapter_img_dir
|
||||
adapter_images = []
|
||||
# loop through images
|
||||
for file_item in batch.file_items:
|
||||
img_path = file_item.path
|
||||
file_name_no_ext = os.path.basename(img_path).split('.')[0]
|
||||
# find the image
|
||||
for ext in img_ext_list:
|
||||
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
|
||||
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
|
||||
break
|
||||
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
|
||||
adapter_tensors = []
|
||||
# load images with torch transforms
|
||||
for idx, adapter_image in enumerate(adapter_images):
|
||||
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
|
||||
# to the smallest dimension
|
||||
img: Image.Image = Image.open(adapter_image)
|
||||
if img.width > img.height:
|
||||
# scale down so height is the same as batch
|
||||
new_height = height
|
||||
new_width = int(img.width * (height / img.height))
|
||||
else:
|
||||
new_width = width
|
||||
new_height = int(img.height * (width / img.width))
|
||||
|
||||
img = img.resize((new_width, new_height))
|
||||
crop_fn = transforms.CenterCrop((height, width))
|
||||
# crop the center to match batch
|
||||
img = crop_fn(img)
|
||||
img = adapter_transforms(img)
|
||||
adapter_tensors.append(img)
|
||||
|
||||
# stack them
|
||||
adapter_tensors = torch.stack(adapter_tensors).to(
|
||||
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
|
||||
)
|
||||
return adapter_tensors
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# set to eval mode
|
||||
self.sd.set_device_state(self.eval_slider_device_state)
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# get a random pair
|
||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
||||
]
|
||||
# move to device and dtype
|
||||
prompt_pair.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get a random pair
|
||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
||||
]
|
||||
# move to device and dtype
|
||||
prompt_pair.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get a random resolution
|
||||
height, width = self.slider_config.resolutions[
|
||||
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
|
||||
]
|
||||
if self.train_config.gradient_checkpointing:
|
||||
# may get disabled elsewhere
|
||||
self.sd.unet.enable_gradient_checkpointing()
|
||||
# get a random resolution
|
||||
height, width = self.slider_config.resolutions[
|
||||
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
|
||||
]
|
||||
if self.train_config.gradient_checkpointing:
|
||||
# may get disabled elsewhere
|
||||
self.sd.unet.enable_gradient_checkpointing()
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
pred_kwargs = {}
|
||||
|
||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||
return self.sd.predict_noise(
|
||||
latents=dn,
|
||||
@@ -209,9 +309,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
),
|
||||
timestep=cts,
|
||||
guidance_scale=gs,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
adapter_images = None
|
||||
# for a complete slider, the batch size is 4 to begin with now
|
||||
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
||||
from_batch = False
|
||||
@@ -219,9 +321,32 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# traing from a batch of images, not generating ourselves
|
||||
from_batch = True
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.slider_config.adapter_img_dir is not None:
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
adapter_strength_min = 0.9
|
||||
adapter_strength_max = 1.0
|
||||
|
||||
denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size
|
||||
denoised_latents = torch.cat(denoised_latent_chunks, dim=0)
|
||||
def rand_strength(sample):
|
||||
adapter_conditioning_scale = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
)
|
||||
|
||||
adapter_conditioning_scale = value_map(
|
||||
adapter_conditioning_scale,
|
||||
0.0,
|
||||
1.0,
|
||||
adapter_strength_min,
|
||||
adapter_strength_max
|
||||
)
|
||||
return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale
|
||||
|
||||
down_block_additional_residuals = self.t2i_adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
rand_strength(sample) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0)
|
||||
current_timestep = timesteps
|
||||
else:
|
||||
|
||||
@@ -229,8 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# ger a random number of steps
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
@@ -267,13 +390,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
noise_scheduler.set_timesteps(1000)
|
||||
|
||||
# split the latents into out prompt pair chunks
|
||||
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
|
||||
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
|
||||
|
||||
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
# split the latents into out prompt pair chunks
|
||||
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
|
||||
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
|
||||
|
||||
# flush() # 4.2GB to 3GB on 512x512
|
||||
|
||||
# 4.20 GB RAM for 512x512
|
||||
@@ -286,7 +410,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
positive_latents = positive_latents.detach()
|
||||
positive_latents.requires_grad = False
|
||||
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
|
||||
|
||||
neutral_latents = get_noise_pred(
|
||||
prompt_pair.positive_target, # negative prompt
|
||||
@@ -297,7 +420,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
neutral_latents = neutral_latents.detach()
|
||||
neutral_latents.requires_grad = False
|
||||
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
|
||||
|
||||
unconditional_latents = get_noise_pred(
|
||||
prompt_pair.positive_target, # negative prompt
|
||||
@@ -308,13 +430,13 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
unconditional_latents = unconditional_latents.detach()
|
||||
unconditional_latents.requires_grad = False
|
||||
unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
|
||||
|
||||
denoised_latents = denoised_latents.detach()
|
||||
|
||||
# flush() # 4.2GB to 3GB on 512x512
|
||||
self.sd.set_device_state(self.train_slider_device_state)
|
||||
# start accumulating gradients
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# 4.20 GB RAM for 512x512
|
||||
anchor_loss_float = None
|
||||
if len(self.anchor_pairs) > 0:
|
||||
with torch.no_grad():
|
||||
@@ -369,9 +491,23 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
del anchor_target_noise
|
||||
# move anchor back to cpu
|
||||
anchor.to("cpu")
|
||||
# flush()
|
||||
|
||||
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
|
||||
with torch.no_grad():
|
||||
if self.slider_config.high_ram:
|
||||
# run through in one instance
|
||||
prompt_pair_chunks = [prompt_pair.detach()]
|
||||
denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()]
|
||||
positive_latents_chunks = [positive_latents.detach()]
|
||||
neutral_latents_chunks = [neutral_latents.detach()]
|
||||
unconditional_latents_chunks = [unconditional_latents.detach()]
|
||||
else:
|
||||
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
|
||||
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
|
||||
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
unconditional_latents_chunks = torch.chunk(unconditional_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||
|
||||
# flush()
|
||||
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
|
||||
# 3.28 GB RAM for 512x512
|
||||
with self.network:
|
||||
|
||||
@@ -186,18 +186,23 @@ class SliderConfig:
|
||||
self.prompt_file: str = kwargs.get('prompt_file', None)
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
|
||||
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
|
||||
self.high_ram = kwargs.get('high_ram', False)
|
||||
|
||||
# expand targets if shuffling
|
||||
from toolkit.prompt_utils import get_slider_target_permutations
|
||||
self.targets: List[SliderTargetConfig] = []
|
||||
targets = [SliderTargetConfig(**target) for target in targets]
|
||||
# do permutations if shuffle is true
|
||||
print(f"Building slider targets")
|
||||
for target in targets:
|
||||
if target.shuffle:
|
||||
target_permutations = get_slider_target_permutations(target)
|
||||
target_permutations = get_slider_target_permutations(target, max_permutations=100)
|
||||
self.targets = self.targets + target_permutations
|
||||
else:
|
||||
self.targets.append(target)
|
||||
print(f"Built {len(self.targets)} slider targets (with permutations)")
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
|
||||
@@ -105,6 +105,18 @@ class EncodedPromptPair:
|
||||
self.both_targets = self.both_targets.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
self.target_class = self.target_class.detach()
|
||||
self.target_class_with_neutral = self.target_class_with_neutral.detach()
|
||||
self.positive_target = self.positive_target.detach()
|
||||
self.positive_target_with_neutral = self.positive_target_with_neutral.detach()
|
||||
self.negative_target = self.negative_target.detach()
|
||||
self.negative_target_with_neutral = self.negative_target_with_neutral.detach()
|
||||
self.neutral = self.neutral.detach()
|
||||
self.empty_prompt = self.empty_prompt.detach()
|
||||
self.both_targets = self.both_targets.detach()
|
||||
return self
|
||||
|
||||
|
||||
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
||||
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
|
||||
@@ -267,15 +279,17 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
|
||||
return anchors
|
||||
|
||||
|
||||
def get_permutations(s):
|
||||
def get_permutations(s, max_permutations=8):
|
||||
# Split the string by comma
|
||||
phrases = [phrase.strip() for phrase in s.split(',')]
|
||||
|
||||
# remove empty strings
|
||||
phrases = [phrase for phrase in phrases if len(phrase) > 0]
|
||||
# shuffle the list
|
||||
random.shuffle(phrases)
|
||||
|
||||
# Get all permutations
|
||||
permutations = list(itertools.permutations(phrases))
|
||||
permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)])
|
||||
|
||||
# Convert the tuples back to comma separated strings
|
||||
return [', '.join(permutation) for permutation in permutations]
|
||||
@@ -283,8 +297,8 @@ def get_permutations(s):
|
||||
|
||||
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
|
||||
from toolkit.config_modules import SliderTargetConfig
|
||||
pos_permutations = get_permutations(target.positive)
|
||||
neg_permutations = get_permutations(target.negative)
|
||||
pos_permutations = get_permutations(target.positive, max_permutations=max_permutations)
|
||||
neg_permutations = get_permutations(target.negative, max_permutations=max_permutations)
|
||||
|
||||
permutations = []
|
||||
for pos, neg in itertools.product(pos_permutations, neg_permutations):
|
||||
|
||||
Reference in New Issue
Block a user