Added a way to add a t2i adapter guided slider training for more consitant images

This commit is contained in:
Jaret Burkett
2023-09-28 14:08:56 -06:00
parent c5d49ba661
commit 8509da60cb
3 changed files with 189 additions and 34 deletions

View File

@@ -1,8 +1,17 @@
import os
import random
from collections import OrderedDict
from typing import Union
from PIL import Image
from diffusers import T2IAdapter
from torchvision.transforms import transforms
from tqdm import tqdm
from toolkit.basic import value_map
from toolkit.config_modules import SliderConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc
from toolkit import train_tools
@@ -21,6 +30,10 @@ def flush():
gc.collect()
adapter_transforms = transforms.Compose([
transforms.ToTensor(),
])
class TrainSliderProcess(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -42,6 +55,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
# trim targets
self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
# get presets
self.eval_slider_device_state = get_train_sd_device_state_preset(
self.device_torch,
train_unet=False,
train_text_encoder=False,
cached_latents=self.is_latents_cached,
train_lora=False,
train_adapter=False,
train_embedding=False,
)
self.train_slider_device_state = get_train_sd_device_state_preset(
self.device_torch,
train_unet=self.train_config.train_unet,
train_text_encoder=False,
cached_latents=self.is_latents_cached,
train_lora=True,
train_adapter=False,
train_embedding=False,
)
def before_model_load(self):
pass
@@ -66,6 +100,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# trim list to our max steps
cache = PromptEmbedsCache()
print(f"Building prompt cache")
# get encoded latents for our prompts
with torch.no_grad():
@@ -175,30 +210,95 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.sd.vae.to(self.device_torch)
# end hook_before_train_loop
def before_dataset_load(self):
if self.slider_config.use_adapter == 'depth':
print(f"Loading T2I Adapter for depth")
# called before LoRA network is loaded but after model is loaded
# attach the adapter here so it is there before we load the network
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
if self.sd.is_xl:
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
# dont name this adapter since we are not training it
self.t2i_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
).to(self.device_torch)
self.t2i_adapter.eval()
self.t2i_adapter.requires_grad_(False)
flush()
@torch.no_grad()
def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']):
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
adapter_folder_path = self.slider_config.adapter_img_dir
adapter_images = []
# loop through images
for file_item in batch.file_items:
img_path = file_item.path
file_name_no_ext = os.path.basename(img_path).split('.')[0]
# find the image
for ext in img_ext_list:
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
break
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
adapter_tensors = []
# load images with torch transforms
for idx, adapter_image in enumerate(adapter_images):
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
# to the smallest dimension
img: Image.Image = Image.open(adapter_image)
if img.width > img.height:
# scale down so height is the same as batch
new_height = height
new_width = int(img.width * (height / img.height))
else:
new_width = width
new_height = int(img.height * (width / img.width))
img = img.resize((new_width, new_height))
crop_fn = transforms.CenterCrop((height, width))
# crop the center to match batch
img = crop_fn(img)
img = adapter_transforms(img)
adapter_tensors.append(img)
# stack them
adapter_tensors = torch.stack(adapter_tensors).to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
return adapter_tensors
def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype)
# set to eval mode
self.sd.set_device_state(self.eval_slider_device_state)
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
]
# move to device and dtype
prompt_pair.to(self.device_torch, dtype=dtype)
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
]
# move to device and dtype
prompt_pair.to(self.device_torch, dtype=dtype)
# get a random resolution
height, width = self.slider_config.resolutions[
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
]
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
# get a random resolution
height, width = self.slider_config.resolutions[
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
]
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss()
pred_kwargs = {}
def get_noise_pred(neg, pos, gs, cts, dn):
return self.sd.predict_noise(
latents=dn,
@@ -209,9 +309,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
),
timestep=cts,
guidance_scale=gs,
**pred_kwargs
)
with torch.no_grad():
adapter_images = None
# for a complete slider, the batch size is 4 to begin with now
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
from_batch = False
@@ -219,9 +321,32 @@ class TrainSliderProcess(BaseSDTrainProcess):
# traing from a batch of images, not generating ourselves
from_batch = True
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.slider_config.adapter_img_dir is not None:
adapter_images = self.get_adapter_images(batch)
adapter_strength_min = 0.9
adapter_strength_max = 1.0
denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size
denoised_latents = torch.cat(denoised_latent_chunks, dim=0)
def rand_strength(sample):
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale
down_block_additional_residuals = self.t2i_adapter(adapter_images)
down_block_additional_residuals = [
rand_strength(sample) for sample in down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0)
current_timestep = timesteps
else:
@@ -229,8 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.train_config.max_denoising_steps, device=self.device_torch
)
self.optimizer.zero_grad()
# ger a random number of steps
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
@@ -267,13 +390,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
noise_scheduler.set_timesteps(1000)
# split the latents into out prompt pair chunks
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
# split the latents into out prompt pair chunks
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
# flush() # 4.2GB to 3GB on 512x512
# 4.20 GB RAM for 512x512
@@ -286,7 +410,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
positive_latents = positive_latents.detach()
positive_latents.requires_grad = False
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
neutral_latents = get_noise_pred(
prompt_pair.positive_target, # negative prompt
@@ -297,7 +420,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
neutral_latents = neutral_latents.detach()
neutral_latents.requires_grad = False
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
unconditional_latents = get_noise_pred(
prompt_pair.positive_target, # negative prompt
@@ -308,13 +430,13 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
unconditional_latents = unconditional_latents.detach()
unconditional_latents.requires_grad = False
unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
denoised_latents = denoised_latents.detach()
# flush() # 4.2GB to 3GB on 512x512
self.sd.set_device_state(self.train_slider_device_state)
# start accumulating gradients
self.optimizer.zero_grad(set_to_none=True)
# 4.20 GB RAM for 512x512
anchor_loss_float = None
if len(self.anchor_pairs) > 0:
with torch.no_grad():
@@ -369,9 +491,23 @@ class TrainSliderProcess(BaseSDTrainProcess):
del anchor_target_noise
# move anchor back to cpu
anchor.to("cpu")
# flush()
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
with torch.no_grad():
if self.slider_config.high_ram:
# run through in one instance
prompt_pair_chunks = [prompt_pair.detach()]
denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()]
positive_latents_chunks = [positive_latents.detach()]
neutral_latents_chunks = [neutral_latents.detach()]
unconditional_latents_chunks = [unconditional_latents.detach()]
else:
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
unconditional_latents_chunks = torch.chunk(unconditional_latents.detach(), self.prompt_chunk_size, dim=0)
# flush()
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
# 3.28 GB RAM for 512x512
with self.network:

View File

@@ -186,18 +186,23 @@ class SliderConfig:
self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
self.high_ram = kwargs.get('high_ram', False)
# expand targets if shuffling
from toolkit.prompt_utils import get_slider_target_permutations
self.targets: List[SliderTargetConfig] = []
targets = [SliderTargetConfig(**target) for target in targets]
# do permutations if shuffle is true
print(f"Building slider targets")
for target in targets:
if target.shuffle:
target_permutations = get_slider_target_permutations(target)
target_permutations = get_slider_target_permutations(target, max_permutations=100)
self.targets = self.targets + target_permutations
else:
self.targets.append(target)
print(f"Built {len(self.targets)} slider targets (with permutations)")
class DatasetConfig:

View File

@@ -105,6 +105,18 @@ class EncodedPromptPair:
self.both_targets = self.both_targets.to(*args, **kwargs)
return self
def detach(self):
self.target_class = self.target_class.detach()
self.target_class_with_neutral = self.target_class_with_neutral.detach()
self.positive_target = self.positive_target.detach()
self.positive_target_with_neutral = self.positive_target_with_neutral.detach()
self.negative_target = self.negative_target.detach()
self.negative_target_with_neutral = self.negative_target_with_neutral.detach()
self.neutral = self.neutral.detach()
self.empty_prompt = self.empty_prompt.detach()
self.both_targets = self.both_targets.detach()
return self
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
@@ -267,15 +279,17 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
return anchors
def get_permutations(s):
def get_permutations(s, max_permutations=8):
# Split the string by comma
phrases = [phrase.strip() for phrase in s.split(',')]
# remove empty strings
phrases = [phrase for phrase in phrases if len(phrase) > 0]
# shuffle the list
random.shuffle(phrases)
# Get all permutations
permutations = list(itertools.permutations(phrases))
permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)])
# Convert the tuples back to comma separated strings
return [', '.join(permutation) for permutation in permutations]
@@ -283,8 +297,8 @@ def get_permutations(s):
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
from toolkit.config_modules import SliderTargetConfig
pos_permutations = get_permutations(target.positive)
neg_permutations = get_permutations(target.negative)
pos_permutations = get_permutations(target.positive, max_permutations=max_permutations)
neg_permutations = get_permutations(target.negative, max_permutations=max_permutations)
permutations = []
for pos, neg in itertools.product(pos_permutations, neg_permutations):