mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-01 16:49:56 +00:00
Bug fixes. added ability to use l1 loss. varous other tests and improvements
This commit is contained in:
@@ -96,6 +96,8 @@ class DataLoaderBatchDTO:
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_latents: Union[torch.Tensor, None] = None
|
||||
self.clip_image_embeds: Union[List[dict], None] = None
|
||||
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
@@ -183,6 +185,23 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
unconditional_tensor.append(x.unconditional_tensor)
|
||||
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
|
||||
|
||||
if any([x.clip_image_embeds is not None for x in self.file_items]):
|
||||
self.clip_image_embeds = []
|
||||
for x in self.file_items:
|
||||
if x.clip_image_embeds is not None:
|
||||
self.clip_image_embeds.append(x.clip_image_embeds)
|
||||
else:
|
||||
raise Exception("clip_image_embeds is None for some file items")
|
||||
|
||||
if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]):
|
||||
self.clip_image_embeds_unconditional = []
|
||||
for x in self.file_items:
|
||||
if x.clip_image_embeds_unconditional is not None:
|
||||
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
|
||||
else:
|
||||
raise Exception("clip_image_embeds_unconditional is None for some file items")
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user