Finished up first frame for i2v adapter

This commit is contained in:
Jaret Burkett
2025-04-12 17:13:04 -06:00
parent cd37ccfc2e
commit 6fb44db6a0
5 changed files with 206 additions and 87 deletions

View File

@@ -111,6 +111,8 @@ class CustomAdapter(torch.nn.Module):
self.conditional_embeds: Optional[torch.Tensor] = None
self.unconditional_embeds: Optional[torch.Tensor] = None
self.cached_control_image_0_1: Optional[torch.Tensor] = None
self.setup_adapter()
@@ -1069,12 +1071,33 @@ class CustomAdapter(torch.nn.Module):
def trigger_pre_te(
self,
tensors_0_1: torch.Tensor,
tensors_0_1: Optional[torch.Tensor]=None,
tensors_preprocessed: Optional[torch.Tensor]=None, # preprocessed by the dataloader
is_training=False,
has_been_preprocessed=False,
batch_tensor: Optional[torch.Tensor]=None,
quad_count=4,
batch_size=1,
) -> PromptEmbeds:
if tensors_0_1 is not None:
# actual 0 - 1 image
self.cached_control_image_0_1 = tensors_0_1
else:
# image has been processed through the dataloader and is prepped for vision encoder
self.cached_control_image_0_1 = None
if batch_tensor is not None and self.cached_control_image_0_1 is None:
# convert it to 0 - 1
to_cache = batch_tensor / 2 + 0.5
# videos come in (bs, num_frames, channels, height, width)
# images come in (bs, channels, height, width)
# if it is a video, just grad first frame
if len(to_cache.shape) == 5:
to_cache = to_cache[:, 0:1, :, :, :]
to_cache = to_cache.squeeze(1)
self.cached_control_image_0_1 = to_cache
if tensors_preprocessed is not None and has_been_preprocessed:
tensors_0_1 = tensors_preprocessed
# if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']:
skip_unconditional = self.sd_ref().is_flux