mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Cleanup
This commit is contained in:
@@ -631,6 +631,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json')
|
||||
with open(path_to_save, 'w') as f:
|
||||
json.dump(json_data, f, indent=4)
|
||||
|
||||
print_acc(f"Saved checkpoint to {file_path}")
|
||||
|
||||
# save optimizer
|
||||
if self.optimizer is not None:
|
||||
@@ -639,11 +641,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
print_acc(f"Saved optimizer to {file_path}")
|
||||
except Exception as e:
|
||||
print_acc(e)
|
||||
print_acc("Could not save optimizer")
|
||||
|
||||
print_acc(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
|
||||
@@ -2095,7 +2097,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# print above the progress bar
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
print_acc(f"Saving at step {self.step_num}")
|
||||
print_acc(f"\nSaving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
if self.progress_bar is not None:
|
||||
|
||||
@@ -10,88 +10,13 @@ from transformers import SiglipImageProcessor, SiglipVisionModel, CLIPImageProce
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.transformers.transformer_wan import WanImageEmbedding, WanTimeTextImageEmbedding
|
||||
from toolkit.util.shuffle import shuffle_tensor_along_axis
|
||||
import torch.nn.functional as F
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
# modified to set the image embedder size
|
||||
class WanAttnProcessor2_0:
|
||||
def __init__(self, num_img_tokens:int=257):
|
||||
self.num_img_tokens = num_img_tokens
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :self.num_img_tokens]
|
||||
encoder_hidden_states = encoder_hidden_states[:, self.num_img_tokens:]
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
||||
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
|
||||
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
||||
return x_out.type_as(hidden_states)
|
||||
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
if encoder_hidden_states_img is not None:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
key_img = attn.norm_added_k(key_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
hidden_states_img = F.scaled_dot_product_attention(
|
||||
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
if hidden_states_img is not None:
|
||||
hidden_states = hidden_states + hidden_states_img
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FrameEmbedder(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -385,6 +310,8 @@ class I2VAdapter(torch.nn.Module):
|
||||
vision_encoder: Union[SiglipVisionModel, CLIPVisionModelWithProjection],
|
||||
):
|
||||
super().__init__()
|
||||
# avoid circular import
|
||||
from toolkit.models.wan21.wan_attn import WanAttnProcessor2_0
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.model_config: ModelConfig = sd.model_config
|
||||
|
||||
84
toolkit/models/wan21/wan_attn.py
Normal file
84
toolkit/models/wan21/wan_attn.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from diffusers.models.attention_processor import Attention
|
||||
|
||||
|
||||
# modified to set the image embedder size
|
||||
class WanAttnProcessor2_0:
|
||||
def __init__(self, num_img_tokens: int = 257):
|
||||
self.num_img_tokens = num_img_tokens
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
encoder_hidden_states_img = encoder_hidden_states[:,
|
||||
:self.num_img_tokens]
|
||||
encoder_hidden_states = encoder_hidden_states[:,
|
||||
self.num_img_tokens:]
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
||||
x_rotated = torch.view_as_complex(
|
||||
hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
|
||||
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
||||
return x_out.type_as(hidden_states)
|
||||
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
if encoder_hidden_states_img is not None:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
key_img = attn.norm_added_k(key_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value_img = value_img.unflatten(
|
||||
2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
hidden_states_img = F.scaled_dot_product_attention(
|
||||
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
if hidden_states_img is not None:
|
||||
hidden_states = hidden_states + hidden_states_img
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
Reference in New Issue
Block a user