Allow short and long caption combinations like form the new captioning system. Merge the network into the model before inference and reextract when done. Doubles inference speed on locon models during inference. allow splitting a batch into individual components and run them through alone. Basicallt gradient accumulation with single batch size.

This commit is contained in:
Jaret Burkett
2023-10-24 16:02:07 -06:00
parent 73c8b50975
commit 002279cec3
9 changed files with 315 additions and 115 deletions

View File

@@ -172,7 +172,9 @@ class SDTrainer(BaseSDTrainProcess):
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
# dont use network on this
self.network.multiplier = 0.0
# self.network.multiplier = 0.0
was_network_active = self.network.is_active
self.network.is_active = False
self.sd.unet.eval()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
@@ -187,7 +189,8 @@ class SDTrainer(BaseSDTrainProcess):
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
# restore network
self.network.multiplier = network_weight_list
# self.network.multiplier = network_weight_list
self.network.is_active = was_network_active
return prior_pred
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
@@ -197,6 +200,8 @@ class SDTrainer(BaseSDTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
if self.train_config.single_item_batching:
network_weight_list = network_weight_list + network_weight_list
has_adapter_img = batch.control_tensor is not None
@@ -234,7 +239,7 @@ class SDTrainer(BaseSDTrainProcess):
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
mask_multiplier = 1.0
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
if batch.mask_tensor is not None:
with self.timer('get_mask_multiplier'):
# upsampling no supported for bfloat16
@@ -297,107 +302,152 @@ class SDTrainer(BaseSDTrainProcess):
self.optimizer.zero_grad(set_to_none=True)
# activate network if it exits
with network:
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
# make the batch splits
if self.train_config.single_item_batching:
batch_size = noisy_latents.shape[0]
# chunk/split everything
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
noise_list = torch.chunk(noise, batch_size, dim=0)
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts]
if imgs is not None:
imgs_list = torch.chunk(imgs, batch_size, dim=0)
else:
imgs_list = [None for _ in range(batch_size)]
if adapter_images is not None:
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
else:
adapter_images_list = [None for _ in range(batch_size)]
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
# flush()
pred_kwargs = {}
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.adapter if self.adapter else self.assistant_adapter
adapter_multiplier = get_adapter_multiplier()
else:
# but it all in an array
noisy_latents_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditioned_prompts_list = [conditioned_prompts]
imgs_list = [imgs]
adapter_images_list = [adapter_images]
mask_multiplier_list = [mask_multiplier]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
noisy_latents_list,
noise_list,
timesteps_list,
conditioned_prompts_list,
imgs_list,
adapter_images_list,
mask_multiplier_list
):
with network:
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
self.device_torch,
dtype=dtype)
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
# flush()
pred_kwargs = {}
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.adapter if self.adapter else self.assistant_adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
)
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
with self.timer('backward'):
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
loss.backward()
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")
with self.timer('backward'):
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# flush()
with self.timer('optimizer_step'):
# apply gradients
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
with self.timer('scheduler_step'):
self.lr_scheduler.step()