mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added initial support for training i2v adapter WIP
This commit is contained in:
@@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.control_lora_adapter import ControlLoraAdapter
|
||||
from toolkit.models.i2v_adapter import I2VAdapter
|
||||
from toolkit.models.subpixel_adapter import SubpixelAdapter
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
from toolkit.models.single_value_adapter import SingleValueAdapter
|
||||
@@ -76,6 +77,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.is_active = True
|
||||
self.flag_word = "fla9wor0"
|
||||
self.is_unconditional_run = False
|
||||
self.is_sampling = False
|
||||
|
||||
self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None
|
||||
|
||||
@@ -105,6 +107,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.redux_adapter: ReduxImageEncoder = None
|
||||
self.control_lora: ControlLoraAdapter = None
|
||||
self.subpixel_adapter: SubpixelAdapter = None
|
||||
self.i2v_adapter: I2VAdapter = None
|
||||
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
@@ -255,6 +258,15 @@ class CustomAdapter(torch.nn.Module):
|
||||
config=self.config,
|
||||
train_config=self.train_config
|
||||
)
|
||||
elif self.adapter_type == 'i2v':
|
||||
self.i2v_adapter = I2VAdapter(
|
||||
self,
|
||||
sd=self.sd_ref(),
|
||||
config=self.config,
|
||||
train_config=self.train_config,
|
||||
image_processor=self.image_processor,
|
||||
vision_encoder=self.vision_encoder,
|
||||
)
|
||||
elif self.adapter_type == 'subpixel':
|
||||
self.subpixel_adapter = SubpixelAdapter(
|
||||
self,
|
||||
@@ -512,6 +524,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.control_lora.load_weights(new_dict, strict=strict)
|
||||
|
||||
if self.adapter_type == 'i2v':
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.i2v_adapter.load_weights(new_dict, strict=strict)
|
||||
|
||||
if self.adapter_type == 'subpixel':
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
@@ -575,6 +595,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
elif self.adapter_type == 'i2v':
|
||||
d = self.i2v_adapter.get_state_dict()
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
elif self.adapter_type == 'subpixel':
|
||||
d = self.subpixel_adapter.get_state_dict()
|
||||
for k, v in d.items():
|
||||
@@ -592,7 +617,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO):
|
||||
with torch.no_grad():
|
||||
if self.adapter_type in ['control_lora']:
|
||||
# todo add i2v start frame conditioning here
|
||||
|
||||
if self.adapter_type in ['i2v']:
|
||||
return self.i2v_adapter.condition_noisy_latents(latents, batch)
|
||||
elif self.adapter_type in ['control_lora']:
|
||||
# inpainting input is 0-1 (bs, 4, h, w) on batch.inpaint_tensor
|
||||
# 4th channel is the mask with 1 being keep area and 0 being area to inpaint.
|
||||
sd: StableDiffusion = self.sd_ref()
|
||||
@@ -724,7 +753,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel']:
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v']:
|
||||
return prompt
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
# todo allow for training
|
||||
@@ -1036,7 +1065,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
quad_count=4,
|
||||
batch_size=1,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
# if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']:
|
||||
skip_unconditional = self.sd_ref().is_flux
|
||||
if tensors_0_1 is None:
|
||||
tensors_0_1 = self.get_empty_clip_image(batch_size)
|
||||
@@ -1091,7 +1121,22 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
|
||||
batch_size = clip_image.shape[0]
|
||||
if (self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter') and not skip_unconditional:
|
||||
if self.config.control_image_dropout > 0 and is_training:
|
||||
clip_batch = torch.chunk(clip_image, batch_size, dim=0)
|
||||
unconditional_batch = torch.chunk(self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
|
||||
clip_image.device, dtype=clip_image.dtype
|
||||
), batch_size, dim=0)
|
||||
combine_list = []
|
||||
for i in range(batch_size):
|
||||
do_dropout = random.random() < self.config.control_image_dropout
|
||||
if do_dropout:
|
||||
# dropout with noise
|
||||
combine_list.append(unconditional_batch[i])
|
||||
else:
|
||||
combine_list.append(clip_batch[i])
|
||||
clip_image = torch.cat(combine_list, dim=0)
|
||||
|
||||
if self.adapter_type in ['vision_direct', 'te_augmenter', 'i2v'] and not skip_unconditional:
|
||||
# add an unconditional so we can save it
|
||||
unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
|
||||
clip_image.device, dtype=clip_image.dtype
|
||||
@@ -1153,7 +1198,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
self.ilora_module(img_embeds)
|
||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
# if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
if self.adapter_type in ['vision_direct', 'te_augmenter', 'i2v']:
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
@@ -1248,6 +1294,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
param_list = self.control_lora.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
elif self.config.type == 'i2v':
|
||||
param_list = self.i2v_adapter.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
elif self.config.type == 'subpixel':
|
||||
param_list = self.subpixel_adapter.get_params()
|
||||
for param in param_list:
|
||||
|
||||
Reference in New Issue
Block a user