WIP on wan

This commit is contained in:
Jaret Burkett
2025-03-01 16:12:52 -07:00
parent acc79956aa
commit f5e40dfa62
4 changed files with 45 additions and 6 deletions

8
.gitmodules vendored
View File

@@ -1,12 +1,20 @@
[submodule "repositories/sd-scripts"] [submodule "repositories/sd-scripts"]
path = repositories/sd-scripts path = repositories/sd-scripts
url = https://github.com/kohya-ss/sd-scripts.git url = https://github.com/kohya-ss/sd-scripts.git
commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c
[submodule "repositories/leco"] [submodule "repositories/leco"]
path = repositories/leco path = repositories/leco
url = https://github.com/p1atdev/LECO url = https://github.com/p1atdev/LECO
commit = 9294adf40218e917df4516737afb13f069a6789d
[submodule "repositories/batch_annotator"] [submodule "repositories/batch_annotator"]
path = repositories/batch_annotator path = repositories/batch_annotator
url = https://github.com/ostris/batch-annotator url = https://github.com/ostris/batch-annotator
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
[submodule "repositories/ipadapter"] [submodule "repositories/ipadapter"]
path = repositories/ipadapter path = repositories/ipadapter
url = https://github.com/tencent-ailab/IP-Adapter.git url = https://github.com/tencent-ailab/IP-Adapter.git
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14
[submodule "repositories/wan21"]
path = repositories/wan21
url = https://github.com/Wan-Video/Wan2.1.git
commit = a326079926a4a347ecda8863dc40ba2d7680a294

1
repositories/wan21 Submodule

Submodule repositories/wan21 added at a326079926

View File

@@ -1,8 +1,7 @@
torch==2.5.1 torch==2.5.1
torchvision==0.20.1 torchvision==0.20.1
safetensors safetensors
# https://github.com/huggingface/diffusers/pull/10921 git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec
git+https://github.com/huggingface/diffusers@refs/pull/10921/head
transformers transformers
lycoris-lora==1.8.3 lycoris-lora==1.8.3
flatten_json flatten_json

View File

@@ -3,7 +3,38 @@ import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.models.base_model import BaseModel from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline from toolkit.paths import REPOS_ROOT
import sys
import os
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
WAN_ROOT = os.path.join(REPOS_ROOT, "wan21")
sys.path.append(WAN_ROOT)
if True:
from wan.text2video import WanT2V
from wan.distributed.fsdp import shard_model
from wan.modules.model import WanModel
from wan.modules.t5 import T5EncoderModel
from wan.modules.vae import WanVAE
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class Wan21(BaseModel): class Wan21(BaseModel):
def __init__( def __init__(
@@ -21,9 +52,9 @@ class Wan21(BaseModel):
# these must be implemented in child classes # these must be implemented in child classes
def load_model(self): def load_model(self):
# override this in child classes self.pipeline = Wan21(
raise NotImplementedError(
"load_model must be implemented in child classes") )
def get_generation_pipeline(self): def get_generation_pipeline(self):
# override this in child classes # override this in child classes