mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Allow short and long caption combinations like form the new captioning system. Merge the network into the model before inference and reextract when done. Doubles inference speed on locon models during inference. allow splitting a batch into individual components and run them through alone. Basicallt gradient accumulation with single batch size.
This commit is contained in:
@@ -172,7 +172,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# dont use network on this
|
||||
self.network.multiplier = 0.0
|
||||
# self.network.multiplier = 0.0
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
@@ -187,7 +189,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
# restore network
|
||||
self.network.multiplier = network_weight_list
|
||||
# self.network.multiplier = network_weight_list
|
||||
self.network.is_active = was_network_active
|
||||
return prior_pred
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
@@ -197,6 +200,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
|
||||
@@ -234,7 +239,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
mask_multiplier = 1.0
|
||||
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
if batch.mask_tensor is not None:
|
||||
with self.timer('get_mask_multiplier'):
|
||||
# upsampling no supported for bfloat16
|
||||
@@ -297,107 +302,152 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
with torch.set_grad_enabled(False):
|
||||
# make sure it is in eval mode
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
# make the batch splits
|
||||
if self.train_config.single_item_batching:
|
||||
batch_size = noisy_latents.shape[0]
|
||||
# chunk/split everything
|
||||
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
||||
noise_list = torch.chunk(noise, batch_size, dim=0)
|
||||
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
|
||||
conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts]
|
||||
if imgs is not None:
|
||||
imgs_list = torch.chunk(imgs, batch_size, dim=0)
|
||||
else:
|
||||
imgs_list = [None for _ in range(batch_size)]
|
||||
if adapter_images is not None:
|
||||
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
|
||||
else:
|
||||
adapter_images_list = [None for _ in range(batch_size)]
|
||||
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.adapter if self.adapter else self.assistant_adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
else:
|
||||
# but it all in an array
|
||||
noisy_latents_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
conditioned_prompts_list = [conditioned_prompts]
|
||||
imgs_list = [imgs]
|
||||
adapter_images_list = [adapter_images]
|
||||
mask_multiplier_list = [mask_multiplier]
|
||||
|
||||
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
|
||||
noisy_latents_list,
|
||||
noise_list,
|
||||
timesteps_list,
|
||||
conditioned_prompts_list,
|
||||
imgs_list,
|
||||
adapter_images_list,
|
||||
mask_multiplier_list
|
||||
):
|
||||
|
||||
with network:
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
with torch.set_grad_enabled(False):
|
||||
# make sure it is in eval mode
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.adapter if self.adapter else self.assistant_adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
if self.assistant_adapter:
|
||||
# not training. detach
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
else:
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) * adapter_multiplier for sample in
|
||||
down_block_additional_residuals
|
||||
]
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
with self.timer('backward'):
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
# with fsdp_overlap_step_with_backward():
|
||||
loss.backward()
|
||||
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
with self.timer('backward'):
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# flush()
|
||||
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user