Bug fixes. added ability to use l1 loss. varous other tests and improvements

This commit is contained in:
Jaret Burkett
2024-01-31 06:30:54 -07:00
parent 92b9c71d44
commit 1ae1017748
9 changed files with 474 additions and 23 deletions

View File

@@ -96,6 +96,8 @@ class DataLoaderBatchDTO:
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
self.unconditional_latents: Union[torch.Tensor, None] = None
self.clip_image_embeds: Union[List[dict], None] = None
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
if not is_latents_cached:
# only return a tensor if latents are not cached
@@ -183,6 +185,23 @@ class DataLoaderBatchDTO:
else:
unconditional_tensor.append(x.unconditional_tensor)
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
if any([x.clip_image_embeds is not None for x in self.file_items]):
self.clip_image_embeds = []
for x in self.file_items:
if x.clip_image_embeds is not None:
self.clip_image_embeds.append(x.clip_image_embeds)
else:
raise Exception("clip_image_embeds is None for some file items")
if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]):
self.clip_image_embeds_unconditional = []
for x in self.file_items:
if x.clip_image_embeds_unconditional is not None:
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
else:
raise Exception("clip_image_embeds_unconditional is None for some file items")
except Exception as e:
print(e)
raise e