mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-05 04:59:56 +00:00
Huge memory optimizations, many big fixes
This commit is contained in:
@@ -14,12 +14,9 @@ def flush():
|
||||
|
||||
|
||||
class SDTrainer(BaseSDTrainProcess):
|
||||
sd: StableDiffusion
|
||||
data_loader: DataLoader = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
pass
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -40,6 +37,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
flush()
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
@@ -71,7 +69,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
|
||||
@@ -6,19 +6,16 @@ from jobs.process import BaseProcess
|
||||
|
||||
|
||||
class BaseJob:
|
||||
config: OrderedDict
|
||||
job: str
|
||||
name: str
|
||||
meta: OrderedDict
|
||||
process: List[BaseProcess]
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
if not config:
|
||||
raise ValueError('config is required')
|
||||
self.process: List[BaseProcess]
|
||||
|
||||
self.config = config['config']
|
||||
self.raw_config = config
|
||||
self.job = config['job']
|
||||
self.torch_profiler = self.get_conf('torch_profiler', False)
|
||||
self.name = self.get_conf('name', required=True)
|
||||
if 'meta' in config:
|
||||
self.meta = config['meta']
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from jobs import BaseJob
|
||||
from toolkit.extension import get_all_extensions_process_dict
|
||||
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
|
||||
class ExtensionJob(BaseJob):
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ process_dict = {
|
||||
|
||||
|
||||
class GenerateJob(BaseJob):
|
||||
process: List[GenerateProcess]
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -26,7 +26,6 @@ process_dict = {
|
||||
|
||||
|
||||
class TrainJob(BaseJob):
|
||||
process: List[BaseExtractProcess]
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
super().__init__(config)
|
||||
|
||||
@@ -4,10 +4,6 @@ from jobs.process.BaseProcess import BaseProcess
|
||||
|
||||
|
||||
class BaseExtensionProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
progress_bar: ForwardRef('tqdm') = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
@@ -15,6 +11,9 @@ class BaseExtensionProcess(BaseProcess):
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.process_id: int
|
||||
self.config: OrderedDict
|
||||
self.progress_bar: ForwardRef('tqdm') = None
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
@@ -12,11 +12,6 @@ from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
class BaseExtractProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
output_folder: str
|
||||
output_filename: str
|
||||
output_path: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -25,6 +20,10 @@ class BaseExtractProcess(BaseProcess):
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.config: OrderedDict
|
||||
self.output_folder: str
|
||||
self.output_filename: str
|
||||
self.output_path: str
|
||||
self.process_id = process_id
|
||||
self.job = job
|
||||
self.config = config
|
||||
|
||||
@@ -9,8 +9,6 @@ from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
class BaseMergeProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -19,6 +17,8 @@ class BaseMergeProcess(BaseProcess):
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.process_id: int
|
||||
self.config: OrderedDict
|
||||
self.output_path = self.get_conf('output_path', required=True)
|
||||
self.dtype = self.get_conf('dtype', self.job.dtype)
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections import OrderedDict
|
||||
|
||||
|
||||
class BaseProcess(object):
|
||||
meta: OrderedDict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -13,6 +12,7 @@ class BaseProcess(object):
|
||||
config: OrderedDict
|
||||
):
|
||||
self.process_id = process_id
|
||||
self.meta: OrderedDict
|
||||
self.job = job
|
||||
self.config = config
|
||||
self.raw_process_config = config
|
||||
|
||||
@@ -9,6 +9,7 @@ from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
@@ -31,11 +32,12 @@ def flush():
|
||||
|
||||
|
||||
class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sd: StableDiffusion
|
||||
embedding: Union[Embedding, None] = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
|
||||
super().__init__(process_id, job, config)
|
||||
self.sd: StableDiffusion
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
@@ -344,7 +346,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noisy_latents = noisy_latents.detach()
|
||||
noise.requires_grad = False
|
||||
noise = noise.detach()
|
||||
|
||||
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
|
||||
|
||||
|
||||
@@ -14,11 +14,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class BaseTrainProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
writer: 'SummaryWriter'
|
||||
job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
|
||||
progress_bar: 'tqdm' = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -27,6 +22,12 @@ class BaseTrainProcess(BaseProcess):
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.process_id: int
|
||||
self.config: OrderedDict
|
||||
self.writer: 'SummaryWriter'
|
||||
self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
|
||||
self.progress_bar: 'tqdm' = None
|
||||
|
||||
self.progress_bar = None
|
||||
self.writer = None
|
||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder if hasattr(self.job, 'training_folder') else None)
|
||||
|
||||
@@ -16,9 +16,9 @@ import random
|
||||
|
||||
|
||||
class GenerateConfig:
|
||||
prompts: List[str]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.prompts: List[str]
|
||||
self.sampler = kwargs.get('sampler', 'ddpm')
|
||||
self.width = kwargs.get('width', 512)
|
||||
self.height = kwargs.get('height', 512)
|
||||
|
||||
@@ -24,6 +24,9 @@ class ModRescaleLoraProcess(BaseProcess):
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.process_id: int
|
||||
self.config: OrderedDict
|
||||
self.progress_bar: ForwardRef('tqdm') = None
|
||||
self.input_path = self.get_conf('input_path', required=True)
|
||||
self.output_path = self.get_conf('output_path', required=True)
|
||||
self.replace_meta = self.get_conf('replace_meta', default=False)
|
||||
|
||||
@@ -91,37 +91,38 @@ class LoRAModule(torch.nn.Module):
|
||||
# allowing us to run positive and negative weights in the same batch
|
||||
# really only useful for slider training for now
|
||||
def get_multiplier(self, lora_up):
|
||||
batch_size = lora_up.size(0)
|
||||
# batch will have all negative prompts first and positive prompts second
|
||||
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
|
||||
# if there is more than our multiplier, it is likely a batch size increase, so we need to
|
||||
# interleave the multipliers
|
||||
if isinstance(self.multiplier, list):
|
||||
if len(self.multiplier) == 0:
|
||||
# single item, just return it
|
||||
return self.multiplier[0]
|
||||
elif len(self.multiplier) == batch_size:
|
||||
# not doing CFG
|
||||
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
|
||||
with torch.no_grad():
|
||||
batch_size = lora_up.size(0)
|
||||
# batch will have all negative prompts first and positive prompts second
|
||||
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
|
||||
# if there is more than our multiplier, it is likely a batch size increase, so we need to
|
||||
# interleave the multipliers
|
||||
if isinstance(self.multiplier, list):
|
||||
if len(self.multiplier) == 0:
|
||||
# single item, just return it
|
||||
return self.multiplier[0]
|
||||
elif len(self.multiplier) == batch_size:
|
||||
# not doing CFG
|
||||
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
|
||||
else:
|
||||
|
||||
# we have a list of multipliers, so we need to get the multiplier for this batch
|
||||
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
|
||||
# should be 1 for if total batch size was 1
|
||||
num_interleaves = (batch_size // 2) // len(self.multiplier)
|
||||
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
|
||||
|
||||
# match lora_up rank
|
||||
if len(lora_up.size()) == 2:
|
||||
multiplier_tensor = multiplier_tensor.view(-1, 1)
|
||||
elif len(lora_up.size()) == 3:
|
||||
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
|
||||
elif len(lora_up.size()) == 4:
|
||||
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
|
||||
return multiplier_tensor.detach()
|
||||
|
||||
else:
|
||||
|
||||
# we have a list of multipliers, so we need to get the multiplier for this batch
|
||||
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
|
||||
# should be 1 for if total batch size was 1
|
||||
num_interleaves = (batch_size // 2) // len(self.multiplier)
|
||||
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
|
||||
|
||||
# match lora_up rank
|
||||
if len(lora_up.size()) == 2:
|
||||
multiplier_tensor = multiplier_tensor.view(-1, 1)
|
||||
elif len(lora_up.size()) == 3:
|
||||
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
|
||||
elif len(lora_up.size()) == 4:
|
||||
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
|
||||
return multiplier_tensor
|
||||
|
||||
else:
|
||||
return self.multiplier
|
||||
return self.multiplier
|
||||
|
||||
def _call_forward(self, x):
|
||||
# module dropout
|
||||
@@ -152,35 +153,38 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
multiplier = self.get_multiplier(lx)
|
||||
|
||||
return lx * multiplier * scale
|
||||
return lx * scale
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
lora_output = self._call_forward(x)
|
||||
|
||||
if self.is_normalizing:
|
||||
# get a dim array from orig forward that had index of all dimensions except the batch and channel
|
||||
with torch.no_grad():
|
||||
# do this calculation without multiplier
|
||||
# get a dim array from orig forward that had index of all dimensions except the batch and channel
|
||||
|
||||
# Calculate the target magnitude for the combined output
|
||||
orig_max = torch.max(torch.abs(org_forwarded))
|
||||
# Calculate the target magnitude for the combined output
|
||||
orig_max = torch.max(torch.abs(org_forwarded))
|
||||
|
||||
# Calculate the additional increase in magnitude that lora_output would introduce
|
||||
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded))
|
||||
# Calculate the additional increase in magnitude that lora_output would introduce
|
||||
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded))
|
||||
|
||||
epsilon = 1e-6 # Small constant to avoid division by zero
|
||||
epsilon = 1e-6 # Small constant to avoid division by zero
|
||||
|
||||
# Calculate the scaling factor for the lora_output
|
||||
# to ensure that the potential increase in magnitude doesn't change the original max
|
||||
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
|
||||
# Calculate the scaling factor for the lora_output
|
||||
# to ensure that the potential increase in magnitude doesn't change the original max
|
||||
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
|
||||
normalize_scaler = normalize_scaler.detach()
|
||||
|
||||
# save the scaler so it can be applied later
|
||||
self.normalize_scaler = normalize_scaler.clone().detach()
|
||||
# save the scaler so it can be applied later
|
||||
self.normalize_scaler = normalize_scaler.clone().detach()
|
||||
|
||||
lora_output *= normalize_scaler
|
||||
|
||||
return org_forwarded + lora_output
|
||||
multiplier = self.get_multiplier(lora_output)
|
||||
|
||||
return org_forwarded + (lora_output * multiplier)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.is_checkpointing = True
|
||||
|
||||
@@ -610,6 +610,7 @@ class StableDiffusion:
|
||||
)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_images(
|
||||
self,
|
||||
image_list: List[torch.Tensor],
|
||||
@@ -625,6 +626,8 @@ class StableDiffusion:
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
# move to device and dtype
|
||||
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
|
||||
|
||||
@@ -635,8 +638,9 @@ class StableDiffusion:
|
||||
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
flush()
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * self.vae.config['scaling_factor']
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
Reference in New Issue
Block a user