mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Finished up first frame for i2v adapter
This commit is contained in:
@@ -111,6 +111,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
self.cached_control_image_0_1: Optional[torch.Tensor] = None
|
||||
|
||||
self.setup_adapter()
|
||||
|
||||
@@ -1069,12 +1071,33 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
def trigger_pre_te(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
tensors_0_1: Optional[torch.Tensor]=None,
|
||||
tensors_preprocessed: Optional[torch.Tensor]=None, # preprocessed by the dataloader
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
batch_tensor: Optional[torch.Tensor]=None,
|
||||
quad_count=4,
|
||||
batch_size=1,
|
||||
) -> PromptEmbeds:
|
||||
if tensors_0_1 is not None:
|
||||
# actual 0 - 1 image
|
||||
self.cached_control_image_0_1 = tensors_0_1
|
||||
else:
|
||||
# image has been processed through the dataloader and is prepped for vision encoder
|
||||
self.cached_control_image_0_1 = None
|
||||
if batch_tensor is not None and self.cached_control_image_0_1 is None:
|
||||
# convert it to 0 - 1
|
||||
to_cache = batch_tensor / 2 + 0.5
|
||||
# videos come in (bs, num_frames, channels, height, width)
|
||||
# images come in (bs, channels, height, width)
|
||||
# if it is a video, just grad first frame
|
||||
if len(to_cache.shape) == 5:
|
||||
to_cache = to_cache[:, 0:1, :, :, :]
|
||||
to_cache = to_cache.squeeze(1)
|
||||
self.cached_control_image_0_1 = to_cache
|
||||
|
||||
if tensors_preprocessed is not None and has_been_preprocessed:
|
||||
tensors_0_1 = tensors_preprocessed
|
||||
# if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']:
|
||||
skip_unconditional = self.sd_ref().is_flux
|
||||
|
||||
Reference in New Issue
Block a user