mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
WIP on wan
This commit is contained in:
8
.gitmodules
vendored
8
.gitmodules
vendored
@@ -1,12 +1,20 @@
|
|||||||
[submodule "repositories/sd-scripts"]
|
[submodule "repositories/sd-scripts"]
|
||||||
path = repositories/sd-scripts
|
path = repositories/sd-scripts
|
||||||
url = https://github.com/kohya-ss/sd-scripts.git
|
url = https://github.com/kohya-ss/sd-scripts.git
|
||||||
|
commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c
|
||||||
[submodule "repositories/leco"]
|
[submodule "repositories/leco"]
|
||||||
path = repositories/leco
|
path = repositories/leco
|
||||||
url = https://github.com/p1atdev/LECO
|
url = https://github.com/p1atdev/LECO
|
||||||
|
commit = 9294adf40218e917df4516737afb13f069a6789d
|
||||||
[submodule "repositories/batch_annotator"]
|
[submodule "repositories/batch_annotator"]
|
||||||
path = repositories/batch_annotator
|
path = repositories/batch_annotator
|
||||||
url = https://github.com/ostris/batch-annotator
|
url = https://github.com/ostris/batch-annotator
|
||||||
|
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
|
||||||
[submodule "repositories/ipadapter"]
|
[submodule "repositories/ipadapter"]
|
||||||
path = repositories/ipadapter
|
path = repositories/ipadapter
|
||||||
url = https://github.com/tencent-ailab/IP-Adapter.git
|
url = https://github.com/tencent-ailab/IP-Adapter.git
|
||||||
|
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14
|
||||||
|
[submodule "repositories/wan21"]
|
||||||
|
path = repositories/wan21
|
||||||
|
url = https://github.com/Wan-Video/Wan2.1.git
|
||||||
|
commit = a326079926a4a347ecda8863dc40ba2d7680a294
|
||||||
1
repositories/wan21
Submodule
1
repositories/wan21
Submodule
Submodule repositories/wan21 added at a326079926
@@ -1,8 +1,7 @@
|
|||||||
torch==2.5.1
|
torch==2.5.1
|
||||||
torchvision==0.20.1
|
torchvision==0.20.1
|
||||||
safetensors
|
safetensors
|
||||||
# https://github.com/huggingface/diffusers/pull/10921
|
git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec
|
||||||
git+https://github.com/huggingface/diffusers@refs/pull/10921/head
|
|
||||||
transformers
|
transformers
|
||||||
lycoris-lora==1.8.3
|
lycoris-lora==1.8.3
|
||||||
flatten_json
|
flatten_json
|
||||||
|
|||||||
@@ -3,7 +3,38 @@ import torch
|
|||||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||||
from toolkit.models.base_model import BaseModel
|
from toolkit.models.base_model import BaseModel
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.cuda.amp as amp
|
||||||
|
import torch.distributed as dist
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
WAN_ROOT = os.path.join(REPOS_ROOT, "wan21")
|
||||||
|
sys.path.append(WAN_ROOT)
|
||||||
|
|
||||||
|
if True:
|
||||||
|
from wan.text2video import WanT2V
|
||||||
|
from wan.distributed.fsdp import shard_model
|
||||||
|
from wan.modules.model import WanModel
|
||||||
|
from wan.modules.t5 import T5EncoderModel
|
||||||
|
from wan.modules.vae import WanVAE
|
||||||
|
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas, retrieve_timesteps)
|
||||||
|
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
|
||||||
class Wan21(BaseModel):
|
class Wan21(BaseModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -21,9 +52,9 @@ class Wan21(BaseModel):
|
|||||||
# these must be implemented in child classes
|
# these must be implemented in child classes
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
# override this in child classes
|
self.pipeline = Wan21(
|
||||||
raise NotImplementedError(
|
|
||||||
"load_model must be implemented in child classes")
|
)
|
||||||
|
|
||||||
def get_generation_pipeline(self):
|
def get_generation_pipeline(self):
|
||||||
# override this in child classes
|
# override this in child classes
|
||||||
|
|||||||
Reference in New Issue
Block a user