Added a way to add a t2i adapter guided slider training for more consitant images

This commit is contained in:
Jaret Burkett
2023-09-28 14:08:56 -06:00
parent c5d49ba661
commit 8509da60cb
3 changed files with 189 additions and 34 deletions

View File

@@ -1,8 +1,17 @@
import os
import random import random
from collections import OrderedDict from collections import OrderedDict
from typing import Union
from PIL import Image
from diffusers import T2IAdapter
from torchvision.transforms import transforms
from tqdm import tqdm from tqdm import tqdm
from toolkit.basic import value_map
from toolkit.config_modules import SliderConfig from toolkit.config_modules import SliderConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.train_tools import get_torch_dtype, apply_snr_weight from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc import gc
from toolkit import train_tools from toolkit import train_tools
@@ -21,6 +30,10 @@ def flush():
gc.collect() gc.collect()
adapter_transforms = transforms.Compose([
transforms.ToTensor(),
])
class TrainSliderProcess(BaseSDTrainProcess): class TrainSliderProcess(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict): def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
@@ -42,6 +55,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
# trim targets # trim targets
self.slider_config.targets = self.slider_config.targets[:self.train_config.steps] self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
# get presets
self.eval_slider_device_state = get_train_sd_device_state_preset(
self.device_torch,
train_unet=False,
train_text_encoder=False,
cached_latents=self.is_latents_cached,
train_lora=False,
train_adapter=False,
train_embedding=False,
)
self.train_slider_device_state = get_train_sd_device_state_preset(
self.device_torch,
train_unet=self.train_config.train_unet,
train_text_encoder=False,
cached_latents=self.is_latents_cached,
train_lora=True,
train_adapter=False,
train_embedding=False,
)
def before_model_load(self): def before_model_load(self):
pass pass
@@ -66,6 +100,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# trim list to our max steps # trim list to our max steps
cache = PromptEmbedsCache() cache = PromptEmbedsCache()
print(f"Building prompt cache")
# get encoded latents for our prompts # get encoded latents for our prompts
with torch.no_grad(): with torch.no_grad():
@@ -175,30 +210,95 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.sd.vae.to(self.device_torch) self.sd.vae.to(self.device_torch)
# end hook_before_train_loop # end hook_before_train_loop
def before_dataset_load(self):
if self.slider_config.use_adapter == 'depth':
print(f"Loading T2I Adapter for depth")
# called before LoRA network is loaded but after model is loaded
# attach the adapter here so it is there before we load the network
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
if self.sd.is_xl:
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
# dont name this adapter since we are not training it
self.t2i_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
).to(self.device_torch)
self.t2i_adapter.eval()
self.t2i_adapter.requires_grad_(False)
flush()
@torch.no_grad()
def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']):
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
adapter_folder_path = self.slider_config.adapter_img_dir
adapter_images = []
# loop through images
for file_item in batch.file_items:
img_path = file_item.path
file_name_no_ext = os.path.basename(img_path).split('.')[0]
# find the image
for ext in img_ext_list:
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
break
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
adapter_tensors = []
# load images with torch transforms
for idx, adapter_image in enumerate(adapter_images):
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
# to the smallest dimension
img: Image.Image = Image.open(adapter_image)
if img.width > img.height:
# scale down so height is the same as batch
new_height = height
new_width = int(img.width * (height / img.height))
else:
new_width = width
new_height = int(img.height * (width / img.width))
img = img.resize((new_width, new_height))
crop_fn = transforms.CenterCrop((height, width))
# crop the center to match batch
img = crop_fn(img)
img = adapter_transforms(img)
adapter_tensors.append(img)
# stack them
adapter_tensors = torch.stack(adapter_tensors).to(
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
)
return adapter_tensors
def hook_train_loop(self, batch): def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype) # set to eval mode
self.sd.set_device_state(self.eval_slider_device_state)
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
]
# move to device and dtype
prompt_pair.to(self.device_torch, dtype=dtype)
# get a random pair # get a random resolution
prompt_pair: EncodedPromptPair = self.prompt_pairs[ height, width = self.slider_config.resolutions[
torch.randint(0, len(self.prompt_pairs), (1,)).item() torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
] ]
# move to device and dtype if self.train_config.gradient_checkpointing:
prompt_pair.to(self.device_torch, dtype=dtype) # may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
# get a random resolution
height, width = self.slider_config.resolutions[
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
]
if self.train_config.gradient_checkpointing:
# may get disabled elsewhere
self.sd.unet.enable_gradient_checkpointing()
noise_scheduler = self.sd.noise_scheduler noise_scheduler = self.sd.noise_scheduler
optimizer = self.optimizer optimizer = self.optimizer
lr_scheduler = self.lr_scheduler lr_scheduler = self.lr_scheduler
loss_function = torch.nn.MSELoss() loss_function = torch.nn.MSELoss()
pred_kwargs = {}
def get_noise_pred(neg, pos, gs, cts, dn): def get_noise_pred(neg, pos, gs, cts, dn):
return self.sd.predict_noise( return self.sd.predict_noise(
latents=dn, latents=dn,
@@ -209,9 +309,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
), ),
timestep=cts, timestep=cts,
guidance_scale=gs, guidance_scale=gs,
**pred_kwargs
) )
with torch.no_grad(): with torch.no_grad():
adapter_images = None
# for a complete slider, the batch size is 4 to begin with now # for a complete slider, the batch size is 4 to begin with now
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
from_batch = False from_batch = False
@@ -219,9 +321,32 @@ class TrainSliderProcess(BaseSDTrainProcess):
# traing from a batch of images, not generating ourselves # traing from a batch of images, not generating ourselves
from_batch = True from_batch = True
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.slider_config.adapter_img_dir is not None:
adapter_images = self.get_adapter_images(batch)
adapter_strength_min = 0.9
adapter_strength_max = 1.0
denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size def rand_strength(sample):
denoised_latents = torch.cat(denoised_latent_chunks, dim=0) adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale
down_block_additional_residuals = self.t2i_adapter(adapter_images)
down_block_additional_residuals = [
rand_strength(sample) for sample in down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0)
current_timestep = timesteps current_timestep = timesteps
else: else:
@@ -229,8 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.train_config.max_denoising_steps, device=self.device_torch self.train_config.max_denoising_steps, device=self.device_torch
) )
self.optimizer.zero_grad()
# ger a random number of steps # ger a random number of steps
timesteps_to = torch.randint( timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,) 1, self.train_config.max_denoising_steps, (1,)
@@ -267,13 +390,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
noise_scheduler.set_timesteps(1000) noise_scheduler.set_timesteps(1000)
# split the latents into out prompt pair chunks
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps) current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index] current_timestep = noise_scheduler.timesteps[current_timestep_index]
# split the latents into out prompt pair chunks
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
# flush() # 4.2GB to 3GB on 512x512 # flush() # 4.2GB to 3GB on 512x512
# 4.20 GB RAM for 512x512 # 4.20 GB RAM for 512x512
@@ -286,7 +410,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
) )
positive_latents = positive_latents.detach() positive_latents = positive_latents.detach()
positive_latents.requires_grad = False positive_latents.requires_grad = False
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
neutral_latents = get_noise_pred( neutral_latents = get_noise_pred(
prompt_pair.positive_target, # negative prompt prompt_pair.positive_target, # negative prompt
@@ -297,7 +420,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
) )
neutral_latents = neutral_latents.detach() neutral_latents = neutral_latents.detach()
neutral_latents.requires_grad = False neutral_latents.requires_grad = False
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
unconditional_latents = get_noise_pred( unconditional_latents = get_noise_pred(
prompt_pair.positive_target, # negative prompt prompt_pair.positive_target, # negative prompt
@@ -308,13 +430,13 @@ class TrainSliderProcess(BaseSDTrainProcess):
) )
unconditional_latents = unconditional_latents.detach() unconditional_latents = unconditional_latents.detach()
unconditional_latents.requires_grad = False unconditional_latents.requires_grad = False
unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
denoised_latents = denoised_latents.detach() denoised_latents = denoised_latents.detach()
# flush() # 4.2GB to 3GB on 512x512 self.sd.set_device_state(self.train_slider_device_state)
# start accumulating gradients
self.optimizer.zero_grad(set_to_none=True)
# 4.20 GB RAM for 512x512
anchor_loss_float = None anchor_loss_float = None
if len(self.anchor_pairs) > 0: if len(self.anchor_pairs) > 0:
with torch.no_grad(): with torch.no_grad():
@@ -369,9 +491,23 @@ class TrainSliderProcess(BaseSDTrainProcess):
del anchor_target_noise del anchor_target_noise
# move anchor back to cpu # move anchor back to cpu
anchor.to("cpu") anchor.to("cpu")
# flush()
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size) with torch.no_grad():
if self.slider_config.high_ram:
# run through in one instance
prompt_pair_chunks = [prompt_pair.detach()]
denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()]
positive_latents_chunks = [positive_latents.detach()]
neutral_latents_chunks = [neutral_latents.detach()]
unconditional_latents_chunks = [unconditional_latents.detach()]
else:
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
unconditional_latents_chunks = torch.chunk(unconditional_latents.detach(), self.prompt_chunk_size, dim=0)
# flush()
assert len(prompt_pair_chunks) == len(denoised_latent_chunks) assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
# 3.28 GB RAM for 512x512 # 3.28 GB RAM for 512x512
with self.network: with self.network:

