From c446f768ea0c7667e796b925e37a97ca21861e53 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 27 Aug 2023 17:48:02 -0600 Subject: [PATCH] Huge memory optimizations, many big fixes --- extensions_built_in/sd_trainer/SDTrainer.py | 6 +- jobs/BaseJob.py | 7 +- jobs/ExtensionJob.py | 3 +- jobs/GenerateJob.py | 1 - jobs/TrainJob.py | 1 - jobs/process/BaseExtensionProcess.py | 7 +- jobs/process/BaseExtractProcess.py | 9 +- jobs/process/BaseMergeProcess.py | 4 +- jobs/process/BaseProcess.py | 2 +- jobs/process/BaseSDTrainProcess.py | 8 +- jobs/process/BaseTrainProcess.py | 11 +-- jobs/process/GenerateProcess.py | 2 +- jobs/process/ModRescaleLoraProcess.py | 3 + toolkit/lora_special.py | 94 +++++++++++---------- toolkit/stable_diffusion_model.py | 6 +- 15 files changed, 86 insertions(+), 78 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3e25e0c..4edbf26 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -14,12 +14,9 @@ def flush(): class SDTrainer(BaseSDTrainProcess): - sd: StableDiffusion - data_loader: DataLoader = None def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): super().__init__(process_id, job, config, **kwargs) - pass def before_model_load(self): pass @@ -40,6 +37,7 @@ class SDTrainer(BaseSDTrainProcess): noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) self.optimizer.zero_grad() + flush() # text encoding grad_on_text_encoder = False @@ -71,7 +69,7 @@ class SDTrainer(BaseSDTrainProcess): timestep=timesteps, guidance_scale=1.0, ) - + # 9.18 gb noise = noise.to(self.device_torch, dtype=dtype) if self.sd.prediction_type == 'v_prediction': diff --git a/jobs/BaseJob.py b/jobs/BaseJob.py index 6027e71..8efd009 100644 --- a/jobs/BaseJob.py +++ b/jobs/BaseJob.py @@ -6,19 +6,16 @@ from jobs.process import BaseProcess class BaseJob: - config: OrderedDict - job: str - name: str - meta: OrderedDict - process: List[BaseProcess] def __init__(self, config: OrderedDict): if not config: raise ValueError('config is required') + self.process: List[BaseProcess] self.config = config['config'] self.raw_config = config self.job = config['job'] + self.torch_profiler = self.get_conf('torch_profiler', False) self.name = self.get_conf('name', required=True) if 'meta' in config: self.meta = config['meta'] diff --git a/jobs/ExtensionJob.py b/jobs/ExtensionJob.py index e1ddc96..def4f85 100644 --- a/jobs/ExtensionJob.py +++ b/jobs/ExtensionJob.py @@ -1,7 +1,8 @@ +import os from collections import OrderedDict from jobs import BaseJob from toolkit.extension import get_all_extensions_process_dict - +from toolkit.paths import CONFIG_ROOT class ExtensionJob(BaseJob): diff --git a/jobs/GenerateJob.py b/jobs/GenerateJob.py index 5bab114..ab61701 100644 --- a/jobs/GenerateJob.py +++ b/jobs/GenerateJob.py @@ -14,7 +14,6 @@ process_dict = { class GenerateJob(BaseJob): - process: List[GenerateProcess] def __init__(self, config: OrderedDict): super().__init__(config) diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index 31ed51e..dda64e2 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -26,7 +26,6 @@ process_dict = { class TrainJob(BaseJob): - process: List[BaseExtractProcess] def __init__(self, config: OrderedDict): super().__init__(config) diff --git a/jobs/process/BaseExtensionProcess.py b/jobs/process/BaseExtensionProcess.py index d618563..b53dc1c 100644 --- a/jobs/process/BaseExtensionProcess.py +++ b/jobs/process/BaseExtensionProcess.py @@ -4,10 +4,6 @@ from jobs.process.BaseProcess import BaseProcess class BaseExtensionProcess(BaseProcess): - process_id: int - config: OrderedDict - progress_bar: ForwardRef('tqdm') = None - def __init__( self, process_id: int, @@ -15,6 +11,9 @@ class BaseExtensionProcess(BaseProcess): config: OrderedDict ): super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.progress_bar: ForwardRef('tqdm') = None def run(self): super().run() diff --git a/jobs/process/BaseExtractProcess.py b/jobs/process/BaseExtractProcess.py index 009bb7b..ac10da5 100644 --- a/jobs/process/BaseExtractProcess.py +++ b/jobs/process/BaseExtractProcess.py @@ -12,11 +12,6 @@ from toolkit.train_tools import get_torch_dtype class BaseExtractProcess(BaseProcess): - process_id: int - config: OrderedDict - output_folder: str - output_filename: str - output_path: str def __init__( self, @@ -25,6 +20,10 @@ class BaseExtractProcess(BaseProcess): config: OrderedDict ): super().__init__(process_id, job, config) + self.config: OrderedDict + self.output_folder: str + self.output_filename: str + self.output_path: str self.process_id = process_id self.job = job self.config = config diff --git a/jobs/process/BaseMergeProcess.py b/jobs/process/BaseMergeProcess.py index d5396dc..55dfec6 100644 --- a/jobs/process/BaseMergeProcess.py +++ b/jobs/process/BaseMergeProcess.py @@ -9,8 +9,6 @@ from toolkit.train_tools import get_torch_dtype class BaseMergeProcess(BaseProcess): - process_id: int - config: OrderedDict def __init__( self, @@ -19,6 +17,8 @@ class BaseMergeProcess(BaseProcess): config: OrderedDict ): super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict self.output_path = self.get_conf('output_path', required=True) self.dtype = self.get_conf('dtype', self.job.dtype) self.torch_dtype = get_torch_dtype(self.dtype) diff --git a/jobs/process/BaseProcess.py b/jobs/process/BaseProcess.py index c21e7ae..7a7a69f 100644 --- a/jobs/process/BaseProcess.py +++ b/jobs/process/BaseProcess.py @@ -4,7 +4,6 @@ from collections import OrderedDict class BaseProcess(object): - meta: OrderedDict def __init__( self, @@ -13,6 +12,7 @@ class BaseProcess(object): config: OrderedDict ): self.process_id = process_id + self.meta: OrderedDict self.job = job self.config = config self.raw_process_config = config diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 9bb0a17..bd8790d 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -9,6 +9,7 @@ from toolkit.data_loader import get_dataloader_from_datasets from toolkit.embedding import Embedding from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer +from toolkit.paths import CONFIG_ROOT from toolkit.scheduler import get_lr_scheduler from toolkit.stable_diffusion_model import StableDiffusion @@ -31,11 +32,12 @@ def flush(): class BaseSDTrainProcess(BaseTrainProcess): - sd: StableDiffusion - embedding: Union[Embedding, None] = None def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): super().__init__(process_id, job, config) + self.sd: StableDiffusion + self.embedding: Union[Embedding, None] = None + self.custom_pipeline = custom_pipeline self.step_num = 0 self.start_step = 0 @@ -344,7 +346,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # remove grads for these noisy_latents.requires_grad = False + noisy_latents = noisy_latents.detach() noise.requires_grad = False + noise = noise.detach() return noisy_latents, noise, timesteps, conditioned_prompts, imgs diff --git a/jobs/process/BaseTrainProcess.py b/jobs/process/BaseTrainProcess.py index d1c65bf..f594704 100644 --- a/jobs/process/BaseTrainProcess.py +++ b/jobs/process/BaseTrainProcess.py @@ -14,11 +14,6 @@ if TYPE_CHECKING: class BaseTrainProcess(BaseProcess): - process_id: int - config: OrderedDict - writer: 'SummaryWriter' - job: Union['TrainJob', 'BaseJob', 'ExtensionJob'] - progress_bar: 'tqdm' = None def __init__( self, @@ -27,6 +22,12 @@ class BaseTrainProcess(BaseProcess): config: OrderedDict ): super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.writer: 'SummaryWriter' + self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob'] + self.progress_bar: 'tqdm' = None + self.progress_bar = None self.writer = None self.training_folder = self.get_conf('training_folder', self.job.training_folder if hasattr(self.job, 'training_folder') else None) diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index 89eb5ee..a005ae0 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -16,9 +16,9 @@ import random class GenerateConfig: - prompts: List[str] def __init__(self, **kwargs): + self.prompts: List[str] self.sampler = kwargs.get('sampler', 'ddpm') self.width = kwargs.get('width', 512) self.height = kwargs.get('height', 512) diff --git a/jobs/process/ModRescaleLoraProcess.py b/jobs/process/ModRescaleLoraProcess.py index ff8304d..8bb7436 100644 --- a/jobs/process/ModRescaleLoraProcess.py +++ b/jobs/process/ModRescaleLoraProcess.py @@ -24,6 +24,9 @@ class ModRescaleLoraProcess(BaseProcess): config: OrderedDict ): super().__init__(process_id, job, config) + self.process_id: int + self.config: OrderedDict + self.progress_bar: ForwardRef('tqdm') = None self.input_path = self.get_conf('input_path', required=True) self.output_path = self.get_conf('output_path', required=True) self.replace_meta = self.get_conf('replace_meta', default=False) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index f13c9e4..0a47708 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -91,37 +91,38 @@ class LoRAModule(torch.nn.Module): # allowing us to run positive and negative weights in the same batch # really only useful for slider training for now def get_multiplier(self, lora_up): - batch_size = lora_up.size(0) - # batch will have all negative prompts first and positive prompts second - # our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts - # if there is more than our multiplier, it is likely a batch size increase, so we need to - # interleave the multipliers - if isinstance(self.multiplier, list): - if len(self.multiplier) == 0: - # single item, just return it - return self.multiplier[0] - elif len(self.multiplier) == batch_size: - # not doing CFG - multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype) + with torch.no_grad(): + batch_size = lora_up.size(0) + # batch will have all negative prompts first and positive prompts second + # our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts + # if there is more than our multiplier, it is likely a batch size increase, so we need to + # interleave the multipliers + if isinstance(self.multiplier, list): + if len(self.multiplier) == 0: + # single item, just return it + return self.multiplier[0] + elif len(self.multiplier) == batch_size: + # not doing CFG + multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype) + else: + + # we have a list of multipliers, so we need to get the multiplier for this batch + multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype) + # should be 1 for if total batch size was 1 + num_interleaves = (batch_size // 2) // len(self.multiplier) + multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves) + + # match lora_up rank + if len(lora_up.size()) == 2: + multiplier_tensor = multiplier_tensor.view(-1, 1) + elif len(lora_up.size()) == 3: + multiplier_tensor = multiplier_tensor.view(-1, 1, 1) + elif len(lora_up.size()) == 4: + multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) + return multiplier_tensor.detach() + else: - - # we have a list of multipliers, so we need to get the multiplier for this batch - multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype) - # should be 1 for if total batch size was 1 - num_interleaves = (batch_size // 2) // len(self.multiplier) - multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves) - - # match lora_up rank - if len(lora_up.size()) == 2: - multiplier_tensor = multiplier_tensor.view(-1, 1) - elif len(lora_up.size()) == 3: - multiplier_tensor = multiplier_tensor.view(-1, 1, 1) - elif len(lora_up.size()) == 4: - multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) - return multiplier_tensor - - else: - return self.multiplier + return self.multiplier def _call_forward(self, x): # module dropout @@ -152,35 +153,38 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) - multiplier = self.get_multiplier(lx) - - return lx * multiplier * scale + return lx * scale def forward(self, x): org_forwarded = self.org_forward(x) lora_output = self._call_forward(x) if self.is_normalizing: - # get a dim array from orig forward that had index of all dimensions except the batch and channel + with torch.no_grad(): + # do this calculation without multiplier + # get a dim array from orig forward that had index of all dimensions except the batch and channel - # Calculate the target magnitude for the combined output - orig_max = torch.max(torch.abs(org_forwarded)) + # Calculate the target magnitude for the combined output + orig_max = torch.max(torch.abs(org_forwarded)) - # Calculate the additional increase in magnitude that lora_output would introduce - potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded)) + # Calculate the additional increase in magnitude that lora_output would introduce + potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded)) - epsilon = 1e-6 # Small constant to avoid division by zero + epsilon = 1e-6 # Small constant to avoid division by zero - # Calculate the scaling factor for the lora_output - # to ensure that the potential increase in magnitude doesn't change the original max - normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon) + # Calculate the scaling factor for the lora_output + # to ensure that the potential increase in magnitude doesn't change the original max + normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon) + normalize_scaler = normalize_scaler.detach() - # save the scaler so it can be applied later - self.normalize_scaler = normalize_scaler.clone().detach() + # save the scaler so it can be applied later + self.normalize_scaler = normalize_scaler.clone().detach() lora_output *= normalize_scaler - return org_forwarded + lora_output + multiplier = self.get_multiplier(lora_output) + + return org_forwarded + (lora_output * multiplier) def enable_gradient_checkpointing(self): self.is_checkpointing = True diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 5794a57..b831cdd 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -610,6 +610,7 @@ class StableDiffusion: ) ) + @torch.no_grad() def encode_images( self, image_list: List[torch.Tensor], @@ -625,6 +626,8 @@ class StableDiffusion: # Move to vae to device if on cpu if self.vae.device == 'cpu': self.vae.to(self.device) + self.vae.eval() + self.vae.requires_grad_(False) # move to device and dtype image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list] @@ -635,8 +638,9 @@ class StableDiffusion: image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image) images = torch.stack(image_list) + flush() latents = self.vae.encode(images).latent_dist.sample() - latents = latents * 0.18215 + latents = latents * self.vae.config['scaling_factor'] latents = latents.to(device, dtype=dtype) return latents