mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Small updates and bug fixes for various things
This commit is contained in:
@@ -764,6 +764,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
batch: Optional['DataLoaderBatchDTO'] = None,
|
||||
is_primary_pred: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -1553,6 +1554,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
is_primary_pred=True,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
Reference in New Issue
Block a user