diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 10c8613f..8056e87b 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -631,6 +631,8 @@ class BaseSDTrainProcess(BaseTrainProcess): path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json') with open(path_to_save, 'w') as f: json.dump(json_data, f, indent=4) + + print_acc(f"Saved checkpoint to {file_path}") # save optimizer if self.optimizer is not None: @@ -639,11 +641,11 @@ class BaseSDTrainProcess(BaseTrainProcess): file_path = os.path.join(self.save_root, filename) state_dict = self.optimizer.state_dict() torch.save(state_dict, file_path) + print_acc(f"Saved optimizer to {file_path}") except Exception as e: print_acc(e) print_acc("Could not save optimizer") - print_acc(f"Saved to {file_path}") self.clean_up_saves() self.post_save_hook(file_path) @@ -2095,7 +2097,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # print above the progress bar if self.progress_bar is not None: self.progress_bar.pause() - print_acc(f"Saving at step {self.step_num}") + print_acc(f"\nSaving at step {self.step_num}") self.save(self.step_num) self.ensure_params_requires_grad() if self.progress_bar is not None: diff --git a/toolkit/models/i2v_adapter.py b/toolkit/models/i2v_adapter.py index f9615eac..27bc7238 100644 --- a/toolkit/models/i2v_adapter.py +++ b/toolkit/models/i2v_adapter.py @@ -10,88 +10,13 @@ from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProce from diffusers.models.attention_processor import Attention from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding from toolkit.util.shuffle import shuffle_tensor_along_axis +import torch.nn.functional as F if TYPE_CHECKING: from toolkit.models.base_model import BaseModel from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig from toolkit.custom_adapter import CustomAdapter -import torch.nn.functional as F - -# modified to set the image embedder size -class WanAttnProcessor2_0: - def __init__(self, num_img_tokens:int=257): - self.num_img_tokens = num_img_tokens - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - encoder_hidden_states_img = None - if attn.add_k_proj is not None: - encoder_hidden_states_img = encoder_hidden_states[:, :self.num_img_tokens] - encoder_hidden_states = encoder_hidden_states[:, self.num_img_tokens:] - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) - - if rotary_emb is not None: - - def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): - x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) - x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) - return x_out.type_as(hidden_states) - - query = apply_rotary_emb(query, rotary_emb) - key = apply_rotary_emb(key, rotary_emb) - - # I2V task - hidden_states_img = None - if encoder_hidden_states_img is not None: - key_img = attn.add_k_proj(encoder_hidden_states_img) - key_img = attn.norm_added_k(key_img) - value_img = attn.add_v_proj(encoder_hidden_states_img) - - key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - - hidden_states_img = F.scaled_dot_product_attention( - query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False - ) - hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) - hidden_states_img = hidden_states_img.type_as(query) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.type_as(query) - - if hidden_states_img is not None: - hidden_states = hidden_states + hidden_states_img - - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - class FrameEmbedder(torch.nn.Module): def __init__( @@ -385,6 +310,8 @@ class I2VAdapter(torch.nn.Module): vision_encoder: Union[SiglipVisionModel, CLIPVisionModelWithProjection], ): super().__init__() + # avoid circular import + from toolkit.models.wan21.wan_attn import WanAttnProcessor2_0 self.adapter_ref: weakref.ref = weakref.ref(adapter) self.sd_ref = weakref.ref(sd) self.model_config: ModelConfig = sd.model_config diff --git a/toolkit/models/wan21/wan_attn.py b/toolkit/models/wan21/wan_attn.py new file mode 100644 index 00000000..6cd93e1c --- /dev/null +++ b/toolkit/models/wan21/wan_attn.py @@ -0,0 +1,84 @@ +import torch +import torch.nn.functional as F +from typing import Optional +from diffusers.models.attention_processor import Attention + + +# modified to set the image embedder size +class WanAttnProcessor2_0: + def __init__(self, num_img_tokens: int = 257): + self.num_img_tokens = num_img_tokens + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + encoder_hidden_states_img = encoder_hidden_states[:, + :self.num_img_tokens] + encoder_hidden_states = encoder_hidden_states[:, + self.num_img_tokens:] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex( + hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value_img = value_img.unflatten( + 2, (attn.heads, -1)).transpose(1, 2) + + hidden_states_img = F.scaled_dot_product_attention( + query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states