mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Allow short and long caption combinations like form the new captioning system. Merge the network into the model before inference and reextract when done. Doubles inference speed on locon models during inference. allow splitting a batch into individual components and run them through alone. Basicallt gradient accumulation with single batch size.
This commit is contained in:
@@ -460,6 +460,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
prompts = batch.get_caption_list()
|
||||
is_reg_list = batch.get_is_reg_list()
|
||||
|
||||
is_any_reg = any([is_reg for is_reg in is_reg_list])
|
||||
|
||||
do_double = self.train_config.short_and_long_captions and not is_any_reg
|
||||
|
||||
if self.train_config.short_and_long_captions and do_double:
|
||||
# dont do this with regs. No point
|
||||
|
||||
# double batch and add short captions to the end
|
||||
prompts = prompts + batch.get_caption_short_list()
|
||||
is_reg_list = is_reg_list + is_reg_list
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
@@ -500,7 +511,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# we determine noise from the differential of the latents
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batch_size = len(batch.file_items)
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
|
||||
@@ -582,6 +593,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# todo is this for sdxl? find out where this came from originally
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
def double_up_tensor(tensor: torch.Tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
return torch.cat([tensor, tensor], dim=0)
|
||||
|
||||
if do_double:
|
||||
noisy_latents = double_up_tensor(noisy_latents)
|
||||
noise = double_up_tensor(noise)
|
||||
timesteps = double_up_tensor(timesteps)
|
||||
# prompts are already updated above
|
||||
imgs = double_up_tensor(imgs)
|
||||
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
|
||||
batch.control_tensor = double_up_tensor(batch.control_tensor)
|
||||
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noisy_latents = noisy_latents.detach()
|
||||
@@ -927,16 +953,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
### HOOK ###
|
||||
self.hook_before_train_loop()
|
||||
|
||||
if self.has_first_sample_requested:
|
||||
if self.has_first_sample_requested and self.step_num <= 1:
|
||||
self.print("Generating first sample from first sample config")
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample:
|
||||
self.print("Skipping first sample due to config setting")
|
||||
else:
|
||||
elif self.step_num <= 1:
|
||||
self.print("Generating baseline samples before training")
|
||||
self.sample(0)
|
||||
self.sample(self.step_num)
|
||||
|
||||
self.progress_bar = ToolkitProgressBar(
|
||||
total=self.train_config.steps,
|
||||
|
||||
Reference in New Issue
Block a user