From f5e40dfa62ffeb4a7e3e5b45efd9861b978aec75 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 1 Mar 2025 16:12:52 -0700 Subject: [PATCH] WIP on wan --- .gitmodules | 8 ++++++++ repositories/wan21 | 1 + requirements.txt | 3 +-- toolkit/models/wan21.py | 39 +++++++++++++++++++++++++++++++++++---- 4 files changed, 45 insertions(+), 6 deletions(-) create mode 160000 repositories/wan21 diff --git a/.gitmodules b/.gitmodules index 657cf28b..ea80e2af 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,20 @@ [submodule "repositories/sd-scripts"] path = repositories/sd-scripts url = https://github.com/kohya-ss/sd-scripts.git + commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c [submodule "repositories/leco"] path = repositories/leco url = https://github.com/p1atdev/LECO + commit = 9294adf40218e917df4516737afb13f069a6789d [submodule "repositories/batch_annotator"] path = repositories/batch_annotator url = https://github.com/ostris/batch-annotator + commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf [submodule "repositories/ipadapter"] path = repositories/ipadapter url = https://github.com/tencent-ailab/IP-Adapter.git + commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14 +[submodule "repositories/wan21"] + path = repositories/wan21 + url = https://github.com/Wan-Video/Wan2.1.git + commit = a326079926a4a347ecda8863dc40ba2d7680a294 \ No newline at end of file diff --git a/repositories/wan21 b/repositories/wan21 new file mode 160000 index 00000000..a3260799 --- /dev/null +++ b/repositories/wan21 @@ -0,0 +1 @@ +Subproject commit a326079926a4a347ecda8863dc40ba2d7680a294 diff --git a/requirements.txt b/requirements.txt index f521b379..4040e760 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ torch==2.5.1 torchvision==0.20.1 safetensors -# https://github.com/huggingface/diffusers/pull/10921 -git+https://github.com/huggingface/diffusers@refs/pull/10921/head +git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec transformers lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py index b52017a2..f30c1cd6 100644 --- a/toolkit/models/wan21.py +++ b/toolkit/models/wan21.py @@ -3,7 +3,38 @@ import torch from toolkit.config_modules import GenerateImageConfig, ModelConfig from toolkit.models.base_model import BaseModel from toolkit.prompt_utils import PromptEmbeds -from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline +from toolkit.paths import REPOS_ROOT +import sys +import os + +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm + + +WAN_ROOT = os.path.join(REPOS_ROOT, "wan21") +sys.path.append(WAN_ROOT) + +if True: + from wan.text2video import WanT2V + from wan.distributed.fsdp import shard_model + from wan.modules.model import WanModel + from wan.modules.t5 import T5EncoderModel + from wan.modules.vae import WanVAE + from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) + from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler class Wan21(BaseModel): def __init__( @@ -21,9 +52,9 @@ class Wan21(BaseModel): # these must be implemented in child classes def load_model(self): - # override this in child classes - raise NotImplementedError( - "load_model must be implemented in child classes") + self.pipeline = Wan21( + + ) def get_generation_pipeline(self): # override this in child classes