Bug fixes. Added some functionality to help with private extensions

This commit is contained in:
Jaret Burkett
2023-10-05 07:09:34 -06:00
parent 579650eaf8
commit f73402473b
8 changed files with 99 additions and 20 deletions

View File

@@ -84,8 +84,22 @@ class DataLoaderBatchDTO:
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
if self.file_items[0].control_tensor is not None:
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
if any([x.control_tensor is not None for x in self.file_items]):
# find one to use as a base
base_control_tensor = None
for x in self.file_items:
if x.control_tensor is not None:
base_control_tensor = x.control_tensor
break
control_tensors = []
for x in self.file_items:
if x.control_tensor is None:
control_tensors.append(torch.zeros_like(base_control_tensor))
else:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
except Exception as e:
print(e)
raise e