Small updates and bug fixes for various things

This commit is contained in:
Jaret Burkett
2025-06-03 20:08:35 -06:00
parent b6d25fcd10
commit adc31ec77d
4 changed files with 56 additions and 6 deletions

View File

@@ -764,6 +764,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None,
batch: Optional['DataLoaderBatchDTO'] = None,
is_primary_pred: bool = False,
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
@@ -1553,6 +1554,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
is_primary_pred=True,
**pred_kwargs
)
self.after_unet_predict()