This commit is contained in:
Jaret Burkett
2025-04-18 11:44:49 -06:00
parent 1628884254
commit d455e76c4f
3 changed files with 91 additions and 78 deletions

View File

@@ -631,6 +631,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json')
with open(path_to_save, 'w') as f:
json.dump(json_data, f, indent=4)
print_acc(f"Saved checkpoint to {file_path}")
# save optimizer
if self.optimizer is not None:
@@ -639,11 +641,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename)
state_dict = self.optimizer.state_dict()
torch.save(state_dict, file_path)
print_acc(f"Saved optimizer to {file_path}")
except Exception as e:
print_acc(e)
print_acc("Could not save optimizer")
print_acc(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
@@ -2095,7 +2097,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# print above the progress bar
if self.progress_bar is not None:
self.progress_bar.pause()
print_acc(f"Saving at step {self.step_num}")
print_acc(f"\nSaving at step {self.step_num}")
self.save(self.step_num)
self.ensure_params_requires_grad()
if self.progress_bar is not None:

View File

@@ -10,88 +10,13 @@ from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProce
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
from toolkit.util.shuffle import shuffle_tensor_along_axis
import torch.nn.functional as F
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
from toolkit.custom_adapter import CustomAdapter
import torch.nn.functional as F
# modified to set the image embedder size
class WanAttnProcessor2_0:
def __init__(self, num_img_tokens:int=257):
self.num_img_tokens = num_img_tokens
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
encoder_hidden_states_img = encoder_hidden_states[:, :self.num_img_tokens]
encoder_hidden_states = encoder_hidden_states[:, self.num_img_tokens:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class FrameEmbedder(torch.nn.Module):
def __init__(
@@ -385,6 +310,8 @@ class I2VAdapter(torch.nn.Module):
vision_encoder: Union[SiglipVisionModel, CLIPVisionModelWithProjection],
):
super().__init__()
# avoid circular import
from toolkit.models.wan21.wan_attn import WanAttnProcessor2_0
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.sd_ref = weakref.ref(sd)
self.model_config: ModelConfig = sd.model_config

View File

@@ -0,0 +1,84 @@
import torch
import torch.nn.functional as F
from typing import Optional
from diffusers.models.attention_processor import Attention
# modified to set the image embedder size
class WanAttnProcessor2_0:
def __init__(self, num_img_tokens: int = 257):
self.num_img_tokens = num_img_tokens
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
encoder_hidden_states_img = encoder_hidden_states[:,
:self.num_img_tokens]
encoder_hidden_states = encoder_hidden_states[:,
self.num_img_tokens:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(
hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(
2, (attn.heads, -1)).transpose(1, 2)
hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states