Added initial support for training i2v adapter WIP

This commit is contained in:
Jaret Burkett
2025-04-09 08:06:29 -06:00
parent a8680c75eb
commit 615b0d0e94
7 changed files with 682 additions and 12 deletions

View File

@@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.control_lora_adapter import ControlLoraAdapter
from toolkit.models.i2v_adapter import I2VAdapter
from toolkit.models.subpixel_adapter import SubpixelAdapter
from toolkit.models.ilora import InstantLoRAModule
from toolkit.models.single_value_adapter import SingleValueAdapter
@@ -76,6 +77,7 @@ class CustomAdapter(torch.nn.Module):
self.is_active = True
self.flag_word = "fla9wor0"
self.is_unconditional_run = False
self.is_sampling = False
self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None
@@ -105,6 +107,7 @@ class CustomAdapter(torch.nn.Module):
self.redux_adapter: ReduxImageEncoder = None
self.control_lora: ControlLoraAdapter = None
self.subpixel_adapter: SubpixelAdapter = None
self.i2v_adapter: I2VAdapter = None
self.conditional_embeds: Optional[torch.Tensor] = None
self.unconditional_embeds: Optional[torch.Tensor] = None
@@ -255,6 +258,15 @@ class CustomAdapter(torch.nn.Module):
config=self.config,
train_config=self.train_config
)
elif self.adapter_type == 'i2v':
self.i2v_adapter = I2VAdapter(
self,
sd=self.sd_ref(),
config=self.config,
train_config=self.train_config,
image_processor=self.image_processor,
vision_encoder=self.vision_encoder,
)
elif self.adapter_type == 'subpixel':
self.subpixel_adapter = SubpixelAdapter(
self,
@@ -512,6 +524,14 @@ class CustomAdapter(torch.nn.Module):
new_dict[k + '.' + k2] = v2
self.control_lora.load_weights(new_dict, strict=strict)
if self.adapter_type == 'i2v':
# state dict is seperated. so recombine it
new_dict = {}
for k, v in state_dict.items():
for k2, v2 in v.items():
new_dict[k + '.' + k2] = v2
self.i2v_adapter.load_weights(new_dict, strict=strict)
if self.adapter_type == 'subpixel':
# state dict is seperated. so recombine it
new_dict = {}
@@ -575,6 +595,11 @@ class CustomAdapter(torch.nn.Module):
for k, v in d.items():
state_dict[k] = v
return state_dict
elif self.adapter_type == 'i2v':
d = self.i2v_adapter.get_state_dict()
for k, v in d.items():
state_dict[k] = v
return state_dict
elif self.adapter_type == 'subpixel':
d = self.subpixel_adapter.get_state_dict()
for k, v in d.items():
@@ -592,7 +617,11 @@ class CustomAdapter(torch.nn.Module):
def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO):
with torch.no_grad():
if self.adapter_type in ['control_lora']:
# todo add i2v start frame conditioning here
if self.adapter_type in ['i2v']:
return self.i2v_adapter.condition_noisy_latents(latents, batch)
elif self.adapter_type in ['control_lora']:
# inpainting input is 0-1 (bs, 4, h, w) on batch.inpaint_tensor
# 4th channel is the mask with 1 being keep area and 0 being area to inpaint.
sd: StableDiffusion = self.sd_ref()
@@ -724,7 +753,7 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str],
is_unconditional: bool = False,
):
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel']:
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v']:
return prompt
elif self.adapter_type == 'text_encoder':
# todo allow for training
@@ -1036,7 +1065,8 @@ class CustomAdapter(torch.nn.Module):
quad_count=4,
batch_size=1,
) -> PromptEmbeds:
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
# if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']:
skip_unconditional = self.sd_ref().is_flux
if tensors_0_1 is None:
tensors_0_1 = self.get_empty_clip_image(batch_size)
@@ -1091,7 +1121,22 @@ class CustomAdapter(torch.nn.Module):
batch_size = clip_image.shape[0]
if (self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter') and not skip_unconditional:
if self.config.control_image_dropout > 0 and is_training:
clip_batch = torch.chunk(clip_image, batch_size, dim=0)
unconditional_batch = torch.chunk(self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
clip_image.device, dtype=clip_image.dtype
), batch_size, dim=0)
combine_list = []
for i in range(batch_size):
do_dropout = random.random() < self.config.control_image_dropout
if do_dropout:
# dropout with noise
combine_list.append(unconditional_batch[i])
else:
combine_list.append(clip_batch[i])
clip_image = torch.cat(combine_list, dim=0)
if self.adapter_type in ['vision_direct', 'te_augmenter', 'i2v'] and not skip_unconditional:
# add an unconditional so we can save it
unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
clip_image.device, dtype=clip_image.dtype
@@ -1153,7 +1198,8 @@ class CustomAdapter(torch.nn.Module):
img_embeds = img_embeds.detach()
self.ilora_module(img_embeds)
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
# if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
if self.adapter_type in ['vision_direct', 'te_augmenter', 'i2v']:
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
@@ -1248,6 +1294,10 @@ class CustomAdapter(torch.nn.Module):
param_list = self.control_lora.get_params()
for param in param_list:
yield param
elif self.config.type == 'i2v':
param_list = self.i2v_adapter.get_params()
for param in param_list:
yield param
elif self.config.type == 'subpixel':
param_list = self.subpixel_adapter.get_params()
for param in param_list: