mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes. Added some functionality to help with private extensions
This commit is contained in:
@@ -84,8 +84,22 @@ class DataLoaderBatchDTO:
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
if self.file_items[0].control_tensor is not None:
|
||||
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
if any([x.control_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_control_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.control_tensor is not None:
|
||||
base_control_tensor = x.control_tensor
|
||||
break
|
||||
control_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.control_tensor is None:
|
||||
control_tensors.append(torch.zeros_like(base_control_tensor))
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user