mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Performance optimizations for pre processing the batch
This commit is contained in:
@@ -1009,76 +1009,76 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
return loss
|
||||
|
||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||
self.timer.start('preprocess_batch')
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_raw(batch)
|
||||
batch = self.preprocess_batch(batch)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_processed(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# sanity check
|
||||
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
if encoder.dtype != self.sd.te_torch_dtype:
|
||||
encoder.to(self.sd.te_torch_dtype)
|
||||
else:
|
||||
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
||||
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
||||
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||
# pick random negative prompts
|
||||
if self.negative_prompt_pool is not None:
|
||||
negative_prompts = []
|
||||
for i in range(noisy_latents.shape[0]):
|
||||
num_neg = random.randint(1, self.train_config.max_negative_prompts)
|
||||
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
|
||||
this_neg_prompt = ', '.join(this_neg_prompts)
|
||||
negative_prompts.append(this_neg_prompt)
|
||||
self.batch_negative_prompt = negative_prompts
|
||||
else:
|
||||
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
|
||||
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
# condition the prompt
|
||||
# todo handle more than one adapter image
|
||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
has_clip_image_embeds = batch.clip_image_embeds is not None
|
||||
# force it to be true if doing regs as we handle those differently
|
||||
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
|
||||
has_clip_image = True
|
||||
if self._clip_image_embeds_unconditional is not None:
|
||||
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
||||
has_clip_image = False
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError(
|
||||
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
match_adapter_assist = True
|
||||
elif self.train_config.match_adapter_chance > 0.0:
|
||||
match_adapter_assist = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
) < self.train_config.match_adapter_chance
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
is_reg = False
|
||||
with torch.no_grad():
|
||||
self.timer.start('preprocess_batch')
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_raw(batch)
|
||||
batch = self.preprocess_batch(batch)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_processed(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# sanity check
|
||||
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
if encoder.dtype != self.sd.te_torch_dtype:
|
||||
encoder.to(self.sd.te_torch_dtype)
|
||||
else:
|
||||
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
||||
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
||||
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||
# pick random negative prompts
|
||||
if self.negative_prompt_pool is not None:
|
||||
negative_prompts = []
|
||||
for i in range(noisy_latents.shape[0]):
|
||||
num_neg = random.randint(1, self.train_config.max_negative_prompts)
|
||||
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
|
||||
this_neg_prompt = ', '.join(this_neg_prompts)
|
||||
negative_prompts.append(this_neg_prompt)
|
||||
self.batch_negative_prompt = negative_prompts
|
||||
else:
|
||||
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
|
||||
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
# condition the prompt
|
||||
# todo handle more than one adapter image
|
||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
has_clip_image_embeds = batch.clip_image_embeds is not None
|
||||
# force it to be true if doing regs as we handle those differently
|
||||
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
|
||||
has_clip_image = True
|
||||
if self._clip_image_embeds_unconditional is not None:
|
||||
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
||||
has_clip_image = False
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError(
|
||||
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
match_adapter_assist = True
|
||||
elif self.train_config.match_adapter_chance > 0.0:
|
||||
match_adapter_assist = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
) < self.train_config.match_adapter_chance
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
is_reg = False
|
||||
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
for idx, file_item in enumerate(batch.file_items):
|
||||
if file_item.is_reg:
|
||||
|
||||
Reference in New Issue
Block a user