Huge memory optimizations, many big fixes

This commit is contained in:
Jaret Burkett
2023-08-27 17:48:02 -06:00
parent cc49786ee9
commit c446f768ea
15 changed files with 86 additions and 78 deletions

View File

@@ -91,37 +91,38 @@ class LoRAModule(torch.nn.Module):
# allowing us to run positive and negative weights in the same batch
# really only useful for slider training for now
def get_multiplier(self, lora_up):
batch_size = lora_up.size(0)
# batch will have all negative prompts first and positive prompts second
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
# if there is more than our multiplier, it is likely a batch size increase, so we need to
# interleave the multipliers
if isinstance(self.multiplier, list):
if len(self.multiplier) == 0:
# single item, just return it
return self.multiplier[0]
elif len(self.multiplier) == batch_size:
# not doing CFG
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
with torch.no_grad():
batch_size = lora_up.size(0)
# batch will have all negative prompts first and positive prompts second
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
# if there is more than our multiplier, it is likely a batch size increase, so we need to
# interleave the multipliers
if isinstance(self.multiplier, list):
if len(self.multiplier) == 0:
# single item, just return it
return self.multiplier[0]
elif len(self.multiplier) == batch_size:
# not doing CFG
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
else:
# we have a list of multipliers, so we need to get the multiplier for this batch
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
# should be 1 for if total batch size was 1
num_interleaves = (batch_size // 2) // len(self.multiplier)
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
# match lora_up rank
if len(lora_up.size()) == 2:
multiplier_tensor = multiplier_tensor.view(-1, 1)
elif len(lora_up.size()) == 3:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
elif len(lora_up.size()) == 4:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
return multiplier_tensor.detach()
else:
# we have a list of multipliers, so we need to get the multiplier for this batch
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
# should be 1 for if total batch size was 1
num_interleaves = (batch_size // 2) // len(self.multiplier)
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
# match lora_up rank
if len(lora_up.size()) == 2:
multiplier_tensor = multiplier_tensor.view(-1, 1)
elif len(lora_up.size()) == 3:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
elif len(lora_up.size()) == 4:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
return multiplier_tensor
else:
return self.multiplier
return self.multiplier
def _call_forward(self, x):
# module dropout
@@ -152,35 +153,38 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
multiplier = self.get_multiplier(lx)
return lx * multiplier * scale
return lx * scale
def forward(self, x):
org_forwarded = self.org_forward(x)
lora_output = self._call_forward(x)
if self.is_normalizing:
# get a dim array from orig forward that had index of all dimensions except the batch and channel
with torch.no_grad():
# do this calculation without multiplier
# get a dim array from orig forward that had index of all dimensions except the batch and channel
# Calculate the target magnitude for the combined output
orig_max = torch.max(torch.abs(org_forwarded))
# Calculate the target magnitude for the combined output
orig_max = torch.max(torch.abs(org_forwarded))
# Calculate the additional increase in magnitude that lora_output would introduce
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded))
# Calculate the additional increase in magnitude that lora_output would introduce
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded))
epsilon = 1e-6 # Small constant to avoid division by zero
epsilon = 1e-6 # Small constant to avoid division by zero
# Calculate the scaling factor for the lora_output
# to ensure that the potential increase in magnitude doesn't change the original max
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
# Calculate the scaling factor for the lora_output
# to ensure that the potential increase in magnitude doesn't change the original max
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
normalize_scaler = normalize_scaler.detach()
# save the scaler so it can be applied later
self.normalize_scaler = normalize_scaler.clone().detach()
# save the scaler so it can be applied later
self.normalize_scaler = normalize_scaler.clone().detach()
lora_output *= normalize_scaler
return org_forwarded + lora_output
multiplier = self.get_multiplier(lora_output)
return org_forwarded + (lora_output * multiplier)
def enable_gradient_checkpointing(self):
self.is_checkpointing = True

View File

@@ -610,6 +610,7 @@ class StableDiffusion:
)
)
@torch.no_grad()
def encode_images(
self,
image_list: List[torch.Tensor],
@@ -625,6 +626,8 @@ class StableDiffusion:
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(self.device)
self.vae.eval()
self.vae.requires_grad_(False)
# move to device and dtype
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
@@ -635,8 +638,9 @@ class StableDiffusion:
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
images = torch.stack(image_list)
flush()
latents = self.vae.encode(images).latent_dist.sample()
latents = latents * 0.18215
latents = latents * self.vae.config['scaling_factor']
latents = latents.to(device, dtype=dtype)
return latents