View File

@@ -186,18 +186,23 @@ class SliderConfig:
self.prompt_file: str = kwargs.get('prompt_file', None) self.prompt_file: str = kwargs.get('prompt_file', None)
self.prompt_tensors: str = kwargs.get('prompt_tensors', None) self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
self.high_ram = kwargs.get('high_ram', False)
# expand targets if shuffling # expand targets if shuffling
from toolkit.prompt_utils import get_slider_target_permutations from toolkit.prompt_utils import get_slider_target_permutations
self.targets: List[SliderTargetConfig] = [] self.targets: List[SliderTargetConfig] = []
targets = [SliderTargetConfig(**target) for target in targets] targets = [SliderTargetConfig(**target) for target in targets]
# do permutations if shuffle is true # do permutations if shuffle is true
print(f"Building slider targets")
for target in targets: for target in targets:
if target.shuffle: if target.shuffle:
target_permutations = get_slider_target_permutations(target) target_permutations = get_slider_target_permutations(target, max_permutations=100)
self.targets = self.targets + target_permutations self.targets = self.targets + target_permutations
else: else:
self.targets.append(target) self.targets.append(target)
print(f"Built {len(self.targets)} slider targets (with permutations)")
class DatasetConfig: class DatasetConfig:

View File

@@ -105,6 +105,18 @@ class EncodedPromptPair:
self.both_targets = self.both_targets.to(*args, **kwargs) self.both_targets = self.both_targets.to(*args, **kwargs)
return self return self
def detach(self):
self.target_class = self.target_class.detach()
self.target_class_with_neutral = self.target_class_with_neutral.detach()
self.positive_target = self.positive_target.detach()
self.positive_target_with_neutral = self.positive_target_with_neutral.detach()
self.negative_target = self.negative_target.detach()
self.negative_target_with_neutral = self.negative_target_with_neutral.detach()
self.neutral = self.neutral.detach()
self.empty_prompt = self.empty_prompt.detach()
self.both_targets = self.both_targets.detach()
return self
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]): def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
@@ -267,15 +279,17 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
return anchors return anchors
def get_permutations(s): def get_permutations(s, max_permutations=8):
# Split the string by comma # Split the string by comma
phrases = [phrase.strip() for phrase in s.split(',')] phrases = [phrase.strip() for phrase in s.split(',')]
# remove empty strings # remove empty strings
phrases = [phrase for phrase in phrases if len(phrase) > 0] phrases = [phrase for phrase in phrases if len(phrase) > 0]
# shuffle the list
random.shuffle(phrases)
# Get all permutations # Get all permutations
permutations = list(itertools.permutations(phrases)) permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)])
# Convert the tuples back to comma separated strings # Convert the tuples back to comma separated strings
return [', '.join(permutation) for permutation in permutations] return [', '.join(permutation) for permutation in permutations]
@@ -283,8 +297,8 @@ def get_permutations(s):
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']: def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
from toolkit.config_modules import SliderTargetConfig from toolkit.config_modules import SliderTargetConfig
pos_permutations = get_permutations(target.positive) pos_permutations = get_permutations(target.positive, max_permutations=max_permutations)
neg_permutations = get_permutations(target.negative) neg_permutations = get_permutations(target.negative, max_permutations=max_permutations)
permutations = [] permutations = []
for pos, neg in itertools.product(pos_permutations, neg_permutations): for pos, neg in itertools.product(pos_permutations, neg_permutations):