diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 51b60170..10c8613f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -321,9 +321,17 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.ema is not None: self.ema.eval() + # let adapter know we are sampling + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + self.adapter.is_sampling = True + # send to be generated self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + self.adapter.is_sampling = False + if self.ema is not None: self.ema.train() @@ -579,7 +587,7 @@ class BaseSDTrainProcess(BaseTrainProcess): direct_save = True if self.adapter_config.type == 'redux': direct_save = True - if self.adapter_config.type in ['control_lora', 'subpixel']: + if self.adapter_config.type in ['control_lora', 'subpixel', 'i2v']: direct_save = True save_ip_adapter_from_diffusers( state_dict, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 6346d554..de6c8ba1 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -151,14 +151,14 @@ class NetworkConfig: self.lokr_factor = kwargs.get('lokr_factor', -1) -AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora'] +AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v'] CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state'] class AdapterConfig: def __init__(self, **kwargs): - self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net + self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net, i2v self.in_channels: int = kwargs.get('in_channels', 3) self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) @@ -255,6 +255,10 @@ class AdapterConfig: # for subpixel adapter self.subpixel_downscale_factor: int = kwargs.get('subpixel_downscale_factor', 8) + + # for i2v adapter + # append the masked start frame. During pretraining we will only do the vision encoder + self.i2v_do_start_frame: bool = kwargs.get('i2v_do_start_frame', False) class EmbeddingConfig: @@ -955,6 +959,8 @@ class GenerateImageConfig: # video if self.num_frames == 1: raise ValueError(f"Expected 1 img but got a list {len(image)}") + if self.num_frames > 1 and self.output_ext not in ['webp']: + self.output_ext = 'webp' if self.output_ext == 'webp': # save as animated webp duration = 1000 // self.fps # Convert fps to milliseconds per frame @@ -1075,6 +1081,8 @@ class GenerateImageConfig: self.extra_values = [float(val) for val in content.split(',')] elif flag == 'frames': self.num_frames = int(content) + elif flag == 'num_frames': + self.num_frames = int(content) elif flag == 'fps': self.fps = int(content) elif flag == 'ctrl_img': diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index d2e4572d..1ac999e5 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.control_lora_adapter import ControlLoraAdapter +from toolkit.models.i2v_adapter import I2VAdapter from toolkit.models.subpixel_adapter import SubpixelAdapter from toolkit.models.ilora import InstantLoRAModule from toolkit.models.single_value_adapter import SingleValueAdapter @@ -76,6 +77,7 @@ class CustomAdapter(torch.nn.Module): self.is_active = True self.flag_word = "fla9wor0" self.is_unconditional_run = False + self.is_sampling = False self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None @@ -105,6 +107,7 @@ class CustomAdapter(torch.nn.Module): self.redux_adapter: ReduxImageEncoder = None self.control_lora: ControlLoraAdapter = None self.subpixel_adapter: SubpixelAdapter = None + self.i2v_adapter: I2VAdapter = None self.conditional_embeds: Optional[torch.Tensor] = None self.unconditional_embeds: Optional[torch.Tensor] = None @@ -255,6 +258,15 @@ class CustomAdapter(torch.nn.Module): config=self.config, train_config=self.train_config ) + elif self.adapter_type == 'i2v': + self.i2v_adapter = I2VAdapter( + self, + sd=self.sd_ref(), + config=self.config, + train_config=self.train_config, + image_processor=self.image_processor, + vision_encoder=self.vision_encoder, + ) elif self.adapter_type == 'subpixel': self.subpixel_adapter = SubpixelAdapter( self, @@ -512,6 +524,14 @@ class CustomAdapter(torch.nn.Module): new_dict[k + '.' + k2] = v2 self.control_lora.load_weights(new_dict, strict=strict) + if self.adapter_type == 'i2v': + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.i2v_adapter.load_weights(new_dict, strict=strict) + if self.adapter_type == 'subpixel': # state dict is seperated. so recombine it new_dict = {} @@ -575,6 +595,11 @@ class CustomAdapter(torch.nn.Module): for k, v in d.items(): state_dict[k] = v return state_dict + elif self.adapter_type == 'i2v': + d = self.i2v_adapter.get_state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict elif self.adapter_type == 'subpixel': d = self.subpixel_adapter.get_state_dict() for k, v in d.items(): @@ -592,7 +617,11 @@ class CustomAdapter(torch.nn.Module): def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO): with torch.no_grad(): - if self.adapter_type in ['control_lora']: + # todo add i2v start frame conditioning here + + if self.adapter_type in ['i2v']: + return self.i2v_adapter.condition_noisy_latents(latents, batch) + elif self.adapter_type in ['control_lora']: # inpainting input is 0-1 (bs, 4, h, w) on batch.inpaint_tensor # 4th channel is the mask with 1 being keep area and 0 being area to inpaint. sd: StableDiffusion = self.sd_ref() @@ -724,7 +753,7 @@ class CustomAdapter(torch.nn.Module): prompt: Union[List[str], str], is_unconditional: bool = False, ): - if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel']: + if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel', 'i2v']: return prompt elif self.adapter_type == 'text_encoder': # todo allow for training @@ -1036,7 +1065,8 @@ class CustomAdapter(torch.nn.Module): quad_count=4, batch_size=1, ) -> PromptEmbeds: - if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + # if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + if self.adapter_type in ['ilora', 'vision_direct', 'te_augmenter', 'i2v']: skip_unconditional = self.sd_ref().is_flux if tensors_0_1 is None: tensors_0_1 = self.get_empty_clip_image(batch_size) @@ -1091,7 +1121,22 @@ class CustomAdapter(torch.nn.Module): batch_size = clip_image.shape[0] - if (self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter') and not skip_unconditional: + if self.config.control_image_dropout > 0 and is_training: + clip_batch = torch.chunk(clip_image, batch_size, dim=0) + unconditional_batch = torch.chunk(self.get_empty_clip_image(batch_size, shape=clip_image.shape).to( + clip_image.device, dtype=clip_image.dtype + ), batch_size, dim=0) + combine_list = [] + for i in range(batch_size): + do_dropout = random.random() < self.config.control_image_dropout + if do_dropout: + # dropout with noise + combine_list.append(unconditional_batch[i]) + else: + combine_list.append(clip_batch[i]) + clip_image = torch.cat(combine_list, dim=0) + + if self.adapter_type in ['vision_direct', 'te_augmenter', 'i2v'] and not skip_unconditional: # add an unconditional so we can save it unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to( clip_image.device, dtype=clip_image.dtype @@ -1153,7 +1198,8 @@ class CustomAdapter(torch.nn.Module): img_embeds = img_embeds.detach() self.ilora_module(img_embeds) - if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + # if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + if self.adapter_type in ['vision_direct', 'te_augmenter', 'i2v']: with torch.set_grad_enabled(is_training): if is_training and self.config.train_image_encoder: self.vision_encoder.train() @@ -1248,6 +1294,10 @@ class CustomAdapter(torch.nn.Module): param_list = self.control_lora.get_params() for param in param_list: yield param + elif self.config.type == 'i2v': + param_list = self.i2v_adapter.get_params() + for param in param_list: + yield param elif self.config.type == 'subpixel': param_list = self.subpixel_adapter.get_params() for param in param_list: diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 9244a7c1..00623fa0 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -161,7 +161,8 @@ class DiffusionFeatureExtractor3(nn.Module): vae = AutoencoderTiny.from_pretrained( "madebyollin/taef1", torch_dtype=torch.bfloat16) self.vae = vae - image_encoder_path = "google/siglip-so400m-patch14-384" + # image_encoder_path = "google/siglip-so400m-patch14-384" + image_encoder_path = "google/siglip2-so400m-patch16-512" try: self.image_processor = SiglipImageProcessor.from_pretrained( image_encoder_path) @@ -182,7 +183,11 @@ class DiffusionFeatureExtractor3(nn.Module): dtype = torch.bfloat16 device = self.vae.device # resize to 384x384 - images = F.interpolate(tensors_0_1, size=(384, 384), + if 'height' in self.image_processor.size: + size = self.image_processor.size['height'] + else: + size = self.image_processor.crop_size['height'] + images = F.interpolate(tensors_0_1, size=(size, size), mode='bicubic', align_corners=False) mean = torch.tensor(self.image_processor.image_mean).to( diff --git a/toolkit/models/i2v_adapter.py b/toolkit/models/i2v_adapter.py new file mode 100644 index 00000000..09f45995 --- /dev/null +++ b/toolkit/models/i2v_adapter.py @@ -0,0 +1,598 @@ +from functools import partial +import inspect +import weakref +import torch +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.lora_special import LoRASpecialNetwork +from diffusers import WanTransformer3DModel +from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProcessor, CLIPVisionModelWithProjection +from diffusers.models.attention_processor import Attention +from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding + +if TYPE_CHECKING: + from toolkit.models.base_model import BaseModel + from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig + from toolkit.custom_adapter import CustomAdapter + +import torch.nn.functional as F + +# modified to set the image embedder size +class WanAttnProcessor2_0: + def __init__(self, num_img_tokens:int=257): + self.num_img_tokens = num_img_tokens + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + encoder_hidden_states_img = encoder_hidden_states[:, :self.num_img_tokens] + encoder_hidden_states = encoder_hidden_states[:, self.num_img_tokens:] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + hidden_states_img = F.scaled_dot_product_attention( + query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class FrameEmbedder(torch.nn.Module): + def __init__( + self, + adapter: 'I2VAdapter', + orig_layer: torch.nn.Linear, + in_channels=64, + out_channels=3072 + ): + super().__init__() + # only do the weight for the new input. We combine with the original linear layer + init = torch.randn(out_channels, in_channels, + device=orig_layer.weight.device, dtype=orig_layer.weight.dtype) * 0.01 + self.weight = torch.nn.Parameter(init) + + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.orig_layer_ref: weakref.ref = weakref.ref(orig_layer) + + @classmethod + def from_model( + cls, + model: WanTransformer3DModel, + adapter: 'I2VAdapter', + num_control_images=1, + has_inpainting_input=False + ): + # TODO implement this + if model.__class__.__name__ == 'WanTransformer3DModel': + num_adapter_in_channels = model.x_embedder.in_features * num_control_images + + if has_inpainting_input: + # inpainting has the mask before packing latents. it is normally 16 ch + 1ch mask + # packed it is 64ch + 4ch mask + # so we need to add 4 to the input channels + num_adapter_in_channels += 4 + + x_embedder: torch.nn.Linear = model.x_embedder + img_embedder = cls( + adapter, + orig_layer=x_embedder, + in_channels=num_adapter_in_channels, + out_channels=x_embedder.out_features, + ) + + # hijack the forward method + x_embedder._orig_ctrl_lora_forward = x_embedder.forward + x_embedder.forward = img_embedder.forward + + # update the config of the transformer + model.config.in_channels = model.config.in_channels * \ + (num_control_images + 1) + model.config["in_channels"] = model.config.in_channels + + return img_embedder + else: + raise ValueError("Model not supported") + + @property + def is_active(self): + return self.adapter_ref().is_active + + def forward(self, x): + if not self.is_active: + # make sure lora is not active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = False + return self.orig_layer_ref()._orig_ctrl_lora_forward(x) + + # make sure lora is active + if self.adapter_ref().control_lora is not None: + self.adapter_ref().control_lora.is_active = True + + orig_device = x.device + orig_dtype = x.dtype + + x = x.to(self.weight.device, dtype=self.weight.dtype) + + orig_weight = self.orig_layer_ref().weight.data.detach() + orig_weight = orig_weight.to( + self.weight.device, dtype=self.weight.dtype) + linear_weight = torch.cat([orig_weight, self.weight], dim=1) + + bias = None + if self.orig_layer_ref().bias is not None: + bias = self.orig_layer_ref().bias.data.detach().to( + self.weight.device, dtype=self.weight.dtype) + + x = torch.nn.functional.linear(x, linear_weight, bias) + + x = x.to(orig_device, dtype=orig_dtype) + return x + + +def deactivatable_forward( + self: 'Attention', + *args, + **kwargs +): + if self._attn_hog_ref() is not None and self._attn_hog_ref().is_active: + self.added_kv_proj_dim = None + self.add_k_proj = self._add_k_proj + self.add_v_proj = self._add_v_proj + self.norm_added_q = self._norm_added_q + self.norm_added_k = self._norm_added_k + else: + self.added_kv_proj_dim = self._attn_hog_ref().added_kv_proj_dim + self.add_k_proj = None + self.add_v_proj = None + self.norm_added_q = None + self.norm_added_k = None + return self._orig_forward(*args, **kwargs) + + +class AttentionHog(torch.nn.Module): + def __init__( + self, + added_kv_proj_dim: int, + adapter: 'I2VAdapter', + attn_layer: Attention, + model: 'WanTransformer3DModel', + ): + super().__init__() + + # To prevent circular import. + from diffusers.models.normalization import FP32LayerNorm, LpNorm, RMSNorm + + self.added_kv_proj_dim = added_kv_proj_dim + self.attn_layer_ref: weakref.ref = weakref.ref(attn_layer) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.model_ref: weakref.ref = weakref.ref(model) + + qk_norm = model.config.qk_norm + + # layers + self.add_k_proj = torch.nn.Linear( + added_kv_proj_dim, + attn_layer.inner_kv_dim, + bias=attn_layer.added_proj_bias + ) + self.add_k_proj.weight.data = self.add_k_proj.weight.data * 0.001 + self.add_v_proj = torch.nn.Linear( + added_kv_proj_dim, + attn_layer.inner_kv_dim, + bias=attn_layer.added_proj_bias + ) + self.add_v_proj.weight.data = self.add_v_proj.weight.data * 0.001 + + # do qk norm. It isnt stored in the class, but we can infer it from the attn layer + self.norm_added_q = None + self.norm_added_k = None + + if attn_layer.norm_q is not None: + eps: float = 1e-5 + if qk_norm == "layer_norm": + self.norm_added_q = torch.nn.LayerNorm( + attn_layer.norm_q.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_q.elementwise_affine) + self.norm_added_k = torch.nn.LayerNorm( + attn_layer.norm_k.normalized_shape, eps=eps, elementwise_affine=attn_layer.norm_k.elementwise_affine) + elif qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm( + attn_layer.norm_q.normalized_shape, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm( + attn_layer.norm_k.normalized_shape, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps) + self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # Wanx applies qk norm across all heads + self.norm_added_q = RMSNorm(attn_layer.norm_q.dim, eps=eps) + self.norm_added_k = RMSNorm(attn_layer.norm_k.dim, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) + + # add these to the attn later in a way they can be deactivated + attn_layer._add_k_proj = self.add_k_proj + attn_layer._add_v_proj = self.add_v_proj + attn_layer._norm_added_q = self.norm_added_q + attn_layer._norm_added_k = self.norm_added_k + + # make it deactivateable + attn_layer._attn_hog_ref = weakref.ref(self) + attn_layer._orig_forward = attn_layer.forward + attn_layer.forward = partial(deactivatable_forward, attn_layer) + + def forward(self, *args, **kwargs): + if not self.adapter_ref().is_active: + return self.attn_module(*args, **kwargs) + + # TODO implement this + raise NotImplementedError("Attention hog not implemented") + + def is_active(self): + return self.adapter_ref().is_active + + +def new_wan_forward( + self: WanTransformer3DModel, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, +) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + adapter:'I2VAdapter' = self._i2v_adapter_ref() + + if adapter.is_active: + # activate the condition embedder + self.condition_embedder.image_embedder = adapter.image_embedder + + # for wan they are putting the image emcoder embeds on the unconditional + # this needs to be fixed as that wont work. For now, we will will use the embeds we have in order + # we cache an conditional and an unconditional embed. On sampling, it samples conditional first, + # then unconditional. So we just need to keep track of which one we are using. This is a horrible hack + # TODO find a not stupid way to do this. + + if adapter.adapter_ref().is_sampling: + if not hasattr(self, '_do_unconditional'): + # set it to true so we alternate to false immediatly + self._do_unconditional = True + + # alternate it + self._do_unconditional = not self._do_unconditional + if self._do_unconditional: + # slightly reduce strength of conditional for the unconditional + encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds * 0.5 + # encoder_hidden_states_image = adapter.adapter_ref().unconditional_embeds + else: + # use the conditional + encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds + else: + # doing a normal training run, always use conditional embeds + encoder_hidden_states_image = adapter.adapter_ref().conditional_embeds + + else: + # not active deactivate the condition embedder + self.condition_embedder.image_embedder = None + + return self._orig_i2v_adapter_forward( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=encoder_hidden_states_image, + return_dict=return_dict, + attention_kwargs=attention_kwargs, + ) + + +class I2VAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'BaseModel', + config: 'AdapterConfig', + train_config: 'TrainConfig', + image_processor: Union[SiglipImageProcessor, CLIPImageProcessor], + vision_encoder: Union[SiglipVisionModel, CLIPVisionModelWithProjection], + ): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref = weakref.ref(sd) + self.model_config: ModelConfig = sd.model_config + self.network_config = config.lora_config + self.train_config = train_config + self.config = config + self.device_torch = sd.device_torch + self.control_lora = None + self.image_processor_ref: weakref.ref = weakref.ref(image_processor) + self.vision_encoder_ref: weakref.ref = weakref.ref(vision_encoder) + + ve_img_size = vision_encoder.config.image_size + ve_patch_size = vision_encoder.config.patch_size + num_patches = (ve_img_size // ve_patch_size) ** 2 + num_vision_tokens = num_patches + + # siglip does not have a class token + if not vision_encoder.__class__.__name__.lower().startswith("siglip"): + num_vision_tokens = num_patches + 1 + + model_class = sd.model.__class__.__name__ + + if self.network_config is not None: + + network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs + if hasattr(sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = self.sd.target_lora_modules + + if 'ignore_if_contains' not in network_kwargs: + network_kwargs['ignore_if_contains'] = [] + + network_kwargs['ignore_if_contains'] += [ + 'add_k_proj', + 'add_v_proj', + 'norm_added_q', + 'norm_added_k', + ] + if model_class == 'WanTransformer3DModel': + # always ignore patch_embedding + network_kwargs['ignore_if_contains'].append('patch_embedding') + + self.control_lora = LoRASpecialNetwork( + text_encoder=sd.text_encoder, + unet=sd.unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=False, + is_lorm=False, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=sd.is_transformer, + base_model=sd, + **network_kwargs + ) + self.control_lora.force_to(self.device_torch, dtype=torch.float32) + self.control_lora._update_torch_multiplier() + self.control_lora.apply_to( + sd.text_encoder, + sd.unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + self.control_lora.can_merge_in = False + self.control_lora.prepare_grad_etc(sd.text_encoder, sd.unet) + if self.train_config.gradient_checkpointing: + self.control_lora.enable_gradient_checkpointing() + + self.frame_embedder: FrameEmbedder = None + if self.config.i2v_do_start_frame: + self.frame_embedder = FrameEmbedder.from_model( + sd.unet, + self, + num_control_images=config.num_control_images, + has_inpainting_input=config.has_inpainting_input + ) + self.frame_embedder.to(self.device_torch) + + # hijack the blocks so we can inject our vision encoder + attn_hog_list = [] + if model_class == 'WanTransformer3DModel': + added_kv_proj_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim + # update the model so it can accept the new input + # wan has i2v with clip-h for i2v, additional k v attn that directly takes + # in the penultimate_hidden_states from the vision encoder + # the kv is on blocks[0].attn2 + sd.model.config.added_kv_proj_dim = added_kv_proj_dim + sd.model.config['added_kv_proj_dim'] = added_kv_proj_dim + + transformer: WanTransformer3DModel = sd.model + for block in transformer.blocks: + block.attn2.added_kv_proj_dim = added_kv_proj_dim + attn_module = AttentionHog( + added_kv_proj_dim, + self, + block.attn2, + transformer + ) + # set the attn function to ours that handles custom number of vision tokens + block.attn2.set_processor(WanAttnProcessor2_0(num_vision_tokens)) + + attn_hog_list.append(attn_module) + else: + raise ValueError(f"Model {model_class} not supported") + + self.attn_hog_list = torch.nn.ModuleList(attn_hog_list) + self.attn_hog_list.to(self.device_torch) + + inner_dim = sd.model.config.num_attention_heads * sd.model.config.attention_head_dim + image_embed_dim = vision_encoder.config.hidden_size + self.image_embedder = WanImageEmbedding(image_embed_dim, inner_dim) + + # override the forward method + if model_class == 'WanTransformer3DModel': + self.sd_ref().model._orig_i2v_adapter_forward = self.sd_ref().model.forward + self.sd_ref().model.forward = partial( + new_wan_forward, + self.sd_ref().model + ) + + # add the wan image embedder + self.sd_ref().model.condition_embedder._image_embedder = self.image_embedder + self.sd_ref().model.condition_embedder._image_embedder.to(self.device_torch) + + self.sd_ref().model._i2v_adapter_ref = weakref.ref(self) + + def get_params(self): + if self.control_lora is not None: + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.control_lora.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr + if 'learning_rate' in sig.parameters: + config['learning_rate'] = self.train_config.lr + params_net = self.control_lora.prepare_optimizer_params( + **config + ) + + # we want only tensors here + params = [] + for p in params_net: + if isinstance(p, dict): + params += p["params"] + elif isinstance(p, torch.Tensor): + params.append(p) + elif isinstance(p, list): + params += p + else: + params = [] + + if self.frame_embedder is not None: + # make sure the embedder is float32 + self.frame_embedder.to(torch.float32) + params += list(self.frame_embedder.parameters()) + + # add the attn hogs + for attn_hog in self.attn_hog_list: + params += list(attn_hog.parameters()) + + # add the image embedder + if self.image_embedder is not None: + params += list(self.image_embedder.parameters()) + return params + + def load_weights(self, state_dict, strict=True): + lora_sd = {} + attn_hog_sd = {} + frame_embedder_sd = {} + image_embedder_sd = {} + + for key, value in state_dict.items(): + if "frame_embedder" in key: + new_key = key.replace("frame_embedder.", "") + frame_embedder_sd[new_key] = value + elif "attn_hog" in key: + new_key = key.replace("attn_hog.", "") + attn_hog_sd[new_key] = value + elif "image_embedder" in key: + new_key = key.replace("image_embedder.", "") + image_embedder_sd[new_key] = value + else: + lora_sd[key] = value + + # todo process state dict before loading + if self.control_lora is not None: + self.control_lora.load_weights(lora_sd) + if self.frame_embedder is not None: + self.frame_embedder.load_state_dict( + frame_embedder_sd, strict=False) + self.attn_hog_list.load_state_dict( + attn_hog_sd, strict=False) + self.image_embedder.load_state_dict( + image_embedder_sd, strict=False) + + def get_state_dict(self): + if self.control_lora is not None: + lora_sd = self.control_lora.get_state_dict(dtype=torch.float32) + else: + lora_sd = {} + + if self.frame_embedder is not None: + frame_embedder_sd = self.frame_embedder.state_dict() + for key, value in frame_embedder_sd.items(): + lora_sd[f"frame_embedder.{key}"] = value + + # add the attn hogs + attn_hog_sd = self.attn_hog_list.state_dict() + for key, value in attn_hog_sd.items(): + lora_sd[f"attn_hog.{key}"] = value + + # add the image embedder + image_embedder_sd = self.image_embedder.state_dict() + for key, value in image_embedder_sd.items(): + lora_sd[f"image_embedder.{key}"] = value + + return lora_sd + + def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO): + # todo handle start frame + return latents + + @property + def is_active(self): + return self.adapter_ref().is_active diff --git a/toolkit/models/wan21/wan21_i2v.py b/toolkit/models/wan21/wan21_i2v.py index ace2593e..3a7a0181 100644 --- a/toolkit/models/wan21/wan21_i2v.py +++ b/toolkit/models/wan21/wan21_i2v.py @@ -211,7 +211,7 @@ class AggressiveWanI2VUnloadPipeline(WanImageToVideoPipeline): hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_image=image_embeds, + encoder_hidden_states_image=image_embeds, # todo I think unconditional should be scaled down version attention_kwargs=attention_kwargs, return_dict=False, )[0] diff --git a/toolkit/sampler.py b/toolkit/sampler.py index bf3fa2e8..47fdc205 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -164,11 +164,12 @@ def get_sampler( config_to_use = copy.deepcopy(flux_config) if arch == "sd": config_to_use = copy.deepcopy(sd_flow_config) - if arch == "flux": + elif arch == "flux": config_to_use = copy.deepcopy(flux_config) elif arch == "lumina2": config_to_use = copy.deepcopy(lumina2_config) else: + print(f"Unknown architecture {arch}, using default flux config") # use flux by default config_to_use = copy.deepcopy(flux_config) else: