diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index dd91b155..180be04f 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -1,8 +1,17 @@ +import os import random from collections import OrderedDict +from typing import Union + +from PIL import Image +from diffusers import T2IAdapter +from torchvision.transforms import transforms from tqdm import tqdm +from toolkit.basic import value_map from toolkit.config_modules import SliderConfig +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.sd_device_states_presets import get_train_sd_device_state_preset from toolkit.train_tools import get_torch_dtype, apply_snr_weight import gc from toolkit import train_tools @@ -21,6 +30,10 @@ def flush(): gc.collect() +adapter_transforms = transforms.Compose([ + transforms.ToTensor(), +]) + class TrainSliderProcess(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) @@ -42,6 +55,27 @@ class TrainSliderProcess(BaseSDTrainProcess): # trim targets self.slider_config.targets = self.slider_config.targets[:self.train_config.steps] + # get presets + self.eval_slider_device_state = get_train_sd_device_state_preset( + self.device_torch, + train_unet=False, + train_text_encoder=False, + cached_latents=self.is_latents_cached, + train_lora=False, + train_adapter=False, + train_embedding=False, + ) + + self.train_slider_device_state = get_train_sd_device_state_preset( + self.device_torch, + train_unet=self.train_config.train_unet, + train_text_encoder=False, + cached_latents=self.is_latents_cached, + train_lora=True, + train_adapter=False, + train_embedding=False, + ) + def before_model_load(self): pass @@ -66,6 +100,7 @@ class TrainSliderProcess(BaseSDTrainProcess): # trim list to our max steps cache = PromptEmbedsCache() + print(f"Building prompt cache") # get encoded latents for our prompts with torch.no_grad(): @@ -175,30 +210,95 @@ class TrainSliderProcess(BaseSDTrainProcess): self.sd.vae.to(self.device_torch) # end hook_before_train_loop + def before_dataset_load(self): + if self.slider_config.use_adapter == 'depth': + print(f"Loading T2I Adapter for depth") + # called before LoRA network is loaded but after model is loaded + # attach the adapter here so it is there before we load the network + adapter_path = 'TencentARC/t2iadapter_depth_sd15v2' + if self.sd.is_xl: + adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0' + + # dont name this adapter since we are not training it + self.t2i_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16" + ).to(self.device_torch) + self.t2i_adapter.eval() + self.t2i_adapter.requires_grad_(False) + flush() + + @torch.no_grad() + def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']): + + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + adapter_folder_path = self.slider_config.adapter_img_dir + adapter_images = [] + # loop through images + for file_item in batch.file_items: + img_path = file_item.path + file_name_no_ext = os.path.basename(img_path).split('.')[0] + # find the image + for ext in img_ext_list: + if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)): + adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext)) + break + width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height + adapter_tensors = [] + # load images with torch transforms + for idx, adapter_image in enumerate(adapter_images): + # we need to centrally crop the largest dimension of the image to match the batch shape after scaling + # to the smallest dimension + img: Image.Image = Image.open(adapter_image) + if img.width > img.height: + # scale down so height is the same as batch + new_height = height + new_width = int(img.width * (height / img.height)) + else: + new_width = width + new_height = int(img.height * (width / img.width)) + + img = img.resize((new_width, new_height)) + crop_fn = transforms.CenterCrop((height, width)) + # crop the center to match batch + img = crop_fn(img) + img = adapter_transforms(img) + adapter_tensors.append(img) + + # stack them + adapter_tensors = torch.stack(adapter_tensors).to( + self.device_torch, dtype=get_torch_dtype(self.train_config.dtype) + ) + return adapter_tensors + def hook_train_loop(self, batch): - dtype = get_torch_dtype(self.train_config.dtype) + # set to eval mode + self.sd.set_device_state(self.eval_slider_device_state) + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + # get a random pair + prompt_pair: EncodedPromptPair = self.prompt_pairs[ + torch.randint(0, len(self.prompt_pairs), (1,)).item() + ] + # move to device and dtype + prompt_pair.to(self.device_torch, dtype=dtype) - # get a random pair - prompt_pair: EncodedPromptPair = self.prompt_pairs[ - torch.randint(0, len(self.prompt_pairs), (1,)).item() - ] - # move to device and dtype - prompt_pair.to(self.device_torch, dtype=dtype) - - # get a random resolution - height, width = self.slider_config.resolutions[ - torch.randint(0, len(self.slider_config.resolutions), (1,)).item() - ] - if self.train_config.gradient_checkpointing: - # may get disabled elsewhere - self.sd.unet.enable_gradient_checkpointing() + # get a random resolution + height, width = self.slider_config.resolutions[ + torch.randint(0, len(self.slider_config.resolutions), (1,)).item() + ] + if self.train_config.gradient_checkpointing: + # may get disabled elsewhere + self.sd.unet.enable_gradient_checkpointing() noise_scheduler = self.sd.noise_scheduler optimizer = self.optimizer lr_scheduler = self.lr_scheduler + loss_function = torch.nn.MSELoss() + pred_kwargs = {} + def get_noise_pred(neg, pos, gs, cts, dn): return self.sd.predict_noise( latents=dn, @@ -209,9 +309,11 @@ class TrainSliderProcess(BaseSDTrainProcess): ), timestep=cts, guidance_scale=gs, + **pred_kwargs ) with torch.no_grad(): + adapter_images = None # for a complete slider, the batch size is 4 to begin with now true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size from_batch = False @@ -219,9 +321,32 @@ class TrainSliderProcess(BaseSDTrainProcess): # traing from a batch of images, not generating ourselves from_batch = True noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.slider_config.adapter_img_dir is not None: + adapter_images = self.get_adapter_images(batch) + adapter_strength_min = 0.9 + adapter_strength_max = 1.0 - denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size - denoised_latents = torch.cat(denoised_latent_chunks, dim=0) + def rand_strength(sample): + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) + + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale + + down_block_additional_residuals = self.t2i_adapter(adapter_images) + down_block_additional_residuals = [ + rand_strength(sample) for sample in down_block_additional_residuals + ] + pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals + + denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0) current_timestep = timesteps else: @@ -229,8 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess): self.train_config.max_denoising_steps, device=self.device_torch ) - self.optimizer.zero_grad() - # ger a random number of steps timesteps_to = torch.randint( 1, self.train_config.max_denoising_steps, (1,) @@ -267,13 +390,14 @@ class TrainSliderProcess(BaseSDTrainProcess): noise_scheduler.set_timesteps(1000) - # split the latents into out prompt pair chunks - denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0) - denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks] current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps) current_timestep = noise_scheduler.timesteps[current_timestep_index] + # split the latents into out prompt pair chunks + denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0) + denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks] + # flush() # 4.2GB to 3GB on 512x512 # 4.20 GB RAM for 512x512 @@ -286,7 +410,6 @@ class TrainSliderProcess(BaseSDTrainProcess): ) positive_latents = positive_latents.detach() positive_latents.requires_grad = False - positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0) neutral_latents = get_noise_pred( prompt_pair.positive_target, # negative prompt @@ -297,7 +420,6 @@ class TrainSliderProcess(BaseSDTrainProcess): ) neutral_latents = neutral_latents.detach() neutral_latents.requires_grad = False - neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0) unconditional_latents = get_noise_pred( prompt_pair.positive_target, # negative prompt @@ -308,13 +430,13 @@ class TrainSliderProcess(BaseSDTrainProcess): ) unconditional_latents = unconditional_latents.detach() unconditional_latents.requires_grad = False - unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0) denoised_latents = denoised_latents.detach() - # flush() # 4.2GB to 3GB on 512x512 + self.sd.set_device_state(self.train_slider_device_state) + # start accumulating gradients + self.optimizer.zero_grad(set_to_none=True) - # 4.20 GB RAM for 512x512 anchor_loss_float = None if len(self.anchor_pairs) > 0: with torch.no_grad(): @@ -369,9 +491,23 @@ class TrainSliderProcess(BaseSDTrainProcess): del anchor_target_noise # move anchor back to cpu anchor.to("cpu") - # flush() - prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size) + with torch.no_grad(): + if self.slider_config.high_ram: + # run through in one instance + prompt_pair_chunks = [prompt_pair.detach()] + denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()] + positive_latents_chunks = [positive_latents.detach()] + neutral_latents_chunks = [neutral_latents.detach()] + unconditional_latents_chunks = [unconditional_latents.detach()] + else: + prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size) + denoised_latent_chunks = denoised_latent_chunks # just to have it in one place + positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0) + neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0) + unconditional_latents_chunks = torch.chunk(unconditional_latents.detach(), self.prompt_chunk_size, dim=0) + + # flush() assert len(prompt_pair_chunks) == len(denoised_latent_chunks) # 3.28 GB RAM for 512x512 with self.network: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a7adbe3e..7fda9382 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -186,18 +186,23 @@ class SliderConfig: self.prompt_file: str = kwargs.get('prompt_file', None) self.prompt_tensors: str = kwargs.get('prompt_tensors', None) self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) + self.use_adapter: bool = kwargs.get('use_adapter', None) # depth + self.adapter_img_dir = kwargs.get('adapter_img_dir', None) + self.high_ram = kwargs.get('high_ram', False) # expand targets if shuffling from toolkit.prompt_utils import get_slider_target_permutations self.targets: List[SliderTargetConfig] = [] targets = [SliderTargetConfig(**target) for target in targets] # do permutations if shuffle is true + print(f"Building slider targets") for target in targets: if target.shuffle: - target_permutations = get_slider_target_permutations(target) + target_permutations = get_slider_target_permutations(target, max_permutations=100) self.targets = self.targets + target_permutations else: self.targets.append(target) + print(f"Built {len(self.targets)} slider targets (with permutations)") class DatasetConfig: diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index c300e6eb..e56cb793 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -105,6 +105,18 @@ class EncodedPromptPair: self.both_targets = self.both_targets.to(*args, **kwargs) return self + def detach(self): + self.target_class = self.target_class.detach() + self.target_class_with_neutral = self.target_class_with_neutral.detach() + self.positive_target = self.positive_target.detach() + self.positive_target_with_neutral = self.positive_target_with_neutral.detach() + self.negative_target = self.negative_target.detach() + self.negative_target_with_neutral = self.negative_target_with_neutral.detach() + self.neutral = self.neutral.detach() + self.empty_prompt = self.empty_prompt.detach() + self.both_targets = self.both_targets.detach() + return self + def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]): text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) @@ -267,15 +279,17 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc return anchors -def get_permutations(s): +def get_permutations(s, max_permutations=8): # Split the string by comma phrases = [phrase.strip() for phrase in s.split(',')] # remove empty strings phrases = [phrase for phrase in phrases if len(phrase) > 0] + # shuffle the list + random.shuffle(phrases) # Get all permutations - permutations = list(itertools.permutations(phrases)) + permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)]) # Convert the tuples back to comma separated strings return [', '.join(permutation) for permutation in permutations] @@ -283,8 +297,8 @@ def get_permutations(s): def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']: from toolkit.config_modules import SliderTargetConfig - pos_permutations = get_permutations(target.positive) - neg_permutations = get_permutations(target.negative) + pos_permutations = get_permutations(target.positive, max_permutations=max_permutations) + neg_permutations = get_permutations(target.negative, max_permutations=max_permutations) permutations = [] for pos, neg in itertools.product(pos_permutations, neg_permutations):