mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP on wan
This commit is contained in:
8
.gitmodules
vendored
8
.gitmodules
vendored
@@ -1,12 +1,20 @@
|
||||
[submodule "repositories/sd-scripts"]
|
||||
path = repositories/sd-scripts
|
||||
url = https://github.com/kohya-ss/sd-scripts.git
|
||||
commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c
|
||||
[submodule "repositories/leco"]
|
||||
path = repositories/leco
|
||||
url = https://github.com/p1atdev/LECO
|
||||
commit = 9294adf40218e917df4516737afb13f069a6789d
|
||||
[submodule "repositories/batch_annotator"]
|
||||
path = repositories/batch_annotator
|
||||
url = https://github.com/ostris/batch-annotator
|
||||
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
|
||||
[submodule "repositories/ipadapter"]
|
||||
path = repositories/ipadapter
|
||||
url = https://github.com/tencent-ailab/IP-Adapter.git
|
||||
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14
|
||||
[submodule "repositories/wan21"]
|
||||
path = repositories/wan21
|
||||
url = https://github.com/Wan-Video/Wan2.1.git
|
||||
commit = a326079926a4a347ecda8863dc40ba2d7680a294
|
||||
1
repositories/wan21
Submodule
1
repositories/wan21
Submodule
Submodule repositories/wan21 added at a326079926
@@ -1,8 +1,7 @@
|
||||
torch==2.5.1
|
||||
torchvision==0.20.1
|
||||
safetensors
|
||||
# https://github.com/huggingface/diffusers/pull/10921
|
||||
git+https://github.com/huggingface/diffusers@refs/pull/10921/head
|
||||
git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec
|
||||
transformers
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
|
||||
@@ -3,7 +3,38 @@ import torch
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
import os
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
WAN_ROOT = os.path.join(REPOS_ROOT, "wan21")
|
||||
sys.path.append(WAN_ROOT)
|
||||
|
||||
if True:
|
||||
from wan.text2video import WanT2V
|
||||
from wan.distributed.fsdp import shard_model
|
||||
from wan.modules.model import WanModel
|
||||
from wan.modules.t5 import T5EncoderModel
|
||||
from wan.modules.vae import WanVAE
|
||||
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps)
|
||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
class Wan21(BaseModel):
|
||||
def __init__(
|
||||
@@ -21,9 +52,9 @@ class Wan21(BaseModel):
|
||||
# these must be implemented in child classes
|
||||
|
||||
def load_model(self):
|
||||
# override this in child classes
|
||||
raise NotImplementedError(
|
||||
"load_model must be implemented in child classes")
|
||||
self.pipeline = Wan21(
|
||||
|
||||
)
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
# override this in child classes
|
||||
|
||||
Reference in New Issue
Block a user